/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.zoo.nlp.embedding;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.embedding.TrainableWordEmbedding;
import ai.djl.mxnet.zoo.MxModelZoo;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Block;
import ai.djl.nn.core.Embedding;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.Utils;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;

public class GloveWordEmbeddingModelLoader
extends BaseModelLoader {
    private static final Application APPLICATION = Application.NLP.WORD_EMBEDDING;
    private static final String GROUP_ID = "ai.djl.mxnet";
    private static final String ARTIFACT_ID = "glove";
    private static final String VERSION = "0.0.2";

    public GloveWordEmbeddingModelLoader(Repository repository) {
        super(repository, MRL.model((Application)APPLICATION, (String)GROUP_ID, (String)ARTIFACT_ID), VERSION, (ModelZoo)new MxModelZoo());
        this.factories.put(new Pair(String.class, NDList.class), new FactoryImpl());
    }

    private Model customGloveBlock(Model model, Artifact artifact, Map<String, Object> arguments) throws IOException {
        List idxToToken = Utils.readLines((InputStream)this.resource.getRepository().openStream((Artifact.Item)artifact.getFiles().get("idx_to_token"), null));
        TrainableWordEmbedding wordEmbedding = ((TrainableWordEmbedding.Builder)((TrainableWordEmbedding.Builder)TrainableWordEmbedding.builder().optNumEmbeddings(Integer.parseInt((String)artifact.getProperties().get("dimensions")))).setVocabulary((Vocabulary)new SimpleVocabulary(idxToToken)).optUnknownToken((String)arguments.get("unknownToken")).optUseDefault(true)).build();
        model.setBlock((Block)wordEmbedding);
        model.setProperty("unknownToken", (String)arguments.get("unknownToken"));
        return model;
    }

    protected Model createModel(String name, Device device, Artifact artifact, Map<String, Object> arguments, String engine) throws IOException {
        Model model = Model.newInstance((String)name, (Device)device, (String)engine);
        return this.customGloveBlock(model, artifact, arguments);
    }

    public ZooModel<NDList, NDList> loadModel() throws IOException, ModelNotFoundException, MalformedModelException {
        Criteria criteria = Criteria.builder().setTypes(NDList.class, NDList.class).optApplication(Application.NLP.WORD_EMBEDDING).build();
        return this.loadModel(criteria);
    }

    private static final class TranslatorImpl
    implements Translator<String, NDList> {
        private String unknownToken;
        private Embedding<String> embedding;

        public TranslatorImpl(String unknownToken) {
            this.unknownToken = unknownToken;
        }

        public void prepare(NDManager manager, Model model) {
            try {
                this.embedding = (Embedding)model.getBlock();
            }
            catch (ClassCastException e) {
                throw new IllegalArgumentException("The model was not an embedding", e);
            }
        }

        public NDList processOutput(TranslatorContext ctx, NDList list) {
            return list;
        }

        public NDList processInput(TranslatorContext ctx, String input) {
            if (this.embedding.hasItem((Object)input)) {
                return new NDList(new NDArray[]{ctx.getNDManager().create(this.embedding.embed((Object)input))});
            }
            return new NDList(new NDArray[]{ctx.getNDManager().create(this.embedding.embed((Object)this.unknownToken))});
        }

        public Batchifier getBatchifier() {
            return Batchifier.STACK;
        }
    }

    private static final class FactoryImpl
    implements TranslatorFactory<String, NDList> {
        private FactoryImpl() {
        }

        public Translator<String, NDList> newInstance(Model model, Map<String, ?> arguments) {
            String unknownToken = (String)arguments.get("unknownToken");
            return new TranslatorImpl(unknownToken);
        }
    }
}

