/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv.translator;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;

public class Sam2Translator
implements NoBatchifyTranslator<Sam2Input, DetectedObjects> {
    private static final float[] MEAN = new float[]{0.485f, 0.456f, 0.406f};
    private static final float[] STD = new float[]{0.229f, 0.224f, 0.225f};
    private Pipeline pipeline = new Pipeline();
    private Predictor<NDList, NDList> predictor;
    private String encoderPath;
    private String encodeMethod;

    public Sam2Translator(Builder builder) {
        this.pipeline.add(new Resize(1024, 1024));
        this.pipeline.add(new ToTensor());
        this.pipeline.add(new Normalize(MEAN, STD));
        this.encoderPath = builder.encoderPath;
        this.encodeMethod = builder.encodeMethod;
    }

    @Override
    public void prepare(TranslatorContext ctx) throws IOException, ModelException {
        if (this.encoderPath == null) {
            if (this.encodeMethod != null) {
                Model model = ctx.getModel();
                this.predictor = model.newPredictor(new NoopTranslator(null));
                model.getNDManager().attachInternal(UUID.randomUUID().toString(), this.predictor);
            }
            return;
        }
        Model model = ctx.getModel();
        Path path = Paths.get(this.encoderPath, new String[0]);
        if (!path.isAbsolute() && Files.notExists(path, new LinkOption[0])) {
            path = model.getModelPath().resolve(this.encoderPath);
        }
        if (!Files.exists(path, new LinkOption[0])) {
            throw new IOException("encoder model not found: " + this.encoderPath);
        }
        NDManager manager = ctx.getNDManager();
        Model encoder = manager.getEngine().newModel("encoder", manager.getDevice());
        encoder.load(path);
        this.predictor = encoder.newPredictor(new NoopTranslator(null));
        model.getNDManager().attachInternal(UUID.randomUUID().toString(), this.predictor);
        model.getNDManager().attachInternal(UUID.randomUUID().toString(), encoder);
    }

    @Override
    public NDList processInput(TranslatorContext ctx, Sam2Input input) throws Exception {
        NDList embeddings;
        Image image = input.getImage();
        int width = image.getWidth();
        int height = image.getHeight();
        ctx.setAttachment("width", width);
        ctx.setAttachment("height", height);
        float[] buf = input.toLocationArray(width, height);
        NDManager manager = ctx.getNDManager();
        NDArray array = image.toNDArray(manager, Image.Flag.COLOR);
        array = ((NDArray)this.pipeline.transform(new NDList(array)).get(0)).expandDims(0);
        NDArray locations = manager.create(buf, new Shape(1L, buf.length / 2, 2L));
        NDArray labels = manager.create(input.getLabels());
        if (this.predictor == null) {
            return new NDList(array, locations, labels);
        }
        if (this.encodeMethod == null) {
            embeddings = this.predictor.predict(new NDList(array));
        } else {
            NDArray placeholder = manager.create("");
            placeholder.setName("module_method:" + this.encodeMethod);
            embeddings = this.predictor.predict(new NDList(placeholder, array));
        }
        NDArray mask = manager.zeros(new Shape(1L, 1L, 256L, 256L));
        NDArray hasMask = manager.zeros(new Shape(1L));
        for (NDArray arr : embeddings) {
            arr.setName(null);
        }
        return new NDList((NDArray)embeddings.get(2), (NDArray)embeddings.get(0), (NDArray)embeddings.get(1), locations, labels, mask, hasMask);
    }

    @Override
    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        NDArray logits = (NDArray)list.get(0);
        NDArray scores = ((NDArray)list.get(1)).squeeze(0);
        long best = scores.argMax().getLong(new long[0]);
        int width = (Integer)ctx.getAttachment("width");
        int height = (Integer)ctx.getAttachment("height");
        long[] size = new long[]{height, width};
        int mode = Image.Interpolation.BILINEAR.ordinal();
        logits = logits.getNDArrayInternal().interpolation(size, mode, false);
        NDArray masks = logits.gt(Float.valueOf(0.0f)).squeeze(0);
        float[][] dist = Mask.toMask(masks.get(best).toType(DataType.FLOAT32, true));
        Mask mask = new Mask(0.0, 0.0, width, height, dist, true);
        double probability = scores.getFloat(best);
        List<String> classes = Collections.singletonList("");
        List<Double> probabilities = Collections.singletonList(probability);
        List<BoundingBox> boxes = Collections.singletonList(mask);
        return new DetectedObjects(classes, probabilities, boxes);
    }

    public static Builder builder() {
        return Sam2Translator.builder(Collections.emptyMap());
    }

    public static Builder builder(Map<String, ?> arguments) {
        return new Builder(arguments);
    }

    public static class Builder {
        String encoderPath;
        String encodeMethod;

        Builder(Map<String, ?> arguments) {
            this.encoderPath = ArgumentsUtil.stringValue(arguments, "encoder");
            this.encodeMethod = ArgumentsUtil.stringValue(arguments, "encode_method");
        }

        public Builder optEncoderPath(String encoderPath) {
            this.encoderPath = encoderPath;
            return this;
        }

        public Builder optEncodeMethod(String encodeMethod) {
            this.encodeMethod = encodeMethod;
            return this;
        }

        public Sam2Translator build() {
            return new Sam2Translator(this);
        }
    }

    public static final class Sam2Input {
        private Image image;
        private Point[] points;
        private int[] labels;
        private boolean visualize;

        public Sam2Input(Image image, Point[] points, int[] labels) {
            this(image, points, labels, false);
        }

        public Sam2Input(Image image, Point[] points, int[] labels, boolean visualize) {
            this.image = image;
            this.points = points;
            this.labels = labels;
            this.visualize = visualize;
        }

        public Image getImage() {
            return this.image;
        }

        public boolean isVisualize() {
            return this.visualize;
        }

        public List<Point> getPoints() {
            ArrayList<Point> list = new ArrayList<Point>();
            for (int i = 0; i < this.labels.length; ++i) {
                if (this.labels[i] >= 2) continue;
                list.add(this.points[i]);
            }
            return list;
        }

        public List<Rectangle> getBoxes() {
            ArrayList<Rectangle> list = new ArrayList<Rectangle>();
            for (int i = 0; i < this.labels.length; ++i) {
                if (this.labels[i] != 2) continue;
                double width = this.points[i + 1].getX() - this.points[i].getX();
                double height = this.points[i + 1].getY() - this.points[i].getY();
                list.add(new Rectangle(this.points[i], width, height));
            }
            return list;
        }

        float[] toLocationArray(int width, int height) {
            float[] ret = new float[this.points.length * 2];
            int i = 0;
            for (Point point : this.points) {
                ret[i++] = (float)point.getX() / (float)width * 1024.0f;
                ret[i++] = (float)point.getY() / (float)height * 1024.0f;
            }
            return ret;
        }

        float[][] getLabels() {
            float[][] buf = new float[1][this.labels.length];
            for (int i = 0; i < this.labels.length; ++i) {
                buf[0][i] = this.labels[i];
            }
            return buf;
        }

        public static Sam2Input fromJson(String input) throws IOException {
            Prompt prompt = (Prompt)JsonUtils.GSON.fromJson(input, Prompt.class);
            if (prompt.image == null) {
                throw new IllegalArgumentException("Missing image_url value");
            }
            if (prompt.prompt == null || prompt.prompt.length == 0) {
                throw new IllegalArgumentException("Missing prompt value");
            }
            Image image = ImageFactory.getInstance().fromUrl(prompt.image);
            Builder builder = Sam2Input.builder(image);
            if (prompt.visualize) {
                builder.visualize();
            }
            for (Location location : prompt.prompt) {
                int[] data = location.data;
                if ("point".equals(location.type)) {
                    builder.addPoint(data[0], data[1], location.label);
                    continue;
                }
                if (!"rectangle".equals(location.type)) continue;
                builder.addBox(data[0], data[1], data[2], data[3]);
            }
            return builder.build();
        }

        public static Builder builder(Image image) {
            return new Builder(image);
        }

        private static final class Prompt {
            @SerializedName(value="image_url")
            String image;
            Location[] prompt;
            boolean visualize;

            private Prompt() {
            }

            public void setImage(String image) {
                this.image = image;
            }

            public void setPrompt(Location[] prompt) {
                this.prompt = prompt;
            }

            public void setVisualize(boolean visualize) {
                this.visualize = visualize;
            }
        }

        private static final class Location {
            String type;
            int[] data;
            int label;

            private Location() {
            }

            public void setType(String type) {
                this.type = type;
            }

            public void setData(int[] data) {
                this.data = data;
            }

            public void setLabel(int label) {
                this.label = label;
            }
        }

        public static final class Builder {
            private Image image;
            private List<Point> points;
            private List<Integer> labels;
            private boolean visualize;

            Builder(Image image) {
                this.image = image;
                this.points = new ArrayList<Point>();
                this.labels = new ArrayList<Integer>();
            }

            public Builder addPoint(int x, int y) {
                return this.addPoint(x, y, 1);
            }

            public Builder addPoint(int x, int y, int label) {
                return this.addPoint(new Point(x, y), label);
            }

            public Builder addPoint(Point point, int label) {
                this.points.add(point);
                this.labels.add(label);
                return this;
            }

            public Builder addBox(int x, int y, int right, int bottom) {
                this.addPoint(new Point(x, y), 2);
                this.addPoint(new Point(right, bottom), 3);
                return this;
            }

            public Builder visualize() {
                this.visualize = true;
                return this;
            }

            public Sam2Input build() {
                Point[] location = this.points.toArray(new Point[0]);
                int[] array = this.labels.stream().mapToInt(Integer::intValue).toArray();
                return new Sam2Input(this.image, location, array, this.visualize);
            }
        }
    }
}

