/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.huggingface.translator.TextEmbeddingTranslator;
import ai.djl.modality.nlp.EmbeddingOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Activation;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class SparseRetrievalTranslator
implements Translator<String, EmbeddingOutput> {
    private static final String[] SPECIAL_TOKENS = new String[]{"cls_token", "eos_token", "pad_token", "unk_token"};
    private HuggingFaceTokenizer tokenizer;
    private TextEmbeddingTranslator translator;
    private boolean includeTokenTypes;
    private boolean int32;
    private boolean returnDenseEmbedding;
    private Set<Long> unusedTokens;
    private String sparseLinear;
    private NDList sparseLinearModel;

    SparseRetrievalTranslator(Builder builder) {
        this.tokenizer = builder.tokenizer;
        this.translator = builder.baseBuilder.build();
        this.includeTokenTypes = builder.baseBuilder.includeTokenTypes;
        this.int32 = builder.baseBuilder.int32;
        this.returnDenseEmbedding = builder.returnDenseEmbedding;
        this.sparseLinear = builder.sparseLinear;
        Encoding encoding = this.tokenizer.encode(SPECIAL_TOKENS);
        this.unusedTokens = Arrays.stream(encoding.getIds()).boxed().collect(Collectors.toSet());
    }

    public void prepare(TranslatorContext ctx) throws Exception {
        NDManager manager = ctx.getPredictorManager().newSubManager();
        if (this.returnDenseEmbedding) {
            this.translator.prepare(ctx);
        }
        if (this.sparseLinear != null) {
            Path file = Paths.get(this.sparseLinear, new String[0]);
            if (!file.isAbsolute()) {
                file = ctx.getModel().getModelPath().resolve(file);
            }
            if (Files.notExists(file, new LinkOption[0])) {
                throw new TranslateException("sparseLinear file does not exist: " + this.sparseLinear);
            }
            try (InputStream is = Files.newInputStream(file, new OpenOption[0]);){
                this.sparseLinearModel = NDList.decode((NDManager)manager, (InputStream)is);
            }
        }
    }

    public NDList processInput(TranslatorContext ctx, String input) {
        return this.batchProcessInput(ctx, Collections.singletonList(input));
    }

    public NDList batchProcessInput(TranslatorContext ctx, List<String> inputs) {
        NDManager manager = ctx.getNDManager();
        Encoding[] encodings = this.tokenizer.batchEncode(inputs);
        NDList list = Encoding.toNDList(encodings, manager, this.includeTokenTypes, this.int32);
        ctx.setAttachment("encodings", (Object)encodings);
        ctx.setAttachment("attentionMask", list.get(1));
        return list;
    }

    public EmbeddingOutput processOutput(TranslatorContext ctx, NDList list) {
        return Objects.requireNonNull(this.batchProcessOutput(ctx, list)).get(0);
    }

    public List<EmbeddingOutput> batchProcessOutput(TranslatorContext ctx, NDList list) {
        Encoding[] encodings = (Encoding[])ctx.getAttachment("encodings");
        int batchSize = encodings.length;
        ArrayList<EmbeddingOutput> embeddings = new ArrayList<EmbeddingOutput>();
        NDArray lastHiddenState = list.get("last_hidden_state");
        if (lastHiddenState == null) {
            lastHiddenState = (NDArray)list.get(0);
        }
        NDArray weight = this.sparseLinearModel.get("weight").toType(lastHiddenState.getDataType(), false);
        NDArray bias = this.sparseLinearModel.get("bias").toType(lastHiddenState.getDataType(), false);
        NDArray array = (NDArray)lastHiddenState.getNDArrayInternal().linear(lastHiddenState, weight, bias).get(0);
        array = Activation.relu((NDArray)array);
        NDArray sparseVecs = array.squeeze(-1);
        float[] data = sparseVecs.toFloatArray();
        int index = 0;
        for (Encoding encoding : encodings) {
            long[] tokenIds = encoding.getIds();
            EmbeddingOutput embedding = new EmbeddingOutput();
            embeddings.add(embedding);
            for (long idx : tokenIds) {
                float w = data[index++];
                if (this.unusedTokens.contains(idx) || !(w > 0.0f)) continue;
                embedding.addTokenWeights(String.valueOf(idx), w);
            }
        }
        if (this.returnDenseEmbedding) {
            NDArray attentionMask = (NDArray)ctx.getAttachment("attentionMask");
            NDArray output = this.translator.processEmbedding(list, attentionMask);
            FloatBuffer fb = output.toByteBuffer().asFloatBuffer();
            int denseEmbeddingSize = fb.remaining() / batchSize;
            for (EmbeddingOutput embedding : embeddings) {
                float[] buf = new float[denseEmbeddingSize];
                fb.get(buf);
                embedding.setDenseEmbedding(buf);
            }
        }
        return embeddings;
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
        Builder builder = new Builder(tokenizer);
        builder.configure(arguments);
        return builder;
    }

    public static final class Builder {
        HuggingFaceTokenizer tokenizer;
        TextEmbeddingTranslator.Builder baseBuilder;
        boolean returnDenseEmbedding;
        String sparseLinear;

        Builder(HuggingFaceTokenizer tokenizer) {
            this.tokenizer = tokenizer;
            this.baseBuilder = TextEmbeddingTranslator.builder(tokenizer);
            this.sparseLinear = "sparse_linear.safetensors";
        }

        public Builder optReturnDenseEmbedding(boolean returnDenseEmbedding) {
            this.returnDenseEmbedding = returnDenseEmbedding;
            return this;
        }

        public Builder optSparseLinear(String sparseLinear) {
            this.sparseLinear = sparseLinear;
            return this;
        }

        public void configure(Map<String, ?> arguments) {
            this.baseBuilder.configure(arguments);
            this.optReturnDenseEmbedding(ArgumentsUtil.booleanValue(arguments, (String)"returnDenseEmbedding", (boolean)false));
            this.optSparseLinear(ArgumentsUtil.stringValue(arguments, (String)"sparseLinear", (String)this.sparseLinear));
        }

        public SparseRetrievalTranslator build() throws IOException {
            return new SparseRetrievalTranslator(this);
        }
    }
}

