/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.config.ExecutionResult;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.ReductionShape;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JException;
import org.nd4j.linalg.factory.Nd4j;

public class SameDiffUtils {
    public static Map<String, INDArray> stackOutputs(List<ExecutionResult> outputs) {
        HashMap outs = new HashMap();
        for (ExecutionResult batch : outputs) {
            if (batch.getOutputs() != null) {
                for (String k : batch.getOutputs().keySet()) {
                    if (!outs.containsKey(k)) {
                        outs.put(k, new ArrayList());
                    }
                    ((List)outs.get(k)).add(batch.getOutputs().get(k).get());
                }
                continue;
            }
            if (batch.getValueOutputs() == null) continue;
            for (String k : batch.getValueOutputs().keySet()) {
                if (!outs.containsKey(k)) {
                    outs.put(k, new ArrayList());
                }
                ((List)outs.get(k)).add(batch.getValueOutputs().get(k).getTensorValue());
            }
        }
        HashMap<String, INDArray> ret = new HashMap<String, INDArray>();
        for (String k : outs.keySet()) {
            try {
                ret.put(k, Nd4j.concat(0, ((List)outs.get(k)).toArray(new INDArray[0])));
            }
            catch (Exception e) {
                throw new ND4JException("Error concatenating batch outputs", e);
            }
        }
        return ret;
    }

    public static List<INDArray> getSingleOutput(List<Map<String, INDArray>> outputs, String output) {
        ArrayList<INDArray> batches = new ArrayList<INDArray>();
        for (Map<String, INDArray> batch : outputs) {
            batches.add(batch.get(output));
        }
        return batches;
    }

    public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable ... inputs) {
        Preconditions.checkArgument((inputs != null && inputs.length > 0 ? 1 : 0) != 0, (String)"Require at least one SDVariable to be specified when using external errors: got %s", (Object[])inputs);
        ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients);
        fn.outputVariable();
        return fn;
    }

    public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, SDVariable[] inputs) {
        return SameDiffUtils.externalErrors(sameDiff, null, inputs);
    }

    public static SDVariable reductionBroadcastableWithOrigShape(int origRank, int[] reduceDims, SDVariable toExpand) {
        if (Shape.isWholeArray(origRank, reduceDims)) {
            return toExpand;
        }
        if (origRank == 2 && reduceDims.length == 1) {
            return toExpand;
        }
        for (int d : reduceDims) {
            toExpand = toExpand.getSameDiff().expandDims(toExpand, d);
        }
        return toExpand;
    }

    public static SDVariable reductionBroadcastableWithOrigShape(SDVariable origInput, SDVariable axis, SDVariable toExpand) {
        SDVariable shape = origInput.shape();
        SDVariable reduceShape = SameDiffUtils.reductionShape(shape, axis, true);
        SDVariable reshaped = toExpand.reshape(reduceShape);
        return reshaped;
    }

    public static SDVariable reductionShape(SDVariable shape, SDVariable axis, boolean keepDim) {
        return new ReductionShape(shape.getSameDiff(), shape, axis, keepDim).outputVariable();
    }

    public static void validateDifferentialFunctionSameDiff(SameDiff sameDiff, SDVariable function, DifferentialFunction op) {
        Preconditions.checkState((function != null ? 1 : 0) != 0, (String)"Passed in function was null.");
        Preconditions.checkState((function.getSameDiff() == sameDiff ? 1 : 0) != 0);
        Preconditions.checkState((function.getSameDiff() == sameDiff ? 1 : 0) != 0, (String)"Function applications must be contained in same sameDiff. The left %s must match this function %s", (Object)function, (Object)op);
    }

    private SameDiffUtils() {
    }
}

