/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.ControlFlow;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.shape.CreateView;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SDVariable
implements Serializable {
    private static final Logger log = LoggerFactory.getLogger(SDVariable.class);
    protected SameDiff sameDiff;
    protected String varName;
    protected VariableType variableType;
    protected long[] shape;
    protected DataType dataType;
    private DifferentialFunction creator;

    public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNull SameDiff sameDiff, long[] shape, DataType dataType) {
        if (varName == null) {
            throw new NullPointerException("varName is marked non-null but is null");
        }
        if (varType == null) {
            throw new NullPointerException("varType is marked non-null but is null");
        }
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        if (varType != VariableType.PLACEHOLDER) {
            Preconditions.checkState((dataType != DataType.UNKNOWN ? 1 : 0) != 0, (String)"Unknown datatype is not allowed for SDVariables (variable name: %s)", (Object)varName);
        }
        if (varName == null) {
            varName = sameDiff.generateNewVarName(varName, 0, true);
        }
        this.sameDiff = sameDiff;
        this.varName = varName;
        this.variableType = varType;
        this.dataType = dataType;
        this.shape = shape;
    }

    public String name() {
        return this.varName;
    }

    public void setVarName(String varName) {
        this.varName = varName;
    }

    @Deprecated
    public String getVarName() {
        return this.name();
    }

    public boolean isPlaceHolder() {
        return this.variableType == VariableType.PLACEHOLDER;
    }

    public boolean isConstant() {
        return this.variableType == VariableType.CONSTANT;
    }

    public INDArray getArr() {
        return this.getArr(false);
    }

    public INDArray getArr(boolean enforceExistence) {
        if (this.sameDiff.arrayAlreadyExistsForVarName(this.getVarName())) {
            return this.sameDiff.getArrForVarName(this.getVarName());
        }
        if (this.variableType == VariableType.ARRAY && enforceExistence) {
            throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead");
        }
        if (this.variableType == VariableType.ARRAY) {
            if (this.sameDiff.isEagerMode()) {
                return this.sameDiff.getEagerArrForVarName(this.name());
            }
            return null;
        }
        INDArray ret = this.sameDiff.getArrForVarName(this.getVarName());
        if (enforceExistence && ret == null) {
            throw new IllegalStateException("No array exists for variable \"" + this.name() + "\"");
        }
        return ret;
    }

    public SDVariable gradient() {
        return this.getGradient();
    }

    public SDVariable getGradient() {
        Preconditions.checkState((boolean)this.dataType().isFPType(), (String)"Cannot get gradient of %s datatype variable \"%s\": only floating point variables have gradients", (Object)((Object)this.dataType()), (Object)this.getVarName());
        return this.sameDiff.getGradForVariable(this.getVarName());
    }

    public long[] getShape() {
        if (this.variableType == VariableType.PLACEHOLDER || this.sameDiff.isEagerMode() && this.shape != null) {
            return this.shape;
        }
        if ((this.variableType == VariableType.VARIABLE || this.variableType == VariableType.CONSTANT) && this.getArr() != null) {
            return this.getArr().shape();
        }
        return null;
    }

    public void setShape(long ... shape) {
        this.shape = shape;
    }

    public long[] placeholderShape() {
        if (this.variableType != VariableType.PLACEHOLDER) {
            throw new IllegalStateException("placeholderShape() can only be used for placeholder variables: variable \"" + this.getVarName() + " is a variable of type " + this.variableType);
        }
        return this.shape;
    }

    public DataType dataType() {
        if (this.dataType == null) {
            this.dataType = this.variableType != VariableType.ARRAY && this.getArr() != null ? this.getArr().dataType() : DataType.UNKNOWN;
        }
        return this.dataType;
    }

    public LongShapeDescriptor getShapeDescriptor() {
        return LongShapeDescriptor.fromShape(this.getShape(), this.dataType());
    }

    public SDVariable castTo(@NonNull DataType dataType) {
        if (dataType == null) {
            throw new NullPointerException("dataType is marked non-null but is null");
        }
        return this.castTo(null, dataType);
    }

    public SDVariable castTo(String name, @NonNull DataType dataType) {
        if (dataType == null) {
            throw new NullPointerException("dataType is marked non-null but is null");
        }
        return this.sameDiff.castTo(name, this, dataType);
    }

    public SDVariable dup() {
        return this.sameDiff.var(this);
    }

    public SDVariable assign(Number value) {
        return this.sameDiff.scalarSet(this, value.doubleValue());
    }

    public SDVariable neg() {
        return this.sameDiff.math.neg(this);
    }

    public SDVariable neg(String name) {
        return this.sameDiff.math().neg(name, this);
    }

    public SDVariable lt(double value) {
        return this.lt(null, value);
    }

    public SDVariable lt(String name, double value) {
        return this.sameDiff.lt(name, this, value);
    }

    public SDVariable lte(double value) {
        return this.lte(null, value);
    }

    public SDVariable lte(String name, double value) {
        return this.sameDiff.lte(name, this, value);
    }

    public SDVariable gt(double value) {
        return this.gt(null, value);
    }

    public SDVariable gt(String name, double value) {
        return this.sameDiff.gt(name, this, value);
    }

    public SDVariable gte(double value) {
        return this.gte(null, value);
    }

    public SDVariable gte(String name, double value) {
        return this.sameDiff.gte(name, this, value);
    }

    public SDVariable eq(double value) {
        return this.eq(null, value);
    }

    public SDVariable eq(String name, double value) {
        return this.sameDiff.eq(name, this, value);
    }

    public SDVariable neq(double value) {
        return this.neq(null, value);
    }

    public SDVariable neq(String name, double value) {
        return this.sameDiff.neq(name, this, value);
    }

    public SDVariable lt(SDVariable other) {
        return this.lt(null, other);
    }

    public SDVariable lt(String name, SDVariable other) {
        return this.sameDiff.lt(name, this, other);
    }

    public SDVariable lte(SDVariable other) {
        return this.lte(null, other);
    }

    public SDVariable lte(String name, SDVariable other) {
        return this.sameDiff.lte(name, this, other);
    }

    public SDVariable gt(SDVariable other) {
        return this.gt(null, other);
    }

    public SDVariable gt(String name, SDVariable other) {
        return this.sameDiff.gt(name, this, other);
    }

    public SDVariable gte(SDVariable other) {
        return this.gte(null, other);
    }

    public SDVariable gte(String name, SDVariable other) {
        return this.sameDiff.gte(name, this, other);
    }

    public SDVariable eq(SDVariable other) {
        return this.eq(null, other);
    }

    public SDVariable eq(String name, SDVariable other) {
        return this.sameDiff.eq(name, this, other);
    }

    public SDVariable neq(SDVariable other) {
        return this.neq(null, other);
    }

    public SDVariable neq(String name, SDVariable other) {
        return this.sameDiff.neq(name, this, other);
    }

    public SDVariable mmul(SDVariable other) {
        return this.mmul(null, other);
    }

    public SDVariable mmul(String name, SDVariable other) {
        return this.sameDiff.mmul(name, this, other);
    }

    public SDVariable mmul(String name, SDVariable other, @NonNull MMulTranspose mMulTranspose) {
        if (mMulTranspose == null) {
            throw new NullPointerException("mMulTranspose is marked non-null but is null");
        }
        return this.sameDiff.mmul(name, this, other, mMulTranspose.isTransposeA(), mMulTranspose.isTransposeB(), mMulTranspose.isTransposeResult());
    }

    public SDVariable dot(SDVariable other, int ... dimensions) {
        return this.dot(null, other, dimensions);
    }

    public SDVariable dot(String name, SDVariable other, int ... dimensions) {
        return this.sameDiff.dot(name, this, other, dimensions);
    }

    public SDVariable add(double scalar) {
        return this.add(null, scalar);
    }

    public SDVariable add(String varName, double scalar) {
        SDVariable function = this.sameDiff.math.add(this, scalar);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable add(SDVariable other) {
        return this.add(null, other);
    }

    public SDVariable add(String name, SDVariable x) {
        SDVariable result = this.sameDiff.math.add(this, x);
        return this.sameDiff.updateVariableNameAndReference(result, name);
    }

    public SDVariable plus(SDVariable other) {
        return this.add(other);
    }

    public SDVariable plus(double other) {
        return this.add(other);
    }

    public SDVariable sub(double scalar) {
        return this.sub(null, scalar);
    }

    public SDVariable sub(String varName, double scalar) {
        SDVariable result = this.sameDiff.math.sub(this, scalar);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable sub(SDVariable x) {
        return this.sub(null, x);
    }

    public SDVariable sub(String name, SDVariable x) {
        SDVariable result = this.sameDiff.math.sub(this, x);
        return this.sameDiff.updateVariableNameAndReference(result, name);
    }

    public SDVariable minus(SDVariable other) {
        return this.sub(other);
    }

    public SDVariable minus(double other) {
        return this.sub(other);
    }

    public SDVariable div(double scalar) {
        return this.div(null, scalar);
    }

    public SDVariable div(String varName, double scalar) {
        SDVariable function = this.sameDiff.math.div(this, scalar);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable div(SDVariable x) {
        return this.div(null, x);
    }

    public SDVariable div(String name, SDVariable x) {
        SDVariable result = this.sameDiff.math.div(this, x);
        return this.sameDiff.updateVariableNameAndReference(result, name);
    }

    public SDVariable fdiv(String name, SDVariable x) {
        SDVariable result = this.sameDiff.math.floorDiv(this, x);
        return this.sameDiff.updateVariableNameAndReference(result, name);
    }

    public SDVariable mod(String name, SDVariable x) {
        SDVariable result = this.sameDiff.math.mod(this, x);
        return this.sameDiff.updateVariableNameAndReference(result, name);
    }

    public SDVariable mul(double scalar) {
        return this.mul(null, scalar);
    }

    public SDVariable mul(String varName, double scalar) {
        SDVariable function = this.sameDiff.math.mul(this, scalar);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable mul(SDVariable x) {
        return this.mul(null, x);
    }

    public SDVariable mul(String name, SDVariable x) {
        SDVariable result = this.sameDiff.math.mul(this, x);
        return this.sameDiff.updateVariableNameAndReference(result, name);
    }

    public SDVariable times(SDVariable other) {
        return this.mul(other);
    }

    public SDVariable times(double other) {
        return this.mul(other);
    }

    public SDVariable pow(double scalar) {
        return this.pow(null, scalar);
    }

    public SDVariable pow(String varName, double scalar) {
        SDVariable ret = this.sameDiff.math.pow(this, scalar);
        return this.sameDiff.updateVariableNameAndReference(ret, varName);
    }

    public SDVariable rsub(double scalar) {
        return this.rsub(null, scalar);
    }

    public SDVariable rsub(String varName, double scalar) {
        SDVariable function = this.sameDiff.math.rsub(this, scalar);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable rsub(SDVariable x) {
        return this.rsub(null, x);
    }

    public SDVariable rsub(String name, SDVariable x) {
        SDVariable result = this.sameDiff.math.rsub(this, x);
        return this.sameDiff.updateVariableNameAndReference(result, name);
    }

    public SDVariable rdiv(double scalar) {
        return this.rdiv(null, scalar);
    }

    public SDVariable rdiv(String varName, double scalar) {
        SDVariable function = this.sameDiff.math.rdiv(this, scalar);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable rdiv(SDVariable sameDiffVariable) {
        return this.rdiv(null, sameDiffVariable);
    }

    public SDVariable rdiv(String name, SDVariable x) {
        SDVariable result = this.sameDiff.math.rdiv(this, x);
        return this.sameDiff.updateVariableNameAndReference(result, name);
    }

    public SDVariable squaredDifference(SDVariable x) {
        return this.squaredDifference(null, x);
    }

    public SDVariable squaredDifference(String name, SDVariable x) {
        SDVariable result = this.sameDiff.math().squaredDifference(this, x);
        return this.sameDiff.updateVariableNameAndReference(result, name);
    }

    public SDVariable sum(int ... dimensions) {
        return this.sum(null, dimensions);
    }

    public SDVariable sum(boolean keepDims, int ... dimensions) {
        return this.sum(null, keepDims, dimensions);
    }

    public SDVariable sum(String name, int ... dimensions) {
        return this.sum(name, false, dimensions);
    }

    public SDVariable sum(String name, boolean keepDims, int ... dimensions) {
        return this.sameDiff.sum(name, this, keepDims, dimensions);
    }

    public SDVariable mean(boolean keepDims, int ... dimensions) {
        return this.mean(null, keepDims, dimensions);
    }

    public SDVariable mean(String name, int ... dimensions) {
        return this.mean(name, false, dimensions);
    }

    public SDVariable mean(int ... dimensions) {
        return this.mean(null, false, dimensions);
    }

    public SDVariable mean(String name, boolean keepDims, int ... dimensions) {
        return this.sameDiff.mean(name, this, keepDims, dimensions);
    }

    public SDVariable std(boolean biasCorrected, int ... dimensions) {
        return this.std(null, biasCorrected, dimensions);
    }

    public SDVariable std(String name, boolean biasCorrected, int ... dimensions) {
        return this.sameDiff.standardDeviation(name, this, biasCorrected, dimensions);
    }

    public SDVariable std(String name, boolean biasCorrected, boolean keepDims, int ... dimensions) {
        return this.sameDiff.standardDeviation(name, this, biasCorrected, keepDims, dimensions);
    }

    public SDVariable prod(int ... dimensions) {
        return this.prod(null, dimensions);
    }

    public SDVariable prod(boolean keepDims, int ... dimensions) {
        return this.prod(null, keepDims, dimensions);
    }

    public SDVariable prod(String name, int ... dimensions) {
        return this.sameDiff.prod(name, this, dimensions);
    }

    public SDVariable prod(String name, boolean keepDims, int ... dimensions) {
        return this.sameDiff.prod(name, this, keepDims, dimensions);
    }

    public SDVariable min(int ... dimensions) {
        return this.min(null, dimensions);
    }

    public SDVariable min(boolean keepDims, int ... dimensions) {
        return this.min(null, keepDims, dimensions);
    }

    public SDVariable min(String name, int ... dimensions) {
        return this.min(name, false, dimensions);
    }

    public SDVariable min(String name, boolean keepDims, int ... dimensions) {
        return this.sameDiff.min(name, this, keepDims, dimensions);
    }

    public SDVariable max(int ... dimensions) {
        return this.max(null, dimensions);
    }

    public SDVariable max(String name, int ... dimensions) {
        return this.max(name, false, dimensions);
    }

    public SDVariable max(boolean keepDims, int ... dimensions) {
        return this.max(null, keepDims, dimensions);
    }

    public SDVariable max(String name, boolean keepDims, int ... dimensions) {
        return this.sameDiff.max(name, this, keepDims, dimensions);
    }

    public SDVariable norm1(int ... dimensions) {
        return this.norm1(null, dimensions);
    }

    public SDVariable norm1(boolean keepDims, int ... dimensions) {
        return this.norm1(null, keepDims, dimensions);
    }

    public SDVariable norm1(String name, int ... dimensions) {
        return this.norm1(name, false, dimensions);
    }

    public SDVariable norm1(String name, boolean keepDims, int ... dimensions) {
        return this.sameDiff.norm1(name, this, keepDims, dimensions);
    }

    public SDVariable norm2(int ... dimensions) {
        return this.norm2(null, dimensions);
    }

    public SDVariable norm2(boolean keepDims, int ... dimensions) {
        return this.norm2(null, keepDims, dimensions);
    }

    public SDVariable norm2(String name, int ... dimensions) {
        return this.norm2(name, false, dimensions);
    }

    public SDVariable norm2(String name, boolean keepDims, int ... dimensions) {
        return this.sameDiff.norm2(name, this, keepDims, dimensions);
    }

    public SDVariable normmax(int ... dimensions) {
        return this.normmax(null, dimensions);
    }

    public SDVariable normmax(boolean keepDims, int ... dimensions) {
        return this.normmax(null, keepDims, dimensions);
    }

    public SDVariable normmax(String name, int ... dimensions) {
        return this.normmax(name, false, dimensions);
    }

    public SDVariable normmax(String name, boolean keepDims, int ... dimensions) {
        return this.sameDiff.normmax(name, this, keepDims, dimensions);
    }

    public SDVariable argmax(int ... dimensions) {
        return this.argmax(null, dimensions);
    }

    public SDVariable argmax(String name, int ... dimensions) {
        return this.sameDiff.argmax(name, this, dimensions);
    }

    public SDVariable argmax(String name, boolean keepDims, int ... dimensions) {
        return this.sameDiff.argmax(name, this, keepDims, dimensions);
    }

    public SDVariable argmin(int ... dimensions) {
        return this.argmin(null, dimensions);
    }

    public SDVariable argmin(String name, int ... dimensions) {
        return this.sameDiff.argmin(name, this, dimensions);
    }

    public SDVariable argmin(String name, boolean keepDims, int ... dimensions) {
        return this.sameDiff.argmax(name, this, keepDims, dimensions);
    }

    public SDVariable length() {
        return this.sameDiff.prod(this.shape(), new int[0]);
    }

    public SDVariable shape() {
        return this.sameDiff.shape(this);
    }

    public SDVariable rank() {
        return this.sameDiff.rank(this);
    }

    public SDVariable reshape(SDVariable newShape) {
        return this.sameDiff.reshape(this, newShape);
    }

    public SDVariable reshape(String name, SDVariable newShape) {
        return this.sameDiff.reshape(name, this, newShape);
    }

    public SDVariable reshape(int ... newShape) {
        return this.sameDiff.reshape(this, ArrayUtil.toLongArray((int[])newShape));
    }

    public SDVariable reshape(long ... newShape) {
        return this.sameDiff.reshape(this, newShape);
    }

    public SDVariable permute(int ... dimensions) {
        return this.sameDiff.permute(this, dimensions);
    }

    public SDVariable permute(SDVariable dimensions) {
        return this.sameDiff.permute(this, dimensions);
    }

    public SDVariable setArray(INDArray array) {
        this.sameDiff.associateArrayWithVariable(array, this);
        return this;
    }

    public INDArray eval() {
        Map<String, INDArray> m = this.sameDiff.output((Map<String, INDArray>)null, this.name());
        return m.get(this.name());
    }

    public INDArray eval(Map<String, INDArray> placeholders) {
        Map<String, INDArray> m = this.sameDiff.output(placeholders, this.name());
        return m.get(this.name());
    }

    public String toString() {
        return "SDVariable(name=\"" + this.varName + "\",variableType=" + this.variableType + ",dtype=" + this.dataType + (String)(this.variableType == VariableType.PLACEHOLDER && this.shape != null ? ",shape=" + Arrays.toString(this.shape) : "") + ")";
    }

    public void addControlDependency(SDVariable controlDependency) {
        Variable vThis = (Variable)this.sameDiff.getVariables().get((Object)this.getVarName());
        Variable vCD = (Variable)this.sameDiff.getVariables().get((Object)controlDependency.name());
        if (vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null) {
            SameDiffOp oThis = this.sameDiff.getOps().get(vThis.getOutputOfOp());
            SameDiffOp oCD = this.sameDiff.getOps().get(vCD.getOutputOfOp());
            if (oThis.getControlDeps() == null) {
                oThis.setControlDeps(new ArrayList<String>());
            }
            if (!oThis.getControlDeps().contains(oCD.getName())) {
                oThis.getControlDeps().add(oCD.getName());
            }
            if (oCD.getControlDepFor() == null) {
                oCD.setControlDepFor(new ArrayList<String>());
            }
            if (!oCD.getControlDepFor().contains(oThis.getName())) {
                oCD.getControlDepFor().add(oThis.getName());
            }
        } else if (vThis.getOutputOfOp() != null) {
            SameDiffOp oThis = this.sameDiff.getOps().get(vThis.getOutputOfOp());
            if (oThis.getVarControlDeps() == null) {
                oThis.setVarControlDeps(new ArrayList<String>());
            }
            if (!oThis.getVarControlDeps().contains(vCD.getName())) {
                oThis.getVarControlDeps().add(vCD.getName());
            }
            if (vCD.getControlDepsForOp() == null) {
                vCD.setControlDepsForOp(new ArrayList<String>());
            }
            if (!vCD.getControlDepsForOp().contains(oThis.getName())) {
                vCD.getControlDepsForOp().add(oThis.getName());
            }
        } else {
            if (vThis.getControlDeps() == null) {
                vThis.setControlDeps(new ArrayList<String>());
            }
            if (!vThis.getControlDeps().contains(vCD.getName())) {
                vThis.getControlDeps().add(vCD.getName());
            }
            if (vCD.getControlDepsForVar() == null) {
                vCD.setControlDepsForVar(new ArrayList<String>());
            }
            if (!vCD.getControlDepsForVar().contains(vThis.getName())) {
                vCD.getControlDepsForVar().add(vThis.getName());
            }
        }
    }

    public SDVariable getView(SDIndex ... indices) {
        SDVariable[] indicesVars = new SDVariable[indices.length];
        block7: for (int i = 0; i < indices.length; ++i) {
            switch (indices[i].getIndexType()) {
                case INTERVAL: {
                    indicesVars[i] = CreateView.createInterval(this.sameDiff, indices[i].getIntervalBegin(), indices[i].getIntervalEnd(), indices[i].getIntervalStrides(), indices[i].isInclusive() ? 1L : 0L);
                    continue block7;
                }
                case POINT: {
                    indicesVars[i] = CreateView.createPoint(this.sameDiff, indices[i].getPointIndex());
                    continue block7;
                }
                case POINT_INPUT: {
                    indicesVars[i] = CreateView.createPoint(this.sameDiff, indices[i].getPointVar());
                    continue block7;
                }
                case INTERVAL_INPUT: {
                    indicesVars[i] = CreateView.createInterval(this.sameDiff, indices[i].getIntervalInputBegin(), indices[i].getIntervalInputEnd(), indices[i].getIntervalStrideInput(), indices[i].getInclusiveInput());
                    continue block7;
                }
                case ALL: {
                    indicesVars[i] = CreateView.createAll(this.sameDiff);
                    continue block7;
                }
                default: {
                    throw new IllegalArgumentException("Illegal type " + indices[i].getIndexType());
                }
            }
        }
        return this.sameDiff.createView(this, indicesVars);
    }

    public SDVariable get(SDIndex ... indices) {
        int ndims = indices.length;
        boolean variableIndices = false;
        SDIndex[] inputIndices = Arrays.copyOf(indices, indices.length);
        indices = inputIndices;
        block4: for (int i = 0; i < indices.length; ++i) {
            if (indices[i].getIndexType() == SDIndex.IndexType.POINT_INPUT || indices[i].getIndexType() == SDIndex.IndexType.INTERVAL_INPUT) {
                variableIndices = true;
            }
            if (!variableIndices || indices[i].getIndexType() != SDIndex.IndexType.INTERVAL && indices[i].getIndexType() != SDIndex.IndexType.POINT) continue;
            switch (indices[i].getIndexType()) {
                case INTERVAL: {
                    indices[i] = SDIndex.interval(this.sameDiff.constant(indices[i].getIntervalBegin()), this.sameDiff.constant(indices[i].getIntervalEnd()), this.sameDiff.constant(indices[i].getIntervalEnd()));
                    continue block4;
                }
                case POINT: {
                    indices[i] = SDIndex.point(this.sameDiff.constant(indices[i].getPointIndex()), indices[i].isPointKeepDim());
                }
            }
        }
        long[] begin = new long[ndims];
        long[] end = new long[ndims];
        long[] strides = new long[ndims];
        int[] begin_mask_arr = new int[ndims];
        int[] end_mask_arr = new int[ndims];
        int[] shrink_axis_mask_arr = new int[ndims];
        SDVariable beginVar = null;
        SDVariable endVar = null;
        SDVariable stridesVar = null;
        for (int i = 0; i < ndims; ++i) {
            strides[i] = 1L;
            SDIndex index = indices[i];
            SDIndex.IndexType indexType = index.getIndexType();
            if (indexType == SDIndex.IndexType.ALL) {
                begin_mask_arr[i] = 1;
                end_mask_arr[i] = 1;
                continue;
            }
            if (indexType == SDIndex.IndexType.POINT || indexType == SDIndex.IndexType.POINT_INPUT) {
                if (indexType == SDIndex.IndexType.POINT) {
                    long pointIndex;
                    begin[i] = pointIndex = index.getPointIndex();
                    end[i] = pointIndex + 1L;
                } else if (indexType == SDIndex.IndexType.POINT_INPUT) {
                    if (beginVar == null && endVar == null) {
                        beginVar = index.getPointVar();
                        endVar = index.getPointVar().add(1.0);
                    } else {
                        beginVar = this.sameDiff.concat(0, beginVar, index.getPointVar());
                        endVar = this.sameDiff.concat(0, endVar, index.getPointVar().add(1.0));
                    }
                }
                if (index.isPointKeepDim()) continue;
                shrink_axis_mask_arr[i] = 1;
                continue;
            }
            if (indexType != SDIndex.IndexType.INTERVAL && indexType != SDIndex.IndexType.INTERVAL_INPUT) continue;
            if (index.getIntervalBegin() == null && indexType != SDIndex.IndexType.INTERVAL_INPUT) {
                begin_mask_arr[i] = 1;
            } else if (indexType == SDIndex.IndexType.INTERVAL_INPUT) {
                beginVar = beginVar == null ? index.getIntervalInputBegin() : this.sameDiff.concat(0, beginVar, index.getIntervalInputBegin());
            } else {
                begin[i] = index.getIntervalBegin();
            }
            if (index.getIntervalEnd() == null && indexType != SDIndex.IndexType.INTERVAL_INPUT) {
                end_mask_arr[i] = 1;
            } else if (indexType == SDIndex.IndexType.INTERVAL_INPUT) {
                endVar = endVar == null ? index.getIntervalInputEnd() : this.sameDiff.concat(0, endVar, index.getIntervalInputEnd());
            } else {
                end[i] = index.getIntervalEnd();
            }
            if (index.getIntervalStrides() == null) {
                strides[i] = 1L;
                if (stridesVar != null) {
                    stridesVar = this.sameDiff.concat(0, stridesVar, this.sameDiff.constant(1).reshape(1));
                    continue;
                }
                stridesVar = this.sameDiff.constant(1).reshape(1);
                continue;
            }
            strides[i] = index.getIntervalStrides();
            stridesVar = stridesVar != null ? this.sameDiff.concat(0, stridesVar, index.getIntervalStrideInput()) : index.getIntervalStrideInput();
        }
        int begin_mask = SDVariable.binArrToInt(begin_mask_arr);
        int end_mask = SDVariable.binArrToInt(end_mask_arr);
        int shrink_axis = SDVariable.binArrToInt(shrink_axis_mask_arr);
        if (variableIndices) {
            if (stridesVar == null) {
                stridesVar = this.sameDiff.onesLike(beginVar);
            }
            return this.sameDiff.stridedSlice(this, beginVar, endVar, stridesVar, begin_mask, end_mask, 0, 0, shrink_axis);
        }
        return this.sameDiff.stridedSlice(this, begin, end, strides, begin_mask, end_mask, 0, 0, shrink_axis);
    }

    public static SDVariable sliceEnd(SDVariable input, SDVariable sliceIndexInput) {
        SameDiff sameDiff = input.getSameDiff();
        SDVariable range = sameDiff.range(sameDiff.constant(0), input.rank(), sameDiff.constant(1), DataType.INT64);
        SDVariable mask = range.gt(0.0).castTo(DataType.INT64);
        SDVariable sliceMask = range.eq(0.0).castTo(DataType.INT64);
        SDVariable sliceIndex = sliceMask.mul(sliceIndexInput);
        SDVariable outputShape = input.shape().mul(mask).add(sliceIndex);
        return outputShape;
    }

    public SDVariable get(SDVariable indices) {
        SDVariable initialSize = this.sameDiff.zerosLike(this.shape()).castTo(DataType.INT64);
        SDVariable startResult = this.sameDiff.slice(this, initialSize.castTo(DataType.INT64), SDVariable.sliceEnd(this, this.sameDiff.onesLike(this.shape()).castTo(DataType.INT64)));
        SDVariable currIteration = this.sameDiff.var(Nd4j.ones(1).castTo(DataType.INT32));
        SDVariable cond = this.sameDiff.constant("curr_cond", true);
        SDVariable indicesLength = indices.length();
        SameDiff loop = SDVariable.createLoopConcat(this, indices);
        return this.sameDiff.loopWithConditions(ControlFlow.LoopParams.builder().functionBody(loop).loopVars(new SDVariable[]{currIteration, indicesLength, cond, startResult, this, indices}).functionBodyInputs(new String[]{"index", "max", "cond", "input", "pullFrom", "indices"}).functionBodyOutputs(new String[]{"index", "max", "cond", "output", "pullFrom", "indices"}).functionName("slices").loopName("outputs").build())[3];
    }

    public SDVariable put(SDVariable indices, SDVariable toPut, SDVariable putIndices) {
        SDVariable currIteration = this.sameDiff.var(Nd4j.zeros(1).castTo(DataType.INT32));
        SDVariable cond = this.sameDiff.constant(true);
        SDVariable indicesLength = indices.length();
        SameDiff loop = SDVariable.createLoopPut(this, indices);
        loop.setEnableCache(false);
        return this.sameDiff.loopWithConditions(ControlFlow.LoopParams.builder().functionBody(loop).loopVars(new SDVariable[]{currIteration, indicesLength, cond, this, toPut, indices, putIndices}).functionBodyInputs(new String[]{"index", "max", "cond", "assignTo", "toPut", "indices", "indicesPut"}).functionBodyOutputs(new String[]{"index", "max", "cond", "assignOutput", "toPut", "indices", "indicesPut"}).functionName("sliceputs").loopName("outputs").build())[3];
    }

    public static SameDiff createLoopPut(SDVariable relative, SDVariable indices) {
        SameDiff loop = SameDiff.create();
        SDVariable index = loop.placeHolder("index", DataType.INT32, new long[0]);
        SDVariable maxIndex = loop.placeHolder("max", DataType.INT32, new long[0]);
        SDVariable currCondition = loop.placeHolder("cond", DataType.BOOL, new long[0]);
        SDVariable assignTo = loop.placeHolder("assignTo", relative.dataType(), new long[0]);
        SDVariable toPut = loop.placeHolder("toPut", relative.dataType(), new long[0]);
        SDVariable indicesLoop = loop.placeHolder("indices", indices.dataType(), new long[0]);
        indicesLoop = indicesLoop.reshape("indicesReshape", indicesLoop.length());
        SDVariable indicesPut = loop.placeHolder("indicesPut", indices.dataType(), new long[0]);
        indicesPut = indicesPut.reshape("indicesPutReshape", indicesPut.length());
        SDVariable indexToRetrieve = indicesLoop.getView(SDIndex.point(index)).reshape(1).castTo("indexToReceive", DataType.INT64);
        SDVariable indexToPut = indicesPut.getView(SDIndex.point(index)).reshape(1).castTo("indexToPut", DataType.INT64);
        SDVariable toAssign = toPut.getView(SDIndex.point(indexToPut));
        SDVariable sliceOutput = assignTo.getView(SDIndex.point(indexToRetrieve));
        SDVariable assignOutput = loop.assign(sliceOutput, toAssign);
        SDVariable outputIdentity = loop.identity("assignOutput", assignTo);
        outputIdentity.addControlDependency(assignOutput);
        return loop;
    }

    public static SameDiff createLoopConcat(SDVariable relative, SDVariable indices) {
        SameDiff loop = SameDiff.create();
        SDVariable index = loop.placeHolder("index", DataType.INT32, new long[0]);
        SDVariable maxIndex = loop.placeHolder("max", DataType.INT32, new long[0]);
        SDVariable currCondition = loop.placeHolder("cond", DataType.BOOL, new long[0]);
        SDVariable input = loop.placeHolder("input", relative.dataType(), new long[0]);
        SDVariable pullFrom = loop.placeHolder("pullFrom", relative.dataType(), new long[0]);
        SDVariable indicesLoop = loop.placeHolder("indices", indices.dataType(), new long[0]);
        indicesLoop = indicesLoop.reshape("indicesReshape", indicesLoop.length());
        SDVariable indexToRetrieve = indicesLoop.get(SDIndex.point(index)).reshape(1).castTo("indexToReceive", DataType.INT64);
        SDVariable sliceOutput = loop.expandDims("outputSlice", pullFrom.get(SDIndex.point(indexToRetrieve)), 0);
        SDVariable output = loop.concat("output", 0, input, sliceOutput);
        return loop;
    }

    public SDVariable convertToConstant() {
        return this.sameDiff.convertToConstant(this);
    }

    public SDVariable convertToVariable() {
        return this.sameDiff.convertToVariable(this);
    }

    public SDVariable rename(String newName) {
        this.sameDiff.renameVariable(this.getVarName(), newName);
        return this;
    }

    public void markAsLoss() {
        this.sameDiff.addLossVariable(this.getVarName());
    }

    public boolean hasGradient() {
        return this.sameDiff.variableHasGradient(this.getVarName());
    }

    private static int binArrToInt(int[] arr) {
        int x = 0;
        int m = 1;
        for (int i = 0; i < arr.length; ++i) {
            if (arr[i] == 1) {
                x += m;
            }
            m *= 2;
        }
        return x;
    }

    public int hashCode() {
        int result = super.hashCode();
        result = 31 * result + (this.varName != null ? this.varName.hashCode() : 0);
        result = 31 * result + (this.variableType != null ? this.variableType.hashCode() : 0);
        result = 31 * result + (this.dataType != null ? this.dataType.hashCode() : 0);
        return result;
    }

    public SDVariable clone(String name, SameDiff sd) {
        SDVariable v = new SDVariable();
        v.varName = name;
        v.variableType = this.variableType;
        v.shape = this.shape == null ? null : (long[])this.shape.clone();
        v.dataType = this.dataType;
        v.sameDiff = sd;
        return v;
    }

    public SDVariable clone(SameDiff sd) {
        SDVariable v = new SDVariable();
        v.varName = this.varName;
        v.variableType = this.variableType;
        v.shape = this.shape == null ? null : (long[])this.shape.clone();
        v.dataType = this.dataType;
        v.sameDiff = sd;
        return v;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof SDVariable)) {
            return false;
        }
        SDVariable s = (SDVariable)o;
        if (!this.varName.equals(s.varName)) {
            return false;
        }
        if (this.variableType != s.variableType) {
            return false;
        }
        if (this.dataType != s.dataType) {
            return false;
        }
        if (this.variableType == VariableType.VARIABLE || this.variableType == VariableType.CONSTANT) {
            INDArray a1 = this.getArr();
            INDArray a2 = s.getArr();
            return a1.equals(a2);
        }
        return true;
    }

    public SameDiff getSameDiff() {
        return this.sameDiff;
    }

    public DifferentialFunction getCreator() {
        return this.creator;
    }

    public void setSameDiff(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    public void setCreator(DifferentialFunction creator) {
        this.creator = creator;
    }

    public SDVariable() {
    }

    public VariableType getVariableType() {
        return this.variableType;
    }

    public void setVariableType(VariableType variableType) {
        this.variableType = variableType;
    }

    public void setDataType(DataType dataType) {
        this.dataType = dataType;
    }
}

