/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.zoo.cv.objectdetection;

import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.ObjectDetectionTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class PtSsdTranslator
extends ObjectDetectionTranslator {
    private NDArray boxRecover;
    private int figSize;
    private int[] featSize;
    private int[] steps;
    private int[] scale;
    private int[][] aspectRatio;

    protected PtSsdTranslator(Builder builder) {
        super((ObjectDetectionTranslator.ObjectDetectionBuilder)builder);
        this.figSize = builder.figSize;
        this.featSize = builder.featSize;
        this.steps = builder.steps;
        this.scale = builder.scale;
        this.aspectRatio = builder.aspectRatio;
    }

    public void prepare(TranslatorContext ctx) throws Exception {
        super.prepare(ctx);
        NDManager manager = ctx.getPredictorManager();
        this.boxRecover = this.boxRecover(manager, this.figSize, this.featSize, this.steps, this.scale, this.aspectRatio);
    }

    public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
        double scaleXY = 0.1;
        double scaleWH = 0.2;
        NDArray prob = ((NDArray)list.get(1)).swapAxes(0, 1).softmax(1).get(":, 1:", new Object[0]);
        prob = NDArrays.stack((NDList)new NDList(new NDArray[]{prob.argMax(1).toType(DataType.FLOAT32, false), prob.max(new int[]{1})}));
        NDArray boundingBoxes = ((NDArray)list.get(0)).swapAxes(0, 1);
        NDArray bbWH = boundingBoxes.get(":, 2:", new Object[0]).mul((Number)scaleWH).exp().mul(this.boxRecover.get(":, 2:", new Object[0]));
        NDArray bbXY = boundingBoxes.get(":, :2", new Object[0]).mul((Number)scaleXY).mul(this.boxRecover.get(":, 2:", new Object[0])).add(this.boxRecover.get(":, :2", new Object[0])).sub(bbWH.mul((Number)Float.valueOf(0.5f)));
        boundingBoxes = NDArrays.concat((NDList)new NDList(new NDArray[]{bbXY, bbWH}), (int)1);
        NDArray cutOff = prob.get(new long[]{1L}).gte((Number)Float.valueOf(this.threshold));
        boundingBoxes = boundingBoxes.transpose().booleanMask(cutOff, 1).transpose();
        prob = prob.booleanMask(cutOff, 1);
        long[] order = prob.get(new long[]{1L}).argSort().toLongArray();
        double desiredIoU = 0.45;
        prob = prob.transpose();
        ArrayList<String> retNames = new ArrayList<String>();
        ArrayList<Double> retProbs = new ArrayList<Double>();
        ArrayList<Rectangle> retBB = new ArrayList<Rectangle>();
        ConcurrentHashMap<Integer, List> recorder = new ConcurrentHashMap<Integer, List>();
        for (int i = order.length - 1; i >= 0; --i) {
            long currMaxLoc = order[i];
            float[] classProb = prob.get(new long[]{currMaxLoc}).toFloatArray();
            int classId = (int)classProb[0];
            double probability = classProb[1];
            double[] boxArr = boundingBoxes.get(new long[]{currMaxLoc}).toDoubleArray();
            Rectangle rect = new Rectangle(boxArr[0], boxArr[1], boxArr[2], boxArr[3]);
            List boxes = recorder.getOrDefault(classId, new ArrayList());
            boolean belowIoU = true;
            for (BoundingBox box : boxes) {
                if (!(box.getIoU((BoundingBox)rect) > desiredIoU)) continue;
                belowIoU = false;
                break;
            }
            if (!belowIoU) continue;
            boxes.add(rect);
            recorder.put(classId, boxes);
            String className = (String)this.classes.get(classId);
            retNames.add(className);
            retProbs.add(probability);
            retBB.add(rect);
        }
        return new DetectedObjects(retNames, retProbs, retBB);
    }

    NDArray boxRecover(NDManager manager, int figSize, int[] featSize, int[] steps, int[] scale, int[][] aspectRatio) {
        double[] fk = manager.create(steps).toType(DataType.FLOAT64, true).getNDArrayInternal().rdiv((Number)figSize).toDoubleArray();
        ArrayList<double[]> defaultBoxes = new ArrayList<double[]>();
        for (int idx = 0; idx < featSize.length; ++idx) {
            double sk1 = (double)scale[idx] * 1.0 / (double)figSize;
            double sk2 = (double)scale[idx + 1] * 1.0 / (double)figSize;
            double sk3 = Math.sqrt(sk1 * sk2);
            ArrayList<double[]> array = new ArrayList<double[]>();
            array.add(new double[]{sk1, sk1});
            array.add(new double[]{sk3, sk3});
            for (int alpha : aspectRatio[idx]) {
                double w = sk1 * Math.sqrt(alpha);
                double h = sk1 / Math.sqrt(alpha);
                array.add(new double[]{w, h});
                array.add(new double[]{h, w});
            }
            Object object = array.iterator();
            while (object.hasNext()) {
                double[] size = (double[])object.next();
                for (int i = 0; i < featSize[idx]; ++i) {
                    for (int j = 0; j < featSize[idx]; ++j) {
                        double cx = ((double)j + 0.5) / fk[idx];
                        double cy = ((double)i + 0.5) / fk[idx];
                        defaultBoxes.add(new double[]{cx, cy, size[0], size[1]});
                    }
                }
            }
        }
        double[][] boxes = new double[defaultBoxes.size()][((double[])defaultBoxes.get(0)).length];
        for (int i = 0; i < defaultBoxes.size(); ++i) {
            boxes[i] = (double[])defaultBoxes.get(i);
        }
        return manager.create(boxes).clip((Number)0.0, (Number)1.0);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = new Builder();
        builder.configPreProcess(arguments);
        builder.configPostProcess(arguments);
        return builder;
    }

    public static class Builder
    extends ObjectDetectionTranslator.ObjectDetectionBuilder<Builder> {
        private int figSize;
        private int[] featSize;
        private int[] steps;
        private int[] scale;
        private int[][] aspectRatio;

        public Builder setBoxes(int figSize, int[] featSize, int[] steps, int[] scale, int[][] aspectRatio) {
            this.figSize = figSize;
            this.featSize = featSize;
            this.steps = steps;
            this.scale = scale;
            this.aspectRatio = aspectRatio;
            return this;
        }

        protected Builder self() {
            return null;
        }

        protected void configPreProcess(Map<String, ?> arguments) {
            super.configPreProcess(arguments);
        }

        protected void configPostProcess(Map<String, ?> arguments) {
            super.configPostProcess(arguments);
            this.threshold = ArgumentsUtil.floatValue(arguments, (String)"threshold", (float)0.4f);
            this.figSize = ArgumentsUtil.intValue(arguments, (String)"size", (int)300);
            List list = (List)arguments.get("featSize");
            this.featSize = list == null ? new int[]{38, 19, 10, 5, 3, 1} : list.stream().mapToInt(Double::intValue).toArray();
            list = (List)arguments.get("steps");
            this.steps = list == null ? new int[]{8, 16, 32, 64, 100, 300} : list.stream().mapToInt(Double::intValue).toArray();
            list = (List)arguments.get("scale");
            this.scale = list == null ? new int[]{21, 45, 99, 153, 207, 261, 315} : list.stream().mapToInt(Double::intValue).toArray();
            List ratio = (List)arguments.get("aspectRatios");
            if (ratio == null) {
                this.aspectRatio = new int[][]{{2}, {2, 3}, {2, 3}, {2, 3}, {2}, {2}};
            } else {
                this.aspectRatio = new int[ratio.size()][];
                for (int i = 0; i < this.aspectRatio.length; ++i) {
                    this.aspectRatio[i] = ((List)ratio.get(i)).stream().mapToInt(Double::intValue).toArray();
                }
            }
        }

        public PtSsdTranslator build() {
            this.validate();
            return new PtSsdTranslator(this);
        }
    }
}

