/*
 * Decompiled with CFR 0.152.
 */
package com.mongodb.reactivestreams.client.internal.crypt;

import com.mongodb.MongoOperationTimeoutException;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketReadTimeoutException;
import com.mongodb.MongoSocketWriteTimeoutException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.AsyncCompletionHandler;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.TimeoutSettings;
import com.mongodb.internal.connection.AsynchronousChannelStream;
import com.mongodb.internal.connection.DefaultInetAddressResolver;
import com.mongodb.internal.connection.OperationContext;
import com.mongodb.internal.connection.Stream;
import com.mongodb.internal.connection.StreamFactory;
import com.mongodb.internal.connection.TlsChannelStreamFactoryFactory;
import com.mongodb.internal.crypt.capi.MongoKeyDecryptor;
import com.mongodb.internal.diagnostics.logging.Logger;
import com.mongodb.internal.diagnostics.logging.Loggers;
import com.mongodb.internal.time.Timeout;
import com.mongodb.lang.NonNull;
import com.mongodb.lang.Nullable;
import com.mongodb.spi.dns.InetAddressResolver;
import java.io.Closeable;
import java.nio.channels.CompletionHandler;
import java.nio.channels.InterruptedByTimeoutException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLContext;
import org.bson.ByteBuf;
import org.bson.ByteBufNIO;
import org.bson.assertions.Assertions;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;

class KeyManagementService
implements Closeable {
    private static final Logger LOGGER = Loggers.getLogger((String)"client");
    private static final String TIMEOUT_ERROR_MESSAGE = "KMS key decryption exceeded the timeout limit.";
    private final Map<String, SSLContext> kmsProviderSslContextMap;
    private final int timeoutMillis;
    private final TlsChannelStreamFactoryFactory tlsChannelStreamFactoryFactory;

    KeyManagementService(Map<String, SSLContext> kmsProviderSslContextMap, int timeoutMillis) {
        Assertions.assertTrue((String)"timeoutMillis > 0", (timeoutMillis > 0 ? 1 : 0) != 0);
        this.kmsProviderSslContextMap = kmsProviderSslContextMap;
        this.tlsChannelStreamFactoryFactory = new TlsChannelStreamFactoryFactory((InetAddressResolver)new DefaultInetAddressResolver());
        this.timeoutMillis = timeoutMillis;
    }

    @Override
    public void close() {
        this.tlsChannelStreamFactoryFactory.close();
    }

    Mono<Void> decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable Timeout operationTimeout) {
        SocketSettings socketSettings = SocketSettings.builder().connectTimeout((long)this.timeoutMillis, TimeUnit.MILLISECONDS).readTimeout((long)this.timeoutMillis, TimeUnit.MILLISECONDS).build();
        StreamFactory streamFactory = this.tlsChannelStreamFactoryFactory.create(socketSettings, SslSettings.builder().enabled(true).context(this.kmsProviderSslContextMap.get(keyDecryptor.getKmsProvider())).build());
        ServerAddress serverAddress = new ServerAddress(keyDecryptor.getHostName());
        LOGGER.info("Connecting to KMS server at " + serverAddress);
        return Mono.create(sink -> {
            final Stream stream = streamFactory.create(serverAddress);
            final OperationContext operationContext = this.createOperationContext(operationTimeout, socketSettings);
            stream.openAsync(operationContext, (AsyncCompletionHandler)new AsyncCompletionHandler<Void>(){

                public void completed(@Nullable Void ignored) {
                    KeyManagementService.this.streamWrite(stream, keyDecryptor, operationContext, (MonoSink<Void>)sink);
                }

                public void failed(Throwable t) {
                    stream.close();
                    KeyManagementService.handleError(t, operationContext, (MonoSink<Void>)sink);
                }
            });
        }).onErrorMap(this::unWrapException);
    }

    private void streamWrite(final Stream stream, final MongoKeyDecryptor keyDecryptor, final OperationContext operationContext, final MonoSink<Void> sink) {
        List<ByteBufNIO> byteBufs = Collections.singletonList(new ByteBufNIO(keyDecryptor.getMessage()));
        stream.writeAsync(byteBufs, operationContext, (AsyncCompletionHandler)new AsyncCompletionHandler<Void>(){

            public void completed(@Nullable Void aVoid) {
                KeyManagementService.this.streamRead(stream, keyDecryptor, operationContext, (MonoSink<Void>)sink);
            }

            public void failed(Throwable t) {
                stream.close();
                KeyManagementService.handleError(t, operationContext, (MonoSink<Void>)sink);
            }
        });
    }

    private void streamRead(final Stream stream, final MongoKeyDecryptor keyDecryptor, final OperationContext operationContext, final MonoSink<Void> sink) {
        int bytesNeeded = keyDecryptor.bytesNeeded();
        if (bytesNeeded > 0) {
            AsynchronousChannelStream asyncStream = (AsynchronousChannelStream)stream;
            final ByteBuf buffer = asyncStream.getBuffer(bytesNeeded);
            long readTimeoutMS = operationContext.getTimeoutContext().getReadTimeoutMS();
            asyncStream.getChannel().read(buffer.asNIO(), readTimeoutMS, TimeUnit.MILLISECONDS, null, (CompletionHandler)new CompletionHandler<Integer, Void>(){

                @Override
                public void completed(Integer integer, Void aVoid) {
                    buffer.flip();
                    try {
                        keyDecryptor.feed(buffer.asNIO());
                        buffer.release();
                        KeyManagementService.this.streamRead(stream, keyDecryptor, operationContext, (MonoSink<Void>)sink);
                    }
                    catch (Throwable t) {
                        sink.error(t);
                    }
                }

                @Override
                public void failed(Throwable t, Void aVoid) {
                    buffer.release();
                    stream.close();
                    KeyManagementService.handleError(t, operationContext, (MonoSink<Void>)sink);
                }
            });
        } else {
            stream.close();
            sink.success();
        }
    }

    private static void handleError(Throwable t, OperationContext operationContext, MonoSink<Void> sink) {
        if (KeyManagementService.isTimeoutException(t) && operationContext.getTimeoutContext().hasTimeoutMS()) {
            sink.error((Throwable)TimeoutContext.createMongoTimeoutException((String)TIMEOUT_ERROR_MESSAGE, (Throwable)t));
        } else {
            sink.error(t);
        }
    }

    private OperationContext createOperationContext(@Nullable Timeout operationTimeout, SocketSettings socketSettings) {
        TimeoutSettings timeoutSettings = operationTimeout == null ? KeyManagementService.createTimeoutSettings(socketSettings, null) : (TimeoutSettings)operationTimeout.call(TimeUnit.MILLISECONDS, () -> {
            throw new AssertionError((Object)"operationTimeout cannot be infinite");
        }, ms -> KeyManagementService.createTimeoutSettings(socketSettings, ms), () -> {
            throw new MongoOperationTimeoutException(TIMEOUT_ERROR_MESSAGE);
        });
        return OperationContext.simpleOperationContext((TimeoutContext)new TimeoutContext(timeoutSettings));
    }

    @NonNull
    private static TimeoutSettings createTimeoutSettings(SocketSettings socketSettings, @Nullable Long ms) {
        return new TimeoutSettings(0L, (long)socketSettings.getConnectTimeout(TimeUnit.MILLISECONDS), (long)socketSettings.getReadTimeout(TimeUnit.MILLISECONDS), ms, 0L);
    }

    private Throwable unWrapException(Throwable t) {
        return t instanceof MongoSocketException ? t.getCause() : t;
    }

    private static boolean isTimeoutException(Throwable t) {
        return t instanceof MongoSocketReadTimeoutException || t instanceof MongoSocketWriteTimeoutException || t instanceof InterruptedByTimeoutException;
    }
}

