/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api;

import java.util.Arrays;
import lombok.NonNull;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataSetUtil {
    private static final Logger log = LoggerFactory.getLogger(DataSetUtil.class);

    public static INDArray tailor2d(@NonNull DataSet dataSet, boolean areFeatures) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        return DataSetUtil.tailor2d(areFeatures ? dataSet.getFeatures() : dataSet.getLabels(), areFeatures ? dataSet.getFeaturesMaskArray() : dataSet.getLabelsMaskArray());
    }

    public static INDArray tailor2d(@NonNull INDArray data, INDArray mask) {
        if (data == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        switch (data.rank()) {
            case 1: 
            case 2: {
                return data;
            }
            case 3: {
                return DataSetUtil.tailor3d2d(data, mask);
            }
            case 4: {
                return DataSetUtil.tailor4d2d(data);
            }
        }
        throw new RuntimeException("Unsupported data rank");
    }

    public static INDArray tailor3d2d(DataSet dataset, boolean areFeatures) {
        INDArray data = areFeatures ? dataset.getFeatures() : dataset.getLabels();
        INDArray mask = areFeatures ? dataset.getFeaturesMaskArray() : dataset.getLabelsMaskArray();
        return DataSetUtil.tailor3d2d(data, mask);
    }

    public static INDArray tailor3d2d(@NonNull INDArray data, INDArray mask) {
        INDArray as2d;
        long[] shape;
        if (data == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        if (mask != null && (data.size(0) != mask.size(0) || data.size(2) != mask.size(1))) {
            throw new IllegalArgumentException("Invalid mask array/data combination: got data with shape [minibatch, vectorSize, timeSeriesLength] = " + Arrays.toString(data.shape()) + "; got mask with shape [minibatch,timeSeriesLength] = " + Arrays.toString(mask.shape()) + "; minibatch and timeSeriesLength dimensions must match");
        }
        if (data.ordering() != 'f' || data.isView() || !Shape.strideDescendingCAscendingF(data)) {
            data = data.dup('f');
        }
        if ((shape = data.shape())[0] == 1L) {
            as2d = data.tensorAlongDimension(0L, 1, 2).permutei(1, 0);
        } else if (shape[2] == 1L) {
            as2d = data.tensorAlongDimension(0L, 1, 0);
        } else {
            INDArray permuted = data.permute(0, 2, 1);
            as2d = permuted.reshape('f', shape[0] * shape[2], shape[1]);
        }
        if (mask == null) {
            return as2d;
        }
        if (mask.ordering() != 'f' || mask.isView() || !Shape.strideDescendingCAscendingF(mask)) {
            mask = mask.dup('f');
        }
        INDArray mask1d = mask.reshape('f', mask.length(), 1L);
        int numElements = mask.sumNumber().intValue();
        if ((long)numElements == mask.length()) {
            return as2d;
        }
        if (numElements == 0) {
            return null;
        }
        int[] rowsToPull = new int[numElements];
        float[] floatMask1d = mask1d.data().asFloat();
        int currCount = 0;
        for (int i = 0; i < floatMask1d.length; ++i) {
            if (floatMask1d[i] == 0.0f) continue;
            rowsToPull[currCount++] = i;
        }
        INDArray subset = Nd4j.pullRows(as2d, 1, rowsToPull);
        return subset;
    }

    public static INDArray tailor4d2d(DataSet dataset, boolean areFeatures) {
        return DataSetUtil.tailor4d2d(areFeatures ? dataset.getFeatures() : dataset.getLabels());
    }

    public static INDArray tailor4d2d(@NonNull INDArray data) {
        if (data == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        long instances = data.size(0);
        long channels = data.size(1);
        long height = data.size(2);
        long width = data.size(3);
        INDArray in2d = Nd4j.create(channels, height * width * instances);
        long tads = data.tensorsAlongDimension(3, 2, 0);
        int i = 0;
        while ((long)i < tads) {
            INDArray thisTAD = data.tensorAlongDimension(i, 3, 2, 0);
            in2d.putRow(i, Nd4j.toFlattened(thisTAD));
            ++i;
        }
        return in2d.transposei();
    }

    public static void setMaskedValuesToZero(INDArray data, INDArray mask) {
        if (mask == null || data.rank() != 3) {
            return;
        }
        Nd4j.getExecutioner().exec(new BroadcastMulOp(data, mask, data, 0, 2));
    }

    public static Pair<INDArray[], INDArray[]> mergeFeatures(@NonNull INDArray[][] featuresToMerge, INDArray[][] featureMasksToMerge) {
        if (featuresToMerge == null) {
            throw new NullPointerException("featuresToMerge is marked non-null but is null");
        }
        int nInArrs = featuresToMerge[0].length;
        INDArray[] outF = new INDArray[nInArrs];
        INDArray[] outM = null;
        for (int i = 0; i < nInArrs; ++i) {
            Pair<INDArray, INDArray> p = DataSetUtil.mergeFeatures(featuresToMerge, featureMasksToMerge, i);
            outF[i] = (INDArray)p.getFirst();
            if (p.getSecond() == null) continue;
            if (outM == null) {
                outM = new INDArray[nInArrs];
            }
            outM[i] = (INDArray)p.getSecond();
        }
        return new Pair((Object)outF, outM);
    }

    public static Pair<INDArray, INDArray> mergeFeatures(@NonNull INDArray[] featuresToMerge, INDArray[] featureMasksToMerge) {
        if (featuresToMerge == null) {
            throw new NullPointerException("featuresToMerge is marked non-null but is null");
        }
        Preconditions.checkNotNull((Object)featuresToMerge[0], (String)"Encountered null feature array when merging");
        int rankFeatures = featuresToMerge[0].rank();
        switch (rankFeatures) {
            case 2: {
                return DataSetUtil.merge2d(featuresToMerge, featureMasksToMerge);
            }
            case 3: {
                return DataSetUtil.mergeTimeSeries(featuresToMerge, featureMasksToMerge);
            }
            case 4: {
                return DataSetUtil.merge4d(featuresToMerge, featureMasksToMerge);
            }
        }
        throw new IllegalStateException("Cannot merge examples: features rank must be in range 2 to 4 inclusive. First example features shape: " + Arrays.toString(featuresToMerge[0].shape()));
    }

    public static Pair<INDArray, INDArray> mergeFeatures(INDArray[][] featuresToMerge, INDArray[][] featureMasksToMerge, int inOutIdx) {
        Pair<INDArray[], INDArray[]> p = DataSetUtil.selectColumnFromMDSData(featuresToMerge, featureMasksToMerge, inOutIdx);
        return DataSetUtil.mergeFeatures((INDArray[])p.getFirst(), (INDArray[])p.getSecond());
    }

    public static Pair<INDArray, INDArray> mergeLabels(INDArray[] labelsToMerge, INDArray[] labelMasksToMerge) {
        Preconditions.checkNotNull((Object)labelsToMerge[0], (String)"Cannot merge data: Encountered null labels array");
        int rankFeatures = labelsToMerge[0].rank();
        switch (rankFeatures) {
            case 2: {
                return DataSetUtil.merge2d(labelsToMerge, labelMasksToMerge);
            }
            case 3: {
                return DataSetUtil.mergeTimeSeries(labelsToMerge, labelMasksToMerge);
            }
            case 4: {
                return DataSetUtil.merge4d(labelsToMerge, labelMasksToMerge);
            }
        }
        throw new ND4JIllegalStateException("Cannot merge examples: labels rank must be in range 2 to 4 inclusive. First example features shape: " + Arrays.toString(labelsToMerge[0].shape()));
    }

    public static Pair<INDArray, INDArray> mergeLabels(@NonNull INDArray[][] labelsToMerge, INDArray[][] labelMasksToMerge, int inOutIdx) {
        if (labelsToMerge == null) {
            throw new NullPointerException("labelsToMerge is marked non-null but is null");
        }
        Pair<INDArray[], INDArray[]> p = DataSetUtil.selectColumnFromMDSData(labelsToMerge, labelMasksToMerge, inOutIdx);
        return DataSetUtil.mergeLabels((INDArray[])p.getFirst(), (INDArray[])p.getSecond());
    }

    private static Pair<INDArray[], INDArray[]> selectColumnFromMDSData(@NonNull INDArray[][] arrays, INDArray[][] masks, int inOutIdx) {
        if (arrays == null) {
            throw new NullPointerException("arrays is marked non-null but is null");
        }
        INDArray[] a = new INDArray[arrays.length];
        INDArray[] m = new INDArray[a.length];
        for (int i = 0; i < a.length; ++i) {
            a[i] = arrays[i][inOutIdx];
            if (masks == null || masks[i] == null) continue;
            m[i] = masks[i][inOutIdx];
        }
        return new Pair((Object)a, (Object)m);
    }

    public static Pair<INDArray, INDArray> merge2d(@NonNull INDArray[][] arrays, INDArray[][] masks, int inOutIdx) {
        if (arrays == null) {
            throw new NullPointerException("arrays is marked non-null but is null");
        }
        Pair<INDArray[], INDArray[]> p = DataSetUtil.selectColumnFromMDSData(arrays, masks, inOutIdx);
        return DataSetUtil.merge2d((INDArray[])p.getFirst(), (INDArray[])p.getSecond());
    }

    public static Pair<INDArray, INDArray> merge2d(INDArray[] arrays, INDArray[] masks) {
        long cols = arrays[0].columns();
        INDArray[] temp = new INDArray[arrays.length];
        boolean hasMasks = false;
        for (int i = 0; i < arrays.length; ++i) {
            Preconditions.checkNotNull((Object)arrays[i], (String)"Encountered null array at position %s when merging data", (int)i);
            if ((long)arrays[i].columns() != cols) {
                throw new IllegalStateException("Cannot merge 2d arrays with different numbers of columns (firstNCols=" + cols + ", ithNCols=" + arrays[i].columns() + ")");
            }
            temp[i] = arrays[i];
            if (masks == null || masks[i] == null || masks[i] == null) continue;
            hasMasks = true;
        }
        INDArray out = Nd4j.specialConcat(0, temp);
        INDArray outMask = null;
        if (hasMasks) {
            outMask = DataSetUtil.mergePerOutputMasks2d(out.shape(), arrays, masks);
        }
        return new Pair((Object)out, outMask);
    }

    public static INDArray mergePerOutputMasks2d(long[] outShape, INDArray[][] arrays, INDArray[][] masks, int inOutIdx) {
        Pair<INDArray[], INDArray[]> p = DataSetUtil.selectColumnFromMDSData(arrays, masks, inOutIdx);
        return DataSetUtil.mergePerOutputMasks2d(outShape, (INDArray[])p.getFirst(), (INDArray[])p.getSecond());
    }

    @Deprecated
    public static INDArray mergePerOutputMasks2d(long[] outShape, INDArray[] arrays, INDArray[] masks) {
        return DataSetUtil.mergeMasks2d(outShape, arrays, masks);
    }

    public static INDArray mergeMasks2d(long[] outShape, INDArray[] arrays, INDArray[] masks) {
        long[] numExamplesPerArr = new long[arrays.length];
        for (int i = 0; i < numExamplesPerArr.length; ++i) {
            numExamplesPerArr[i] = arrays[i].size(0);
        }
        INDArray outMask = Nd4j.ones(arrays[0].dataType(), outShape);
        int rowsSoFar = 0;
        for (int i = 0; i < masks.length; ++i) {
            long thisRows = numExamplesPerArr[i];
            if (masks[i] == null) continue;
            outMask.put(new INDArrayIndex[]{NDArrayIndex.interval((long)rowsSoFar, (long)rowsSoFar + thisRows), NDArrayIndex.all()}, masks[i]);
            rowsSoFar = (int)((long)rowsSoFar + thisRows);
        }
        return outMask;
    }

    public static INDArray mergeMasks4d(INDArray[] featuresOrLabels, INDArray[] masks) {
        long[] outShape = null;
        long mbCountNoMask = 0L;
        for (int i = 0; i < masks.length; ++i) {
            if (masks[i] == null) {
                mbCountNoMask += featuresOrLabels[i].size(0);
                continue;
            }
            if (masks[i].rank() != 4) {
                throw new IllegalStateException("Cannot merge mask arrays: expected mask array of rank 4. Got mask array of rank " + masks[i].rank() + " with shape " + Arrays.toString(masks[i].shape()));
            }
            if (outShape == null) {
                outShape = (long[])masks[i].shape().clone();
                continue;
            }
            INDArray m = masks[i];
            if (m.size(1) != outShape[1] || m.size(2) != outShape[2] || m.size(3) != outShape[3]) {
                throw new IllegalStateException("Mismatched mask shapes: masks should have same depth/height/width for all examples. Prior examples had shape [mb," + masks[1] + "," + masks[2] + "," + masks[3] + "], next example has shape " + Arrays.toString(m.shape()));
            }
            outShape[0] = outShape[0] + m.size(0);
        }
        if (outShape == null) {
            return null;
        }
        outShape[0] = outShape[0] + mbCountNoMask;
        INDArray outMask = Nd4j.ones(outShape);
        int exSoFar = 0;
        for (int i = 0; i < masks.length; ++i) {
            if (masks[i] == null) {
                exSoFar = (int)((long)exSoFar + featuresOrLabels[i].size(0));
                continue;
            }
            long nEx = masks[i].size(0);
            outMask.put(new INDArrayIndex[]{NDArrayIndex.interval((long)exSoFar, (long)exSoFar + nEx), NDArrayIndex.all()}, masks[i]);
            exSoFar = (int)((long)exSoFar + nEx);
        }
        return outMask;
    }

    public static Pair<INDArray, INDArray> mergeTimeSeries(INDArray[][] arrays, INDArray[][] masks, int inOutIdx) {
        Pair<INDArray[], INDArray[]> p = DataSetUtil.selectColumnFromMDSData(arrays, masks, inOutIdx);
        return DataSetUtil.mergeTimeSeries((INDArray[])p.getFirst(), (INDArray[])p.getSecond());
    }

    public static Pair<INDArray, INDArray> mergeTimeSeries(INDArray[] arrays, INDArray[] masks) {
        long firstLength = arrays[0].size(2);
        long size = arrays[0].size(1);
        long maxLength = firstLength;
        boolean hasMask = false;
        int maskRank = -1;
        boolean lengthsDiffer = false;
        int totalExamples = 0;
        for (int i = 0; i < arrays.length; ++i) {
            totalExamples = (int)((long)totalExamples + arrays[i].size(0));
            long thisLength = arrays[i].size(2);
            maxLength = Math.max(maxLength, thisLength);
            if (thisLength != firstLength) {
                lengthsDiffer = true;
            }
            if (masks != null && masks[i] != null && masks[i] != null) {
                maskRank = masks[i].rank();
                hasMask = true;
            }
            if (arrays[i].size(1) == size) continue;
            throw new IllegalStateException("Cannot merge time series with different size for dimension 1 (first shape: " + Arrays.toString(arrays[0].shape()) + ", " + i + "th shape: " + Arrays.toString(arrays[i].shape()));
        }
        boolean needMask = hasMask || lengthsDiffer;
        INDArray arr = Nd4j.create(arrays[0].dataType(), totalExamples, size, maxLength);
        INDArray mask = needMask && maskRank != 3 ? Nd4j.ones(arrays[0].dataType(), totalExamples, maxLength) : null;
        int examplesSoFar = 0;
        if (!lengthsDiffer && !needMask) {
            for (int i = 0; i < arrays.length; ++i) {
                long thisNExamples = arrays[i].size(0);
                arr.put(new INDArrayIndex[]{NDArrayIndex.interval((long)examplesSoFar, (long)examplesSoFar + thisNExamples), NDArrayIndex.all(), NDArrayIndex.all()}, arrays[i]);
                examplesSoFar = (int)((long)examplesSoFar + thisNExamples);
            }
            return new Pair((Object)arr, null);
        }
        if (lengthsDiffer && !hasMask || maskRank == 2) {
            for (int i = 0; i < arrays.length; ++i) {
                INDArray a = arrays[i];
                long thisNExamples = a.size(0);
                long thisLength = a.size(2);
                arr.put(new INDArrayIndex[]{NDArrayIndex.interval((long)examplesSoFar, (long)examplesSoFar + thisNExamples), NDArrayIndex.all(), NDArrayIndex.interval(0L, thisLength)}, a);
                if (masks != null && masks[i] != null && masks[i] != null) {
                    INDArray origMask = masks[i];
                    long maskLength = origMask.size(1);
                    mask.put(new INDArrayIndex[]{NDArrayIndex.interval((long)examplesSoFar, (long)examplesSoFar + thisNExamples), NDArrayIndex.interval(0L, maskLength)}, origMask);
                    if (maskLength < maxLength) {
                        mask.put(new INDArrayIndex[]{NDArrayIndex.interval((long)examplesSoFar, (long)examplesSoFar + thisNExamples), NDArrayIndex.interval(maskLength, maxLength)}, Nd4j.zeros(thisNExamples, maxLength - maskLength));
                    }
                } else if (thisLength < maxLength) {
                    mask.put(new INDArrayIndex[]{NDArrayIndex.interval((long)examplesSoFar, (long)examplesSoFar + thisNExamples), NDArrayIndex.interval(thisLength, maxLength)}, Nd4j.zeros(thisNExamples, maxLength - thisLength));
                }
                examplesSoFar = (int)((long)examplesSoFar + thisNExamples);
            }
        } else if (maskRank == 3) {
            mask = Nd4j.create(arr.dataType(), arr.shape());
            for (int i = 0; i < arrays.length; ++i) {
                INDArray m = masks[i];
                INDArray a = arrays[i];
                long thisNExamples = a.size(0);
                long thisLength = a.size(2);
                arr.put(new INDArrayIndex[]{NDArrayIndex.interval((long)examplesSoFar, (long)examplesSoFar + thisNExamples), NDArrayIndex.all(), NDArrayIndex.interval(0L, thisLength)}, a);
                if (m == null) {
                    mask.get(NDArrayIndex.interval((long)examplesSoFar, (long)examplesSoFar + thisNExamples), NDArrayIndex.all(), NDArrayIndex.interval(0L, thisLength)).assign(1);
                } else {
                    mask.put(new INDArrayIndex[]{NDArrayIndex.interval((long)examplesSoFar, (long)examplesSoFar + thisNExamples), NDArrayIndex.all(), NDArrayIndex.interval(0L, thisLength)}, m);
                }
                examplesSoFar = (int)((long)examplesSoFar + thisNExamples);
            }
        } else {
            throw new UnsupportedOperationException("Cannot merge time series with mask rank " + maskRank);
        }
        return new Pair((Object)arr, (Object)mask);
    }

    public static Pair<INDArray, INDArray> merge4d(INDArray[][] arrays, INDArray[][] masks, int inOutIdx) {
        Pair<INDArray[], INDArray[]> p = DataSetUtil.selectColumnFromMDSData(arrays, masks, inOutIdx);
        return DataSetUtil.merge4d((INDArray[])p.getFirst(), (INDArray[])p.getSecond());
    }

    public static Pair<INDArray, INDArray> merge4d(INDArray[] arrays, INDArray[] masks) {
        int nExamples = 0;
        long[] shape = arrays[0].shape();
        INDArray[] temp = new INDArray[arrays.length];
        boolean hasMasks = false;
        int maskRank = -1;
        for (int i = 0; i < arrays.length; ++i) {
            Preconditions.checkNotNull((Object)arrays[i], (String)"Encountered null array when merging data at position %s", (int)i);
            nExamples = (int)((long)nExamples + arrays[i].size(0));
            long[] thisShape = arrays[i].shape();
            if (thisShape.length != 4) {
                throw new IllegalStateException("Cannot merge 4d arrays with non 4d arrays");
            }
            for (int j = 1; j < 4; ++j) {
                if (thisShape[j] == shape[j]) continue;
                throw new IllegalStateException("Cannot merge 4d arrays with different shape (other than # examples):  data[0].shape = " + Arrays.toString(shape) + ", data[" + i + "].shape = " + Arrays.toString(thisShape));
            }
            temp[i] = arrays[i];
            if (masks == null || masks[i] == null) continue;
            hasMasks = true;
            maskRank = masks[i].rank();
        }
        INDArray out = Nd4j.specialConcat(0, temp);
        INDArray outMask = null;
        if (hasMasks) {
            if (maskRank == 2) {
                outMask = DataSetUtil.mergeMasks2d(out.shape(), arrays, masks);
            } else if (maskRank == 4) {
                outMask = DataSetUtil.mergeMasks4d(arrays, masks);
            }
        }
        return new Pair((Object)out, outMask);
    }
}

