/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import java.util.ArrayList;
import java.util.Arrays;

public final class MobileNetV2 {
    public static final int FILTERLENGTH = 9;
    public static final int REPEATLENGTH = 9;
    public static final int STRIDELENGTH = 9;
    public static final int MULTILENGTH = 7;

    private MobileNetV2() {
    }

    public static Block linearBottleNeck(int inputChannels, int outputChannels, int stride, int t, float batchNormMomentum) {
        SequentialBlock block = new SequentialBlock();
        block.add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(inputChannels * t)).setKernelShape(new Shape(new long[]{1L, 1L}))).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(batchNormMomentum).build()).add(Activation.relu6Block()).add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{3L, 3L}))).setFilters(inputChannels * t)).optStride(new Shape(new long[]{stride, stride}))).optPadding(new Shape(new long[]{1L, 1L}))).optGroups(inputChannels * t)).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(batchNormMomentum).build()).add(Activation.relu6Block()).add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(outputChannels)).setKernelShape(new Shape(new long[]{1L, 1L}))).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(batchNormMomentum).build());
        if (stride == 1 && inputChannels == outputChannels) {
            return new ParallelBlock(list -> new NDList(new NDArray[]{NDArrays.add((NDArray[])new NDArray[]{((NDList)list.get(0)).singletonOrThrow(), ((NDList)list.get(1)).singletonOrThrow()})}), Arrays.asList(block, Blocks.identityBlock()));
        }
        return block;
    }

    public static Block makeStage(int repeat, int inputChannels, int outputChannels, int stride, int t, float batchNormMomentum) {
        SequentialBlock layers = new SequentialBlock();
        layers.add(MobileNetV2.linearBottleNeck(inputChannels, outputChannels, stride, t, batchNormMomentum));
        for (int i = 0; i < repeat - 1; ++i) {
            layers.add(MobileNetV2.linearBottleNeck(outputChannels, outputChannels, 1, t, batchNormMomentum));
        }
        return layers;
    }

    public static Block mobilenetV2(Builder builder) {
        SequentialBlock mobileNet = new SequentialBlock();
        SequentialBlock pre = new SequentialBlock();
        for (int i = 0; i < builder.repeatTimes[0]; ++i) {
            pre.add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{1L, 1L}))).setFilters(builder.filters[0])).optStride(new Shape(new long[]{builder.strides[0], builder.strides[0]}))).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.relu6Block());
        }
        ArrayList<Block> bottleNecks = new ArrayList<Block>();
        for (int i = 0; i < 7; ++i) {
            bottleNecks.add(MobileNetV2.makeStage(builder.repeatTimes[i + 1], builder.filters[i], builder.filters[i + 1], builder.strides[i + 1], builder.multiTimes[i], builder.batchNormMomentum));
        }
        SequentialBlock conv1 = new SequentialBlock();
        for (int i = 0; i < builder.repeatTimes[8]; ++i) {
            conv1.add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{1L, 1L}))).setFilters(builder.filters[8])).optStride(new Shape(new long[]{builder.strides[8], builder.strides[8]}))).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.relu6Block());
        }
        Conv2d conv2 = ((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{1L, 1L}))).setFilters((int)builder.outSize)).build();
        return mobileNet.add((Block)pre).addAll(bottleNecks).add((Block)conv1).add(Pool.globalAvgPool2dBlock()).addSingleton(array -> array.reshape(new long[]{array.getShape().get(0), builder.filters[8], 1L, 1L})).add((Block)conv2).addSingleton(array -> array.reshape(new long[]{array.getShape().get(0), builder.outSize}));
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        float batchNormMomentum = 0.9f;
        long outSize = 10L;
        int[] repeatTimes = new int[]{1, 1, 2, 3, 4, 3, 3, 1, 1};
        int[] filters = new int[]{32, 16, 24, 32, 64, 96, 160, 320, 1280};
        int[] strides = new int[]{2, 1, 2, 2, 2, 1, 2, 1, 1};
        int[] multiTimes = new int[]{1, 6, 6, 6, 6, 6, 6};

        Builder() {
        }

        public Builder optBatchNormMomentum(float batchNormMomentum) {
            this.batchNormMomentum = batchNormMomentum;
            return this;
        }

        public Builder setOutSize(long outSize) {
            this.outSize = outSize;
            return this;
        }

        public Builder optFilters(int[] filters) {
            if (filters.length != 9) {
                throw new IllegalArgumentException(String.format("optFilters requires filters of length %d, but was given filters of length %d instead", 9, filters.length));
            }
            this.filters = filters;
            return this;
        }

        public Builder optRepeatTimes(int[] repeatTimes) {
            if (repeatTimes.length != 9) {
                throw new IllegalArgumentException(String.format("optRepeatTimes requires repeatTimes of length %d, but was given repeatTimes of length %d instead", 9, repeatTimes.length));
            }
            this.repeatTimes = repeatTimes;
            return this;
        }

        public Builder optStrides(int[] strides) {
            if (strides.length != 9) {
                throw new IllegalArgumentException(String.format("optStrides requires strides of length %d, but was given strides of length %d instead", 9, strides.length));
            }
            this.strides = strides;
            return this;
        }

        public Builder optMultiTimes(int[] multiTimes) {
            if (multiTimes.length != 7) {
                throw new IllegalArgumentException(String.format("optMultiTimes requires multiTimes of length %d, but was given multiTimes of length %d instead", 7, multiTimes.length));
            }
            this.multiTimes = multiTimes;
            return this;
        }

        public Block build() {
            return MobileNetV2.mobilenetV2(this);
        }
    }
}

