/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow.op.core;

import java.util.Arrays;
import java.util.Collections;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.Concat;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Gather;
import org.tensorflow.op.core.ReduceProd;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Shape;
import org.tensorflow.op.core.Squeeze;
import org.tensorflow.op.core.StridedSliceHelper;
import org.tensorflow.op.core.Where;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TType;

public abstract class BooleanMask {
    public static <T extends TType> Operand<T> create(Scope scope, Operand<T> tensor, Operand<TBool> mask, Options ... options) {
        scope = scope.withNameAsSubScope("BooleanMask");
        int axis = 0;
        if (options != null) {
            for (Options opts : options) {
                if (opts.axis == null) continue;
                axis = opts.axis;
            }
        }
        if (axis < 0) {
            axis += tensor.rank();
        }
        org.tensorflow.ndarray.Shape maskShape = mask.shape();
        org.tensorflow.ndarray.Shape tensorShape = tensor.shape();
        if (maskShape.numDimensions() == 0) {
            throw new IllegalArgumentException("Mask cannot be a scalar.");
        }
        if (maskShape.hasUnknownDimension()) {
            throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
        }
        Constant<TInt32> axisTensor = Constant.scalarOf(scope, axis);
        org.tensorflow.ndarray.Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions());
        if (!requiredMaskShape.isCompatibleWith(maskShape)) {
            throw new IllegalArgumentException("Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + ".");
        }
        Shape<TInt32> liveShape = Shape.create(scope, tensor);
        ReduceProd<TInt32> leadingSize = ReduceProd.create(scope, StridedSliceHelper.stridedSlice(scope, liveShape, Indices.range(axis, axis + maskShape.numDimensions())), Constant.arrayOf(scope, 0), new ReduceProd.Options[0]);
        Reshape<T> flattened = Reshape.create(scope, tensor, Concat.create(scope, Arrays.asList(StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceTo(axis)), Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)), StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions()))), Constant.scalarOf(scope, 0)));
        Reshape<TBool> flatMask = Reshape.create(scope, mask, Constant.arrayOf(scope, -1));
        Squeeze<TInt64> indices = Squeeze.create(scope, Where.create(scope, flatMask), Squeeze.axis(Collections.singletonList(1L)));
        return Gather.create(scope, flattened, indices, axisTensor, new Gather.Options[0]);
    }

    public static Options axis(Integer axis) {
        return new Options().axis(axis);
    }

    public static Options axis(int axis) {
        return new Options().axis(axis);
    }

    public static class Options {
        private Integer axis;

        public Options axis(Integer axis) {
            this.axis = axis;
            return this;
        }

        private Options() {
        }
    }
}

