/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import scala.Function0;
import scala.Function2;
import scala.Predef$;
import scala.Serializable;
import scala.math.package$;
import scala.reflect.ScalaSignature;
import scala.runtime.DoubleRef;
import scala.runtime.java8.JFunction2;

@ScalaSignature(bytes="\u0006\u0001I3Q\u0001C\u0005\u0001\u001bUA\u0001b\n\u0001\u0003\u0002\u0003\u0006I!\u000b\u0005\tk\u0001\u0011\t\u0011)A\u0005m!A\u0011\b\u0001B\u0001B\u0003%!\bC\u0003B\u0001\u0011\u0005!\tC\u0004H\u0001\t\u0007I\u0011\u000b%\t\r1\u0003\u0001\u0015!\u0003J\u0011\u0015i\u0005\u0001\"\u0001O\u00055\te\tV!hOJ,w-\u0019;pe*\u0011!bC\u0001\u000bC\u001e<'/Z4bi>\u0014(B\u0001\u0007\u000e\u0003\u0015y\u0007\u000f^5n\u0015\tqq\"\u0001\u0002nY*\u0011\u0001#E\u0001\u0006gB\f'o\u001b\u0006\u0003%M\ta!\u00199bG\",'\"\u0001\u000b\u0002\u0007=\u0014xmE\u0002\u0001-q\u0001\"a\u0006\u000e\u000e\u0003aQ\u0011!G\u0001\u0006g\u000e\fG.Y\u0005\u00037a\u0011a!\u00118z%\u00164\u0007\u0003B\u000f\u001fA\u0019j\u0011!C\u0005\u0003?%\u0011A\u0004R5gM\u0016\u0014XM\u001c;jC\ndW\rT8tg\u0006;wM]3hCR|'\u000f\u0005\u0002\"I5\t!E\u0003\u0002$\u001b\u00059a-Z1ukJ,\u0017BA\u0013#\u0005!Ien\u001d;b]\u000e,\u0007CA\u000f\u0001\u00035\u00117MR3biV\u0014Xm]*uI\u000e\u0001\u0001c\u0001\u0016._5\t1F\u0003\u0002-\u001f\u0005I!M]8bI\u000e\f7\u000f^\u0005\u0003]-\u0012\u0011B\u0011:pC\u0012\u001c\u0017m\u001d;\u0011\u0007]\u0001$'\u0003\u000221\t)\u0011I\u001d:bsB\u0011qcM\u0005\u0003ia\u0011a\u0001R8vE2,\u0017\u0001\u00044ji&sG/\u001a:dKB$\bCA\f8\u0013\tA\u0004DA\u0004C_>dW-\u00198\u0002\u001d\t\u001c7i\\3gM&\u001c\u0017.\u001a8ugB\u0019!&L\u001e\u0011\u0005qzT\"A\u001f\u000b\u0005yj\u0011A\u00027j]\u0006dw-\u0003\u0002A{\t1a+Z2u_J\fa\u0001P5oSRtDcA\"F\rR\u0011a\u0005\u0012\u0005\u0006s\u0011\u0001\rA\u000f\u0005\u0006O\u0011\u0001\r!\u000b\u0005\u0006k\u0011\u0001\rAN\u0001\u0004I&lW#A%\u0011\u0005]Q\u0015BA&\u0019\u0005\rIe\u000e^\u0001\u0005I&l\u0007%A\u0002bI\u0012$\"a\u0014)\u000e\u0003\u0001AQ!U\u0004A\u0002\u0001\nA\u0001Z1uC\u0002")
public class AFTAggregator
implements DifferentiableLossAggregator<Instance, AFTAggregator> {
    private final Broadcast<double[]> bcFeaturesStd;
    private final boolean fitIntercept;
    private final Broadcast<Vector> bcCoefficients;
    private final int dim;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile boolean bitmap$0;

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        AFTAggregator aFTAggregator = this;
        synchronized (aFTAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    @Override
    public AFTAggregator add(Instance data) {
        double[] coefficients = ((Vector)this.bcCoefficients.value()).toArray();
        double intercept = coefficients[this.dim() - 2];
        double sigma = package$.MODULE$.exp(coefficients[this.dim() - 1]);
        Vector xi = data.features();
        double ti = data.label();
        double delta = data.weight();
        Predef$.MODULE$.require(ti > 0.0, (Function0 & java.io.Serializable & Serializable)() -> "The lifetime or label should be  greater than 0.");
        double[] localFeaturesStd = (double[])this.bcFeaturesStd.value();
        DoubleRef sum = DoubleRef.create((double)0.0);
        xi.foreachNonZero((Function2)(JFunction2.mcVID.sp & java.io.Serializable & Serializable)(index, value) -> {
            block0: {
                if (localFeaturesStd[index] == 0.0) break block0;
                sum$1.elem += coefficients[index] * (value / localFeaturesStd[index]);
            }
        });
        double margin = sum.elem + intercept;
        double epsilon = (package$.MODULE$.log(ti) - margin) / sigma;
        this.lossSum_$eq(this.lossSum() + (delta * package$.MODULE$.log(sigma) - delta * epsilon + package$.MODULE$.exp(epsilon)));
        double multiplier = (delta - package$.MODULE$.exp(epsilon)) / sigma;
        xi.foreachNonZero((Function2)(JFunction2.mcVID.sp & java.io.Serializable & Serializable)(index, value) -> {
            block0: {
                if (localFeaturesStd[index] == 0.0) break block0;
                $this.gradientSumArray()[index] = this.gradientSumArray()[index] + multiplier * (value / localFeaturesStd[index]);
            }
        });
        int n = this.dim() - 2;
        this.gradientSumArray()[n] = this.gradientSumArray()[n] + (this.fitIntercept ? multiplier : 0.0);
        int n2 = this.dim() - 1;
        this.gradientSumArray()[n2] = this.gradientSumArray()[n2] + (delta + multiplier * sigma * epsilon);
        this.weightSum_$eq(this.weightSum() + 1.0);
        return this;
    }

    public AFTAggregator(Broadcast<double[]> bcFeaturesStd, boolean fitIntercept, Broadcast<Vector> bcCoefficients) {
        this.bcFeaturesStd = bcFeaturesStd;
        this.fitIntercept = fitIntercept;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        this.dim = ((Vector)bcCoefficients.value()).size();
    }
}

