/*
 * Decompiled with CFR 0.152.
 */
package net.schmizz.sshj.transport;

import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.PublicKey;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import net.schmizz.concurrent.ErrorDeliveryUtil;
import net.schmizz.concurrent.Event;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.common.ErrorNotifiable;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.common.SSHPacketHandler;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.transport.NegotiatedAlgorithms;
import net.schmizz.sshj.transport.Proposal;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.transport.TransportImpl;
import net.schmizz.sshj.transport.cipher.Cipher;
import net.schmizz.sshj.transport.compression.Compression;
import net.schmizz.sshj.transport.digest.Digest;
import net.schmizz.sshj.transport.kex.KeyExchange;
import net.schmizz.sshj.transport.mac.MAC;
import net.schmizz.sshj.transport.verification.AlgorithmsVerifier;
import net.schmizz.sshj.transport.verification.HostKeyVerifier;
import org.slf4j.Logger;

final class KeyExchanger
implements SSHPacketHandler,
ErrorNotifiable {
    private final Logger log;
    private final TransportImpl transport;
    private final Queue<HostKeyVerifier> hostVerifiers = new LinkedList<HostKeyVerifier>();
    private final Queue<AlgorithmsVerifier> algorithmVerifiers = new LinkedList<AlgorithmsVerifier>();
    private final AtomicBoolean kexOngoing = new AtomicBoolean();
    private Expected expected = Expected.KEXINIT;
    private KeyExchange kex;
    private byte[] sessionID;
    private Proposal clientProposal;
    private NegotiatedAlgorithms negotiatedAlgs;
    private final Event<TransportException> kexInitSent;
    private final Event<TransportException> done;

    KeyExchanger(TransportImpl trans) {
        this.transport = trans;
        this.log = trans.getConfig().getLoggerFactory().getLogger(this.getClass());
        this.kexInitSent = new Event<TransportException>("kexinit sent", TransportException.chainer, trans.getConfig().getLoggerFactory());
        this.done = new Event<TransportException>("kex done", TransportException.chainer, trans.getWriteLock(), trans.getConfig().getLoggerFactory());
    }

    synchronized void addHostKeyVerifier(HostKeyVerifier hkv) {
        this.hostVerifiers.add(hkv);
    }

    synchronized void addAlgorithmsVerifier(AlgorithmsVerifier verifier) {
        this.algorithmVerifiers.add(verifier);
    }

    byte[] getSessionID() {
        return Arrays.copyOf(this.sessionID, this.sessionID.length);
    }

    boolean isKexDone() {
        return this.done.isSet();
    }

    boolean isKexOngoing() {
        return this.kexOngoing.get();
    }

    void startKex(boolean waitForDone) throws TransportException {
        if (!this.kexOngoing.getAndSet(true)) {
            this.done.clear();
            this.sendKexInit();
        }
        if (waitForDone) {
            this.waitForDone();
        }
    }

    void waitForDone() throws TransportException {
        this.done.await(this.transport.getTimeoutMs(), TimeUnit.MILLISECONDS);
    }

    private synchronized void ensureKexOngoing() throws TransportException {
        if (!this.isKexOngoing()) {
            throw new TransportException(DisconnectReason.PROTOCOL_ERROR, "Key exchange packet received when key exchange was not ongoing");
        }
    }

    private static void ensureReceivedMatchesExpected(Message got, Message expected) throws TransportException {
        if (got != expected) {
            throw new TransportException(DisconnectReason.PROTOCOL_ERROR, "Was expecting " + (Object)((Object)expected));
        }
    }

    private void sendKexInit() throws TransportException {
        this.log.debug("Sending SSH_MSG_KEXINIT");
        this.clientProposal = new Proposal(this.transport.getConfig());
        this.transport.write(this.clientProposal.getPacket());
        this.kexInitSent.set();
    }

    private void sendNewKeys() throws TransportException {
        this.log.debug("Sending SSH_MSG_NEWKEYS");
        this.transport.write(new SSHPacket(Message.NEWKEYS));
    }

    private synchronized void verifyHost(PublicKey key) throws TransportException {
        for (HostKeyVerifier hkv : this.hostVerifiers) {
            this.log.debug("Trying to verify host key with {}", (Object)hkv);
            if (!hkv.verify(this.transport.getRemoteHost(), this.transport.getRemotePort(), key)) continue;
            return;
        }
        throw new TransportException(DisconnectReason.HOST_KEY_NOT_VERIFIABLE, "Could not verify `" + (Object)((Object)KeyType.fromKey(key)) + "` host key with fingerprint `" + SecurityUtils.getFingerprint(key) + "` for `" + this.transport.getRemoteHost() + "` on port " + this.transport.getRemotePort());
    }

    private void setKexDone() {
        this.kexOngoing.set(false);
        this.kexInitSent.clear();
        this.done.set();
    }

    private void gotKexInit(SSHPacket buf) throws TransportException {
        buf.rpos(buf.rpos() - 1);
        Proposal serverProposal = new Proposal(buf);
        this.negotiatedAlgs = this.clientProposal.negotiate(serverProposal);
        this.log.debug("Negotiated algorithms: {}", (Object)this.negotiatedAlgs);
        for (AlgorithmsVerifier v : this.algorithmVerifiers) {
            this.log.debug("Trying to verify algorithms with {}", (Object)v);
            if (v.verify(this.negotiatedAlgs)) continue;
            throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, "Failed to verify negotiated algorithms `" + this.negotiatedAlgs + "`");
        }
        this.kex = (KeyExchange)Factory.Named.Util.create(this.transport.getConfig().getKeyExchangeFactories(), this.negotiatedAlgs.getKeyExchangeAlgorithm());
        try {
            this.kex.init(this.transport, this.transport.getServerID(), this.transport.getClientID(), serverProposal.getPacket().getCompactData(), this.clientProposal.getPacket().getCompactData());
        }
        catch (GeneralSecurityException e) {
            throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, (Throwable)e);
        }
    }

    private static byte[] resizedKey(byte[] E, int blockSize, Digest hash, BigInteger K, byte[] H) {
        while (blockSize > E.length) {
            Buffer.PlainBuffer buffer = (Buffer.PlainBuffer)((Buffer.PlainBuffer)((Buffer.PlainBuffer)new Buffer.PlainBuffer().putMPInt(K)).putRawBytes(H)).putRawBytes(E);
            hash.update(buffer.array(), 0, buffer.available());
            byte[] foo = hash.digest();
            byte[] bar = new byte[E.length + foo.length];
            System.arraycopy(E, 0, bar, 0, E.length);
            System.arraycopy(foo, 0, bar, E.length, foo.length);
            E = bar;
        }
        return E;
    }

    private void gotNewKeys() {
        Digest hash = this.kex.getHash();
        byte[] H = this.kex.getH();
        if (this.sessionID == null) {
            this.sessionID = H;
        }
        Buffer.PlainBuffer hashInput = (Buffer.PlainBuffer)((Buffer.PlainBuffer)((Buffer.PlainBuffer)((Buffer.PlainBuffer)new Buffer.PlainBuffer().putMPInt(this.kex.getK())).putRawBytes(H)).putByte((byte)0)).putRawBytes(this.sessionID);
        int pos = hashInput.available() - this.sessionID.length - 1;
        hashInput.array()[pos] = 65;
        hash.update(hashInput.array(), 0, hashInput.available());
        byte[] initialIV_C2S = hash.digest();
        hashInput.array()[pos] = 66;
        hash.update(hashInput.array(), 0, hashInput.available());
        byte[] initialIV_S2C = hash.digest();
        hashInput.array()[pos] = 67;
        hash.update(hashInput.array(), 0, hashInput.available());
        byte[] encryptionKey_C2S = hash.digest();
        hashInput.array()[pos] = 68;
        hash.update(hashInput.array(), 0, hashInput.available());
        byte[] encryptionKey_S2C = hash.digest();
        hashInput.array()[pos] = 69;
        hash.update(hashInput.array(), 0, hashInput.available());
        byte[] integrityKey_C2S = hash.digest();
        hashInput.array()[pos] = 70;
        hash.update(hashInput.array(), 0, hashInput.available());
        byte[] integrityKey_S2C = hash.digest();
        Cipher cipher_C2S = (Cipher)Factory.Named.Util.create(this.transport.getConfig().getCipherFactories(), this.negotiatedAlgs.getClient2ServerCipherAlgorithm());
        cipher_C2S.init(Cipher.Mode.Encrypt, KeyExchanger.resizedKey(encryptionKey_C2S, cipher_C2S.getBlockSize(), hash, this.kex.getK(), this.kex.getH()), initialIV_C2S);
        Cipher cipher_S2C = (Cipher)Factory.Named.Util.create(this.transport.getConfig().getCipherFactories(), this.negotiatedAlgs.getServer2ClientCipherAlgorithm());
        cipher_S2C.init(Cipher.Mode.Decrypt, KeyExchanger.resizedKey(encryptionKey_S2C, cipher_S2C.getBlockSize(), hash, this.kex.getK(), this.kex.getH()), initialIV_S2C);
        MAC mac_C2S = (MAC)Factory.Named.Util.create(this.transport.getConfig().getMACFactories(), this.negotiatedAlgs.getClient2ServerMACAlgorithm());
        mac_C2S.init(KeyExchanger.resizedKey(integrityKey_C2S, mac_C2S.getBlockSize(), hash, this.kex.getK(), this.kex.getH()));
        MAC mac_S2C = (MAC)Factory.Named.Util.create(this.transport.getConfig().getMACFactories(), this.negotiatedAlgs.getServer2ClientMACAlgorithm());
        mac_S2C.init(KeyExchanger.resizedKey(integrityKey_S2C, mac_S2C.getBlockSize(), hash, this.kex.getK(), this.kex.getH()));
        Compression compression_S2C = (Compression)Factory.Named.Util.create(this.transport.getConfig().getCompressionFactories(), this.negotiatedAlgs.getServer2ClientCompressionAlgorithm());
        Compression compression_C2S = (Compression)Factory.Named.Util.create(this.transport.getConfig().getCompressionFactories(), this.negotiatedAlgs.getClient2ServerCompressionAlgorithm());
        this.transport.getEncoder().setAlgorithms(cipher_C2S, mac_C2S, compression_C2S);
        this.transport.getDecoder().setAlgorithms(cipher_S2C, mac_S2C, compression_S2C);
    }

    @Override
    public void handle(Message msg, SSHPacket buf) throws TransportException {
        switch (this.expected) {
            case KEXINIT: {
                KeyExchanger.ensureReceivedMatchesExpected(msg, Message.KEXINIT);
                this.log.debug("Received SSH_MSG_KEXINIT");
                this.startKex(false);
                this.kexInitSent.await(this.transport.getTimeoutMs(), TimeUnit.MILLISECONDS);
                this.gotKexInit(buf);
                this.expected = Expected.FOLLOWUP;
                break;
            }
            case FOLLOWUP: {
                this.ensureKexOngoing();
                this.log.debug("Received kex followup data");
                try {
                    if (!this.kex.next(msg, buf)) break;
                    this.verifyHost(this.kex.getHostKey());
                    this.sendNewKeys();
                    this.expected = Expected.NEWKEYS;
                    break;
                }
                catch (GeneralSecurityException e) {
                    throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, (Throwable)e);
                }
            }
            case NEWKEYS: {
                KeyExchanger.ensureReceivedMatchesExpected(msg, Message.NEWKEYS);
                this.ensureKexOngoing();
                this.log.debug("Received SSH_MSG_NEWKEYS");
                this.gotNewKeys();
                this.setKexDone();
                this.expected = Expected.KEXINIT;
                break;
            }
            default: {
                assert (false);
                break;
            }
        }
    }

    @Override
    public void notifyError(SSHException error) {
        this.log.debug("Got notified of {}", (Object)error.toString());
        ErrorDeliveryUtil.alertEvents((Throwable)error, this.kexInitSent, this.done);
    }

    private static enum Expected {
        KEXINIT,
        FOLLOWUP,
        NEWKEYS;

    }
}

