/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.xmss;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.TreeMap;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.pqc.crypto.xmss.BDSTreeHash;
import org.bouncycastle.pqc.crypto.xmss.HashTreeAddress;
import org.bouncycastle.pqc.crypto.xmss.LTreeAddress;
import org.bouncycastle.pqc.crypto.xmss.OTSHashAddress;
import org.bouncycastle.pqc.crypto.xmss.WOTSPlus;
import org.bouncycastle.pqc.crypto.xmss.WOTSPlusParameters;
import org.bouncycastle.pqc.crypto.xmss.WOTSPlusPublicKeyParameters;
import org.bouncycastle.pqc.crypto.xmss.XMSSNode;
import org.bouncycastle.pqc.crypto.xmss.XMSSNodeUtil;
import org.bouncycastle.pqc.crypto.xmss.XMSSParameters;
import org.bouncycastle.pqc.crypto.xmss.XMSSUtil;

/*
 * Multiple versions of this class in jar - see https://www.benf.org/other/cfr/multi-version-jar.html
 */
public final class BDS
implements Serializable {
    private static final long serialVersionUID = 1L;
    private transient WOTSPlus wotsPlus;
    private final int treeHeight;
    private final List<BDSTreeHash> treeHashInstances;
    private int k;
    private XMSSNode root;
    private List<XMSSNode> authenticationPath;
    private Map<Integer, LinkedList<XMSSNode>> retain;
    private Stack<XMSSNode> stack;
    private Map<Integer, XMSSNode> keep;
    private int index;
    private boolean used;
    private transient int maxIndex;

    BDS(XMSSParameters params, int maxIndex, int index) {
        this(params.getWOTSPlus(), params.getHeight(), params.getK(), index);
        this.maxIndex = maxIndex;
        this.index = index;
        this.used = true;
    }

    BDS(XMSSParameters params, byte[] publicSeed, byte[] secretKeySeed, OTSHashAddress otsHashAddress) {
        this(params.getWOTSPlus(), params.getHeight(), params.getK(), (1 << params.getHeight()) - 1);
        this.initialize(publicSeed, secretKeySeed, otsHashAddress);
    }

    BDS(XMSSParameters params, byte[] publicSeed, byte[] secretKeySeed, OTSHashAddress otsHashAddress, int index) {
        this(params.getWOTSPlus(), params.getHeight(), params.getK(), (1 << params.getHeight()) - 1);
        this.initialize(publicSeed, secretKeySeed, otsHashAddress);
        while (this.index < index) {
            this.nextAuthenticationPath(publicSeed, secretKeySeed, otsHashAddress);
            this.used = false;
        }
    }

    private BDS(WOTSPlus wotsPlus, int treeHeight, int k, int maxIndex) {
        this.wotsPlus = wotsPlus;
        this.treeHeight = treeHeight;
        this.maxIndex = maxIndex;
        this.k = k;
        if (k > treeHeight || k < 2 || (treeHeight - k) % 2 != 0) {
            throw new IllegalArgumentException("illegal value for BDS parameter k");
        }
        this.authenticationPath = new ArrayList<XMSSNode>();
        this.retain = new TreeMap<Integer, LinkedList<XMSSNode>>();
        this.stack = new Stack();
        this.treeHashInstances = new ArrayList<BDSTreeHash>();
        for (int height = 0; height < treeHeight - k; ++height) {
            this.treeHashInstances.add(new BDSTreeHash(height));
        }
        this.keep = new TreeMap<Integer, XMSSNode>();
        this.index = 0;
        this.used = false;
    }

    BDS(BDS last) {
        this.wotsPlus = new WOTSPlus(last.wotsPlus.getParams());
        this.treeHeight = last.treeHeight;
        this.k = last.k;
        this.root = last.root;
        this.authenticationPath = new ArrayList<XMSSNode>();
        this.authenticationPath.addAll(last.authenticationPath);
        this.retain = new TreeMap<Integer, LinkedList<XMSSNode>>();
        for (Integer key : last.retain.keySet()) {
            this.retain.put(key, (LinkedList)last.retain.get(key).clone());
        }
        this.stack = new Stack();
        this.stack.addAll(last.stack);
        this.treeHashInstances = new ArrayList<BDSTreeHash>();
        Iterator<Serializable> it = last.treeHashInstances.iterator();
        while (it.hasNext()) {
            this.treeHashInstances.add(((BDSTreeHash)it.next()).clone());
        }
        this.keep = new TreeMap<Integer, XMSSNode>(last.keep);
        this.index = last.index;
        this.maxIndex = last.maxIndex;
        this.used = last.used;
    }

    private BDS(BDS last, byte[] publicSeed, byte[] secretKeySeed, OTSHashAddress otsHashAddress) {
        this.wotsPlus = new WOTSPlus(last.wotsPlus.getParams());
        this.treeHeight = last.treeHeight;
        this.k = last.k;
        this.root = last.root;
        this.authenticationPath = new ArrayList<XMSSNode>();
        this.authenticationPath.addAll(last.authenticationPath);
        this.retain = new TreeMap<Integer, LinkedList<XMSSNode>>();
        for (Integer key : last.retain.keySet()) {
            this.retain.put(key, (LinkedList)last.retain.get(key).clone());
        }
        this.stack = new Stack();
        this.stack.addAll(last.stack);
        this.treeHashInstances = new ArrayList<BDSTreeHash>();
        Iterator<Serializable> it = last.treeHashInstances.iterator();
        while (it.hasNext()) {
            this.treeHashInstances.add(((BDSTreeHash)it.next()).clone());
        }
        this.keep = new TreeMap<Integer, XMSSNode>(last.keep);
        this.index = last.index;
        this.maxIndex = last.maxIndex;
        this.used = false;
        this.nextAuthenticationPath(publicSeed, secretKeySeed, otsHashAddress);
    }

    private BDS(BDS last, ASN1ObjectIdentifier digest) {
        this.wotsPlus = new WOTSPlus(new WOTSPlusParameters(digest));
        this.treeHeight = last.treeHeight;
        this.k = last.k;
        this.root = last.root;
        this.authenticationPath = new ArrayList<XMSSNode>();
        this.authenticationPath.addAll(last.authenticationPath);
        this.retain = new TreeMap<Integer, LinkedList<XMSSNode>>();
        for (Integer key : last.retain.keySet()) {
            this.retain.put(key, (LinkedList)last.retain.get(key).clone());
        }
        this.stack = new Stack();
        this.stack.addAll(last.stack);
        this.treeHashInstances = new ArrayList<BDSTreeHash>();
        Iterator<Serializable> it = last.treeHashInstances.iterator();
        while (it.hasNext()) {
            this.treeHashInstances.add(((BDSTreeHash)it.next()).clone());
        }
        this.keep = new TreeMap<Integer, XMSSNode>(last.keep);
        this.index = last.index;
        this.maxIndex = last.maxIndex;
        this.used = last.used;
        this.validate();
    }

    private BDS(BDS last, int maxIndex, ASN1ObjectIdentifier digest) {
        this.wotsPlus = new WOTSPlus(new WOTSPlusParameters(digest));
        this.treeHeight = last.treeHeight;
        this.k = last.k;
        this.root = last.root;
        this.authenticationPath = new ArrayList<XMSSNode>();
        this.authenticationPath.addAll(last.authenticationPath);
        this.retain = new TreeMap<Integer, LinkedList<XMSSNode>>();
        for (Integer key : last.retain.keySet()) {
            this.retain.put(key, (LinkedList)last.retain.get(key).clone());
        }
        this.stack = new Stack();
        this.stack.addAll(last.stack);
        this.treeHashInstances = new ArrayList<BDSTreeHash>();
        Iterator<Serializable> it = last.treeHashInstances.iterator();
        while (it.hasNext()) {
            this.treeHashInstances.add(((BDSTreeHash)it.next()).clone());
        }
        this.keep = new TreeMap<Integer, XMSSNode>(last.keep);
        this.index = last.index;
        this.maxIndex = maxIndex;
        this.used = last.used;
        this.validate();
    }

    public BDS getNextState(byte[] publicSeed, byte[] secretKeySeed, OTSHashAddress otsHashAddress) {
        return new BDS(this, publicSeed, secretKeySeed, otsHashAddress);
    }

    private void initialize(byte[] publicSeed, byte[] secretSeed, OTSHashAddress otsHashAddress) {
        if (otsHashAddress == null) {
            throw new NullPointerException("otsHashAddress == null");
        }
        LTreeAddress lTreeAddress = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).build();
        HashTreeAddress hashTreeAddress = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).build();
        for (int indexLeaf = 0; indexLeaf < 1 << this.treeHeight; ++indexLeaf) {
            otsHashAddress = (OTSHashAddress)((OTSHashAddress.Builder)((OTSHashAddress.Builder)((OTSHashAddress.Builder)new OTSHashAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withOTSAddress(indexLeaf).withChainAddress(otsHashAddress.getChainAddress()).withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())).build();
            this.wotsPlus.importKeys(this.wotsPlus.getWOTSPlusSecretKey(secretSeed, otsHashAddress), publicSeed);
            WOTSPlusPublicKeyParameters wotsPlusPublicKey = this.wotsPlus.getPublicKey(otsHashAddress);
            lTreeAddress = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(lTreeAddress.getLayerAddress())).withTreeAddress(lTreeAddress.getTreeAddress())).withLTreeAddress(indexLeaf).withTreeHeight(lTreeAddress.getTreeHeight()).withTreeIndex(lTreeAddress.getTreeIndex()).withKeyAndMask(lTreeAddress.getKeyAndMask())).build();
            XMSSNode node = XMSSNodeUtil.lTree(this.wotsPlus, wotsPlusPublicKey, lTreeAddress);
            hashTreeAddress = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(hashTreeAddress.getLayerAddress())).withTreeAddress(hashTreeAddress.getTreeAddress())).withTreeIndex(indexLeaf).withKeyAndMask(hashTreeAddress.getKeyAndMask())).build();
            while (!this.stack.isEmpty() && this.stack.peek().getHeight() == node.getHeight()) {
                int indexOnHeight = indexLeaf / (1 << node.getHeight());
                if (indexOnHeight == 1) {
                    this.authenticationPath.add(node);
                }
                if (indexOnHeight == 3 && node.getHeight() < this.treeHeight - this.k) {
                    this.treeHashInstances.get(node.getHeight()).setNode(node);
                }
                if (indexOnHeight >= 3 && (indexOnHeight & 1) == 1 && node.getHeight() >= this.treeHeight - this.k && node.getHeight() <= this.treeHeight - 2) {
                    if (this.retain.get(node.getHeight()) == null) {
                        LinkedList<XMSSNode> queue = new LinkedList<XMSSNode>();
                        queue.add(node);
                        this.retain.put(node.getHeight(), queue);
                    } else {
                        this.retain.get(node.getHeight()).add(node);
                    }
                }
                hashTreeAddress = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(hashTreeAddress.getLayerAddress())).withTreeAddress(hashTreeAddress.getTreeAddress())).withTreeHeight(hashTreeAddress.getTreeHeight()).withTreeIndex((hashTreeAddress.getTreeIndex() - 1) / 2).withKeyAndMask(hashTreeAddress.getKeyAndMask())).build();
                node = XMSSNodeUtil.randomizeHash(this.wotsPlus, this.stack.pop(), node, hashTreeAddress);
                node = new XMSSNode(node.getHeight() + 1, node.getValue());
                hashTreeAddress = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(hashTreeAddress.getLayerAddress())).withTreeAddress(hashTreeAddress.getTreeAddress())).withTreeHeight(hashTreeAddress.getTreeHeight() + 1).withTreeIndex(hashTreeAddress.getTreeIndex()).withKeyAndMask(hashTreeAddress.getKeyAndMask())).build();
            }
            this.stack.push(node);
        }
        this.root = this.stack.pop();
    }

    private void nextAuthenticationPath(byte[] publicSeed, byte[] secretSeed, OTSHashAddress otsHashAddress) {
        if (otsHashAddress == null) {
            throw new NullPointerException("otsHashAddress == null");
        }
        if (this.used) {
            throw new IllegalStateException("index already used");
        }
        if (this.index > this.maxIndex - 1) {
            throw new IllegalStateException("index out of bounds");
        }
        int tau = XMSSUtil.calculateTau(this.index, this.treeHeight);
        if ((this.index >> tau + 1 & 1) == 0 && tau < this.treeHeight - 1) {
            this.keep.put(tau, this.authenticationPath.get(tau));
        }
        LTreeAddress lTreeAddress = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).build();
        HashTreeAddress hashTreeAddress = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).build();
        if (tau == 0) {
            otsHashAddress = (OTSHashAddress)((OTSHashAddress.Builder)((OTSHashAddress.Builder)((OTSHashAddress.Builder)new OTSHashAddress.Builder().withLayerAddress(otsHashAddress.getLayerAddress())).withTreeAddress(otsHashAddress.getTreeAddress())).withOTSAddress(this.index).withChainAddress(otsHashAddress.getChainAddress()).withHashAddress(otsHashAddress.getHashAddress()).withKeyAndMask(otsHashAddress.getKeyAndMask())).build();
            this.wotsPlus.importKeys(this.wotsPlus.getWOTSPlusSecretKey(secretSeed, otsHashAddress), publicSeed);
            WOTSPlusPublicKeyParameters wotsPlusPublicKey = this.wotsPlus.getPublicKey(otsHashAddress);
            lTreeAddress = (LTreeAddress)((LTreeAddress.Builder)((LTreeAddress.Builder)((LTreeAddress.Builder)new LTreeAddress.Builder().withLayerAddress(lTreeAddress.getLayerAddress())).withTreeAddress(lTreeAddress.getTreeAddress())).withLTreeAddress(this.index).withTreeHeight(lTreeAddress.getTreeHeight()).withTreeIndex(lTreeAddress.getTreeIndex()).withKeyAndMask(lTreeAddress.getKeyAndMask())).build();
            XMSSNode node = XMSSNodeUtil.lTree(this.wotsPlus, wotsPlusPublicKey, lTreeAddress);
            this.authenticationPath.set(0, node);
        } else {
            hashTreeAddress = (HashTreeAddress)((HashTreeAddress.Builder)((HashTreeAddress.Builder)((HashTreeAddress.Builder)new HashTreeAddress.Builder().withLayerAddress(hashTreeAddress.getLayerAddress())).withTreeAddress(hashTreeAddress.getTreeAddress())).withTreeHeight(tau - 1).withTreeIndex(this.index >> tau).withKeyAndMask(hashTreeAddress.getKeyAndMask())).build();
            this.wotsPlus.importKeys(this.wotsPlus.getWOTSPlusSecretKey(secretSeed, otsHashAddress), publicSeed);
            XMSSNode node = XMSSNodeUtil.randomizeHash(this.wotsPlus, this.authenticationPath.get(tau - 1), this.keep.get(tau - 1), hashTreeAddress);
            node = new XMSSNode(node.getHeight() + 1, node.getValue());
            this.authenticationPath.set(tau, node);
            this.keep.remove(tau - 1);
            for (int height = 0; height < tau; ++height) {
                if (height < this.treeHeight - this.k) {
                    this.authenticationPath.set(height, this.treeHashInstances.get(height).getTailNode());
                    continue;
                }
                this.authenticationPath.set(height, this.retain.get(height).removeFirst());
            }
            int minHeight = Math.min(tau, this.treeHeight - this.k);
            for (int height = 0; height < minHeight; ++height) {
                int startIndex = this.index + 1 + 3 * (1 << height);
                if (startIndex >= 1 << this.treeHeight) continue;
                this.treeHashInstances.get(height).initialize(startIndex);
            }
        }
        for (int i = 0; i < this.treeHeight - this.k >> 1; ++i) {
            BDSTreeHash treeHash = this.getBDSTreeHashInstanceForUpdate();
            if (treeHash == null) continue;
            treeHash.update(this.stack, this.wotsPlus, publicSeed, secretSeed, otsHashAddress);
        }
        ++this.index;
    }

    boolean isUsed() {
        return this.used;
    }

    void markUsed() {
        this.used = true;
    }

    private BDSTreeHash getBDSTreeHashInstanceForUpdate() {
        BDSTreeHash ret = null;
        for (BDSTreeHash treeHash : this.treeHashInstances) {
            if (treeHash.isFinished() || !treeHash.isInitialized()) continue;
            if (ret == null) {
                ret = treeHash;
                continue;
            }
            if (treeHash.getHeight() < ret.getHeight()) {
                ret = treeHash;
                continue;
            }
            if (treeHash.getHeight() != ret.getHeight() || treeHash.getIndexLeaf() >= ret.getIndexLeaf()) continue;
            ret = treeHash;
        }
        return ret;
    }

    private void validate() {
        if (this.authenticationPath == null) {
            throw new IllegalStateException("authenticationPath == null");
        }
        if (this.retain == null) {
            throw new IllegalStateException("retain == null");
        }
        if (this.stack == null) {
            throw new IllegalStateException("stack == null");
        }
        if (this.treeHashInstances == null) {
            throw new IllegalStateException("treeHashInstances == null");
        }
        if (this.keep == null) {
            throw new IllegalStateException("keep == null");
        }
        if (!XMSSUtil.isIndexValid(this.treeHeight, this.index)) {
            throw new IllegalStateException("index in BDS state out of bounds");
        }
    }

    protected int getTreeHeight() {
        return this.treeHeight;
    }

    protected XMSSNode getRoot() {
        return this.root;
    }

    protected List<XMSSNode> getAuthenticationPath() {
        ArrayList<XMSSNode> authenticationPath = new ArrayList<XMSSNode>();
        for (XMSSNode node : this.authenticationPath) {
            authenticationPath.add(node);
        }
        return authenticationPath;
    }

    protected int getIndex() {
        return this.index;
    }

    public int getMaxIndex() {
        return this.maxIndex;
    }

    public BDS withWOTSDigest(ASN1ObjectIdentifier digestName) {
        return new BDS(this, digestName);
    }

    public BDS withMaxIndex(int maxIndex, ASN1ObjectIdentifier digestName) {
        return new BDS(this, maxIndex, digestName);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.maxIndex = in.available() != 0 ? in.readInt() : (1 << this.treeHeight) - 1;
        if (this.maxIndex > (1 << this.treeHeight) - 1 || this.index > this.maxIndex + 1 || in.available() != 0) {
            throw new IOException("inconsistent BDS data detected");
        }
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
        out.writeInt(this.maxIndex);
    }
}

