/*
 * Decompiled with CFR 0.152.
 */
package io.r2dbc.mssql.client.ssl;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.ssl.SslHandler;
import io.r2dbc.mssql.client.ConnectionContext;
import io.r2dbc.mssql.client.TdsEncoder;
import io.r2dbc.mssql.client.ssl.ContextProxy;
import io.r2dbc.mssql.client.ssl.SslConfiguration;
import io.r2dbc.mssql.client.ssl.SslEventHandler;
import io.r2dbc.mssql.client.ssl.SslState;
import io.r2dbc.mssql.message.header.Header;
import io.r2dbc.mssql.message.header.HeaderOptions;
import io.r2dbc.mssql.message.header.PacketIdProvider;
import io.r2dbc.mssql.message.header.Status;
import io.r2dbc.mssql.message.header.Type;
import io.r2dbc.mssql.message.tds.ContextualTdsFragment;
import io.r2dbc.mssql.message.tds.TdsFragment;
import io.r2dbc.mssql.util.Assert;
import java.security.GeneralSecurityException;
import javax.net.ssl.SSLEngine;
import reactor.util.Logger;
import reactor.util.Loggers;
import reactor.util.annotation.Nullable;

@ChannelHandler.Sharable
public final class TdsSslHandler
extends ChannelDuplexHandler {
    private static final Logger LOGGER = Loggers.getLogger(TdsSslHandler.class);
    public static final boolean DEBUG_ENABLED = LOGGER.isDebugEnabled();
    private final ConnectionContext connectionContext;
    private final PacketIdProvider packetIdProvider;
    private final SslConfiguration sslConfiguration;
    private volatile SslHandler sslHandler;
    private ChannelHandlerContext context;
    private ByteBuf outputBuffer;
    private SslState state = SslState.OFF;
    private boolean handshakeDone;
    @Nullable
    private Chunk chunk;

    public TdsSslHandler(PacketIdProvider packetIdProvider, SslConfiguration sslConfiguration, ConnectionContext context) {
        Assert.requireNonNull(packetIdProvider, "PacketIdProvider must not be null");
        Assert.requireNonNull(sslConfiguration, "SslConfiguration must not be null");
        Assert.requireNonNull(context, "ConnectionContext must not be null");
        this.packetIdProvider = packetIdProvider;
        this.sslConfiguration = sslConfiguration;
        this.connectionContext = context;
    }

    void setSslHandler(SslHandler sslHandler) {
        this.sslHandler = sslHandler;
    }

    void setState(SslState state) {
        this.state = state;
    }

    private static SslHandler createSslHandler(SslConfiguration sslConfiguration, ByteBufAllocator allocator) throws GeneralSecurityException {
        SSLEngine sslEngine = sslConfiguration.getSslProvider().getSslContext().newEngine(allocator);
        return new SslHandler(sslEngine);
    }

    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt == SslState.LOGIN_ONLY || evt == SslState.CONNECTION) {
            this.state = (SslState)evt;
            this.sslHandler = TdsSslHandler.createSslHandler(this.sslConfiguration, ctx.alloc());
            LOGGER.debug(this.connectionContext.getMessage("Registering Context Proxy and SSL Event Handlers to propagate SSL events to channelRead()"));
            ctx.pipeline().addAfter(((Object)((Object)this)).getClass().getName(), ContextProxy.class.getName(), (ChannelHandler)new ContextProxy());
            ctx.pipeline().addAfter(ContextProxy.class.getName(), SslEventHandler.class.getName(), (ChannelHandler)new SslEventHandler());
            this.context = ctx.channel().pipeline().context(ContextProxy.class.getName());
            ctx.write((Object)HeaderOptions.create(Type.PRE_LOGIN, Status.empty()));
            this.sslHandler.handlerAdded(this.context);
        }
        if (evt == SslState.NEGOTIATED) {
            LOGGER.debug(this.connectionContext.getMessage("SSL Handshake done"));
            ctx.write((Object)TdsEncoder.ResetHeader.INSTANCE, ctx.voidPromise());
            this.handshakeDone = true;
            if (this.state == SslState.CONNECTION) {
                LOGGER.debug(this.connectionContext.getMessage("Reordering handlers for full SSL usage"));
                ctx.pipeline().remove((ChannelHandler)this);
                ctx.pipeline().addFirst(new ChannelHandler[]{this});
            }
        }
        super.userEventTriggered(ctx, evt);
    }

    public void handlerAdded(ChannelHandlerContext ctx) {
        this.outputBuffer = ctx.alloc().buffer();
    }

    public void handlerRemoved(ChannelHandlerContext ctx) {
        if (this.outputBuffer != null) {
            this.outputBuffer.release();
            this.outputBuffer = null;
        }
    }

    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        Chunk chunk;
        if (this.sslHandler != null) {
            this.sslHandler.channelInactive(ctx);
        }
        if ((chunk = this.chunk) != null) {
            chunk.fullMessage.release();
            chunk.aggregator.release();
            this.chunk = null;
        }
    }

    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
        if (this.handshakeDone && (this.state == SslState.NEGOTIATED || this.state == SslState.LOGIN_ONLY || this.state == SslState.CONNECTION)) {
            msg = this.unwrap(ctx.alloc(), msg);
            this.sslHandler.write(ctx, msg, promise);
            this.sslHandler.flush(ctx);
            if (this.state == SslState.LOGIN_ONLY) {
                this.state = SslState.AFTER_LOGIN_ONLY;
            }
            return;
        }
        if (this.requiresWrapping()) {
            if (DEBUG_ENABLED) {
                LOGGER.debug(this.connectionContext.getMessage("Write wrapping: Append to output buffer"));
            }
            ByteBuf sslPayload = (ByteBuf)msg;
            this.outputBuffer.writeBytes(sslPayload);
            sslPayload.release();
        } else {
            super.write(ctx, msg, promise);
        }
    }

    private Object unwrap(ByteBufAllocator allocator, Object msg) {
        if (msg instanceof ContextualTdsFragment) {
            ContextualTdsFragment tdsFragment = (ContextualTdsFragment)msg;
            HeaderOptions headerOptions = tdsFragment.getHeaderOptions();
            Status eom = headerOptions.getStatus().and(Status.StatusBit.EOM);
            Header header = new Header(headerOptions.getType(), eom, 8 + tdsFragment.getByteBuf().readableBytes(), 0, (int)this.packetIdProvider.nextPacketId(), 0);
            ByteBuf buffer = allocator.buffer((int)header.getLength());
            header.encode(buffer);
            buffer.writeBytes(tdsFragment.getByteBuf());
            tdsFragment.getByteBuf().release();
            return buffer;
        }
        if (msg instanceof TdsFragment) {
            return ((TdsFragment)msg).getByteBuf();
        }
        return msg;
    }

    public void flush(ChannelHandlerContext ctx) throws Exception {
        if (this.requiresWrapping()) {
            if (DEBUG_ENABLED) {
                LOGGER.debug(this.connectionContext.getMessage("Write wrapping: Flushing output buffer and enable auto-read"));
            }
            ByteBuf message = this.outputBuffer;
            this.outputBuffer = ctx.alloc().buffer();
            ctx.writeAndFlush((Object)message);
            ctx.channel().config().setAutoRead(true);
        } else {
            super.flush(ctx);
        }
    }

    public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
        if (this.isInHandshake() && this.outputBuffer.readableBytes() > 0) {
            this.flush(ctx);
        }
        super.channelReadComplete(ctx);
    }

    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (this.isInHandshake()) {
            ByteBuf buffer = (ByteBuf)msg;
            Chunk chunk = this.chunk;
            if (chunk != null || Header.canDecode(buffer)) {
                Header header;
                if (chunk == null) {
                    header = Header.decode(buffer);
                    if (!Chunk.isCompletePacketAvailable(header, buffer)) {
                        ByteBuf defragmented = buffer.alloc().buffer((int)header.getLength());
                        defragmented.writeBytes(buffer);
                        buffer.release();
                        this.chunk = new Chunk(header, defragmented, buffer.alloc().compositeBuffer());
                        ctx.read();
                        return;
                    }
                } else {
                    chunk.defragment(buffer);
                    if (!chunk.isCompleteHandshakeAvailable()) {
                        return;
                    }
                    buffer = chunk.fullMessage;
                    header = chunk.header;
                    this.chunk.aggregator.release();
                    this.chunk = null;
                }
                if (header.getType() == Type.PRE_LOGIN) {
                    this.sslHandler.channelRead(this.context, (Object)buffer);
                }
                if (header.is(Status.StatusBit.IGNORE)) {
                    return;
                }
            }
            return;
        }
        if (this.handshakeDone && this.state == SslState.CONNECTION) {
            this.sslHandler.channelRead(ctx, msg);
            return;
        }
        super.channelRead(ctx, msg);
    }

    private boolean isInHandshake() {
        return this.requiresWrapping() && !this.handshakeDone;
    }

    private boolean requiresWrapping() {
        return this.state == SslState.LOGIN_ONLY || this.state == SslState.CONNECTION;
    }

    static class Chunk {
        Header header;
        final ByteBuf fullMessage;
        final CompositeByteBuf aggregator;
        int decoded = 0;

        Chunk(Header header, ByteBuf fullMessage, CompositeByteBuf aggregator) {
            this.header = header;
            this.fullMessage = fullMessage;
            this.aggregator = aggregator;
        }

        void defragment(ByteBuf chunk) {
            this.aggregator.addComponent(true, chunk);
            while (this.aggregator.isReadable()) {
                int remainder = this.getRemainingLength();
                if (this.aggregator.readableBytes() < remainder) break;
                this.fullMessage.writeBytes((ByteBuf)this.aggregator, remainder);
                if (!Header.canDecode((ByteBuf)this.aggregator)) break;
                this.updateHeader(Header.decode((ByteBuf)this.aggregator));
                if (!this.isCompleteHandshakeAvailable()) continue;
                break;
            }
        }

        void updateHeader(Header header) {
            this.decoded += this.header.getLength() - 8;
            this.header = header;
        }

        boolean isCompleteHandshakeAvailable() {
            return this.header.is(Status.StatusBit.EOM) && this.getRemainingLength() <= 0;
        }

        int getRemainingLength() {
            return this.header.getLength() - (this.fullMessage.readableBytes() - this.decoded + 8);
        }

        static boolean isCompletePacketAvailable(Header header, ByteBuf buffer) {
            return buffer.readableBytes() + 8 >= header.getLength();
        }
    }
}

