/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Activation;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.transformer.IdEmbedding;
import ai.djl.nn.transformer.MemoryScope;
import ai.djl.nn.transformer.TransformerEncoderBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public final class BertBlock
extends AbstractBlock {
    private static final byte VERSION = 1;
    private static final String PARAM_POSITION_EMBEDDING = "positionEmbedding";
    private final int embeddingSize;
    private final int tokenDictionarySize;
    private final int typeDictionarySize;
    private final IdEmbedding tokenEmbedding;
    private final IdEmbedding typeEmbedding;
    private final Parameter positionEmebdding;
    private final BatchNorm embeddingNorm;
    private final Dropout embeddingDropout;
    private final List<TransformerEncoderBlock> transformerEncoderBlocks;
    private final Linear pooling;

    private BertBlock(Builder builder) {
        super((byte)1);
        this.embeddingSize = builder.embeddingSize;
        this.tokenEmbedding = this.addChildBlock("tokenEmbedding", new IdEmbedding.Builder().setEmbeddingSize(builder.embeddingSize).setDictionarySize(builder.tokenDictionarySize).build());
        this.tokenDictionarySize = builder.tokenDictionarySize;
        this.positionEmebdding = this.addParameter(new Parameter(PARAM_POSITION_EMBEDDING, this, ParameterType.WEIGHT), new Shape(builder.maxSequenceLength, builder.embeddingSize));
        this.typeEmbedding = this.addChildBlock("typeEmbedding", new IdEmbedding.Builder().setEmbeddingSize(builder.embeddingSize).setDictionarySize(builder.typeDictionarySize).build());
        this.typeDictionarySize = builder.typeDictionarySize;
        this.embeddingNorm = this.addChildBlock("embeddingNorm", BatchNorm.builder().optAxis(2).build());
        this.embeddingDropout = this.addChildBlock("embeddingDropout", Dropout.builder().optRate(builder.hiddenDropoutProbability).build());
        this.transformerEncoderBlocks = new ArrayList<TransformerEncoderBlock>(builder.transformerBlockCount);
        for (int i = 0; i < builder.transformerBlockCount; ++i) {
            this.transformerEncoderBlocks.add(this.addChildBlock("transformer_" + i, new TransformerEncoderBlock(builder.embeddingSize, builder.attentionHeadCount, builder.hiddenSize, 0.1f, Activation::gelu)));
        }
        this.pooling = this.addChildBlock("poolingProjection", Linear.builder().setUnits(builder.embeddingSize).optBias(true).build());
    }

    public IdEmbedding getTokenEmbedding() {
        return this.tokenEmbedding;
    }

    public int getEmbeddingSize() {
        return this.embeddingSize;
    }

    public int getTokenDictionarySize() {
        return this.tokenDictionarySize;
    }

    public int getTypeDictionarySize() {
        return this.typeDictionarySize;
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        long B = inputShapes[0].get(0);
        long S = inputShapes[0].get(1);
        return new Shape[]{new Shape(B, S, this.embeddingSize), new Shape(B, this.embeddingSize)};
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.beforeInitialize(inputShapes);
        this.inputNames = Arrays.asList("tokenIds", "typeIds", "masks");
        Shape[] tokenShape = new Shape[]{inputShapes[0]};
        Shape[] typeShape = new Shape[]{inputShapes[1]};
        Shape[] embeddingOutput = this.tokenEmbedding.initialize(manager, dataType, tokenShape);
        this.typeEmbedding.initialize(manager, dataType, typeShape);
        this.embeddingNorm.initialize(manager, dataType, embeddingOutput);
        this.embeddingDropout.initialize(manager, dataType, embeddingOutput);
        for (TransformerEncoderBlock tb : this.transformerEncoderBlocks) {
            tb.initialize(manager, dataType, embeddingOutput);
        }
        long batchSize = inputShapes[0].get(0);
        this.pooling.initialize(manager, dataType, new Shape(batchSize, this.embeddingSize));
    }

    public static NDArray createAttentionMaskFromInputMask(NDArray ids, NDArray mask) {
        long batchSize = ids.getShape().get(0);
        long fromSeqLength = ids.getShape().get(1);
        long toSeqLength = mask.getShape().get(1);
        NDArray broadcastOnes = ids.onesLike().toType(DataType.FLOAT32, false).reshape(batchSize, fromSeqLength, 1L);
        NDArray mask3D = mask.toType(DataType.FLOAT32, false).reshape(batchSize, 1L, toSeqLength);
        return broadcastOnes.matMul(mask3D);
    }

    @Override
    protected NDList forwardInternal(ParameterStore ps, NDList inputs, boolean training, PairList<String, Object> params) {
        return this.forward(ps, inputs, training);
    }

    @Override
    public NDList forward(ParameterStore ps, NDList inputs, boolean training) {
        NDArray tokenIds = (NDArray)inputs.get(0);
        NDArray typeIds = (NDArray)inputs.get(1);
        NDArray masks = (NDArray)inputs.get(2);
        return this.forward(ps, tokenIds, typeIds, masks, training);
    }

    public NDList forward(ParameterStore ps, NDArray tokenIds, NDArray typeIds, NDArray masks, boolean training) {
        MemoryScope initScope = MemoryScope.from(tokenIds).add(typeIds, masks);
        NDArray embeddedTokens = this.tokenEmbedding.forward(ps, tokenIds, training);
        NDArray embeddedTypes = this.typeEmbedding.forward(ps, typeIds, training);
        NDArray embeddedPositions = ps.getValue(this.positionEmebdding, tokenIds.getDevice(), training);
        NDArray embedding = embeddedTokens.add(embeddedTypes).add(embeddedPositions);
        NDList normalizedEmbedding = this.embeddingNorm.forward(ps, new NDList(embedding), training);
        NDList dropoutEmbedding = this.embeddingDropout.forward(ps, normalizedEmbedding, training);
        NDArray attentionMask = BertBlock.createAttentionMaskFromInputMask(tokenIds, masks);
        Shape maskShape = attentionMask.getShape();
        NDArray offsetMask = attentionMask.reshape(maskShape.get(0), 1L, maskShape.get(1), maskShape.get(2)).toType(DataType.FLOAT32, false).mul(Float.valueOf(-1.0f)).add(Float.valueOf(1.0f)).mul(Float.valueOf(-100000.0f));
        NDList lastOutput = dropoutEmbedding;
        initScope.remove(tokenIds, typeIds, masks).waitToRead(dropoutEmbedding).waitToRead(offsetMask).close();
        for (TransformerEncoderBlock block : this.transformerEncoderBlocks) {
            NDList input = new NDList(lastOutput.head(), offsetMask);
            MemoryScope innerScope = MemoryScope.from(input);
            lastOutput = block.forward(ps, input, training);
            innerScope.remove(offsetMask).waitToRead(lastOutput).close();
        }
        NDArray firstToken = lastOutput.head().get(new NDIndex(":,1,:", new Object[0])).squeeze();
        NDArray pooledFirstToken = this.pooling.forward(ps, new NDList(firstToken), training).head().tanh();
        lastOutput.add(pooledFirstToken);
        return lastOutput;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        int tokenDictionarySize;
        int typeDictionarySize = 16;
        int embeddingSize = 768;
        int transformerBlockCount = 12;
        int attentionHeadCount = 12;
        int hiddenSize = 4 * this.embeddingSize;
        float hiddenDropoutProbability = 0.1f;
        int maxSequenceLength = 512;

        private Builder() {
        }

        public Builder setTokenDictionarySize(int tokenDictionarySize) {
            this.tokenDictionarySize = tokenDictionarySize;
            return this;
        }

        public Builder optTypeDictionarySize(int typeDictionarySize) {
            this.typeDictionarySize = typeDictionarySize;
            return this;
        }

        public Builder optEmbeddingSize(int embeddingSize) {
            this.embeddingSize = embeddingSize;
            return this;
        }

        public Builder optTransformerBlockCount(int transformerBlockCount) {
            this.transformerBlockCount = transformerBlockCount;
            return this;
        }

        public Builder optAttentionHeadCount(int attentionHeadCount) {
            this.attentionHeadCount = attentionHeadCount;
            return this;
        }

        public Builder optHiddenSize(int hiddenSize) {
            this.hiddenSize = hiddenSize;
            return this;
        }

        public Builder optHiddenDropoutProbability(float hiddenDropoutProbability) {
            this.hiddenDropoutProbability = hiddenDropoutProbability;
            return this;
        }

        public Builder optMaxSequenceLength(int maxSequenceLength) {
            this.maxSequenceLength = maxSequenceLength;
            return this;
        }

        public Builder nano() {
            this.typeDictionarySize = 2;
            this.embeddingSize = 256;
            this.transformerBlockCount = 4;
            this.attentionHeadCount = 4;
            this.hiddenSize = 4 * this.embeddingSize;
            this.hiddenDropoutProbability = 0.1f;
            this.maxSequenceLength = 128;
            return this;
        }

        public Builder micro() {
            this.typeDictionarySize = 2;
            this.embeddingSize = 512;
            this.transformerBlockCount = 12;
            this.attentionHeadCount = 8;
            this.hiddenSize = 4 * this.embeddingSize;
            this.hiddenDropoutProbability = 0.1f;
            this.maxSequenceLength = 128;
            return this;
        }

        public Builder base() {
            this.typeDictionarySize = 16;
            this.embeddingSize = 768;
            this.transformerBlockCount = 12;
            this.attentionHeadCount = 12;
            this.hiddenSize = 4 * this.embeddingSize;
            this.hiddenDropoutProbability = 0.1f;
            this.maxSequenceLength = 256;
            return this;
        }

        public Builder large() {
            this.typeDictionarySize = 16;
            this.embeddingSize = 1024;
            this.transformerBlockCount = 24;
            this.attentionHeadCount = 16;
            this.hiddenSize = 4 * this.embeddingSize;
            this.hiddenDropoutProbability = 0.1f;
            this.maxSequenceLength = 512;
            return this;
        }

        public BertBlock build() {
            if (this.tokenDictionarySize == 0) {
                throw new IllegalArgumentException("You must specify the dictionary size.");
            }
            return new BertBlock(this);
        }
    }
}

