/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.VisionLanguageInput;
import ai.djl.modality.cv.translator.BaseImagePreProcessor;
import ai.djl.modality.cv.translator.BaseImageTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

public class ZeroShotImageClassificationTranslator
implements NoBatchifyTranslator<VisionLanguageInput, Classifications> {
    private HuggingFaceTokenizer tokenizer;
    private BaseImageTranslator<?> imageProcessor;
    private boolean int32;

    ZeroShotImageClassificationTranslator(HuggingFaceTokenizer tokenizer, BaseImageTranslator<?> imageProcessor, boolean int32) {
        this.tokenizer = tokenizer;
        this.imageProcessor = imageProcessor;
        this.int32 = int32;
    }

    public NDList processInput(TranslatorContext ctx, VisionLanguageInput input) throws TranslateException {
        NDManager manager = ctx.getNDManager();
        String template = input.getHypothesisTemplate();
        String[] candidates = input.getCandidates();
        if (candidates == null || candidates.length == 0) {
            throw new TranslateException("Missing candidates in input");
        }
        ArrayList<String> sequences = new ArrayList<String>(candidates.length);
        for (String candidate : candidates) {
            sequences.add(this.applyTemplate(template, candidate));
        }
        Encoding[] encodings = this.tokenizer.batchEncode(sequences);
        NDList list = Encoding.toNDList(encodings, manager, false, this.int32);
        Image img = input.getImage();
        NDList imageFeatures = this.imageProcessor.processInput(ctx, img);
        NDArray array = ((NDArray)imageFeatures.get(0)).expandDims(0);
        list.add((Object)array);
        ctx.setAttachment("candidates", (Object)candidates);
        return list;
    }

    public Classifications processOutput(TranslatorContext ctx, NDList list) throws TranslateException {
        NDArray logits = list.get("logits_per_image");
        logits = logits.squeeze().softmax(0);
        String[] candidates = (String[])ctx.getAttachment("candidates");
        List<String> classes = Arrays.asList(candidates);
        return new Classifications(classes, logits, candidates.length);
    }

    private String applyTemplate(String template, String arg) {
        int pos = template.indexOf("{}");
        if (pos == -1) {
            return template + arg;
        }
        int len = template.length();
        return template.substring(0, pos) + arg + template.substring(pos + 2, len);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer) {
        return new Builder(tokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arguments) {
        Builder builder = ZeroShotImageClassificationTranslator.builder(tokenizer);
        builder.configure(arguments);
        return builder;
    }

    public static final class Builder
    extends BaseImageTranslator.BaseBuilder<Builder> {
        private HuggingFaceTokenizer tokenizer;
        private boolean int32;

        Builder(HuggingFaceTokenizer tokenizer) {
            this.tokenizer = tokenizer;
        }

        protected Builder self() {
            return this;
        }

        public Builder optInt32(boolean int32) {
            this.int32 = int32;
            return this;
        }

        public void configure(Map<String, ?> arguments) {
            this.configPreProcess(arguments);
            this.optInt32(ArgumentsUtil.booleanValue(arguments, (String)"int32"));
        }

        public ZeroShotImageClassificationTranslator build() throws IOException {
            BaseImagePreProcessor processor = new BaseImagePreProcessor((BaseImageTranslator.BaseBuilder)this);
            return new ZeroShotImageClassificationTranslator(this.tokenizer, (BaseImageTranslator<?>)processor, this.int32);
        }
    }
}

