/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.ann;

import java.io.Serializable;
import org.apache.spark.ml.ann.ANNGradient;
import org.apache.spark.ml.ann.ANNUpdater;
import org.apache.spark.ml.ann.DataStacker;
import org.apache.spark.ml.ann.Topology;
import org.apache.spark.ml.ann.TopologyModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorImplicits$;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.Gradient;
import org.apache.spark.mllib.optimization.GradientDescent;
import org.apache.spark.mllib.optimization.LBFGS;
import org.apache.spark.mllib.optimization.Optimizer;
import org.apache.spark.mllib.optimization.Updater;
import org.apache.spark.rdd.RDD;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.Function1;
import scala.MatchError;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u0005%f!\u0002\u0015*\u0001-\u001a\u0004\u0002C\u001f\u0001\u0005\u0003\u0005\u000b\u0011B \t\u0011\r\u0003!Q1A\u0005\u0002\u0011C\u0001\u0002\u0013\u0001\u0003\u0002\u0003\u0006I!\u0012\u0005\t\u0013\u0002\u0011)\u0019!C\u0001\t\"A!\n\u0001B\u0001B\u0003%Q\tC\u0003L\u0001\u0011\u0005A\nC\u0004R\u0001\u0001\u0007I\u0011\u0002*\t\u000fY\u0003\u0001\u0019!C\u0005/\"1Q\f\u0001Q!\nMCqA\u0018\u0001A\u0002\u0013%q\fC\u0004g\u0001\u0001\u0007I\u0011B4\t\r%\u0004\u0001\u0015)\u0003a\u0011\u001dQ\u0007\u00011A\u0005\n\u0011Cqa\u001b\u0001A\u0002\u0013%A\u000e\u0003\u0004o\u0001\u0001\u0006K!\u0012\u0005\b_\u0002\u0001\r\u0011\"\u0003q\u0011\u001d!\b\u00011A\u0005\nUDaa\u001e\u0001!B\u0013\t\bb\u0002=\u0001\u0001\u0004%I!\u001f\u0005\n\u0003\u000b\u0001\u0001\u0019!C\u0005\u0003\u000fAq!a\u0003\u0001A\u0003&!\u0010C\u0005\u0002\u000e\u0001\u0001\r\u0011\"\u0003\u0002\u0010!I\u0011q\u0003\u0001A\u0002\u0013%\u0011\u0011\u0004\u0005\t\u0003;\u0001\u0001\u0015)\u0003\u0002\u0012!I\u0011q\u0004\u0001A\u0002\u0013%\u0011\u0011\u0005\u0005\n\u0003S\u0001\u0001\u0019!C\u0005\u0003WA\u0001\"a\f\u0001A\u0003&\u00111\u0005\u0005\u0007\u0003c\u0001A\u0011\u0001*\t\u000f\u0005M\u0002\u0001\"\u0001\u00026!1\u0011Q\b\u0001\u0005\u0002}Cq!a\u0010\u0001\t\u0003\t\t\u0005C\u0004\u0002F\u0001!\t!a\u0012\t\u000f\u0005-\u0003\u0001\"\u0001\u0002N!9\u0011Q\u000b\u0001\u0005\u0002\u0005]\u0003bBA0\u0001\u0011\u0005\u0011\u0011\r\u0005\b\u0003K\u0002A\u0011AA4\u0011!\tY\u0007\u0001Q\u0005\n\u00055\u0004\u0002CA:\u0001\u0001&I!!\u001e\t\u000f\u0005m\u0004\u0001\"\u0001\u0002~\t\u0011b)Z3e\r>\u0014x/\u0019:e)J\f\u0017N\\3s\u0015\tQ3&A\u0002b]:T!\u0001L\u0017\u0002\u00055d'B\u0001\u00180\u0003\u0015\u0019\b/\u0019:l\u0015\t\u0001\u0014'\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002e\u0005\u0019qN]4\u0014\u0007\u0001!$\b\u0005\u00026q5\taGC\u00018\u0003\u0015\u00198-\u00197b\u0013\tIdG\u0001\u0004B]f\u0014VM\u001a\t\u0003kmJ!\u0001\u0010\u001c\u0003\u0019M+'/[1mSj\f'\r\\3\u0002\u0011Q|\u0007o\u001c7pOf\u001c\u0001\u0001\u0005\u0002A\u00036\t\u0011&\u0003\u0002CS\tAAk\u001c9pY><\u00170A\u0005j]B,HoU5{KV\tQ\t\u0005\u00026\r&\u0011qI\u000e\u0002\u0004\u0013:$\u0018AC5oaV$8+\u001b>fA\u0005Qq.\u001e;qkR\u001c\u0016N_3\u0002\u0017=,H\u000f];u'&TX\rI\u0001\u0007y%t\u0017\u000e\u001e \u0015\t5su\n\u0015\t\u0003\u0001\u0002AQ!\u0010\u0004A\u0002}BQa\u0011\u0004A\u0002\u0015CQ!\u0013\u0004A\u0002\u0015\u000bQaX:fK\u0012,\u0012a\u0015\t\u0003kQK!!\u0016\u001c\u0003\t1{gnZ\u0001\n?N,W\rZ0%KF$\"\u0001W.\u0011\u0005UJ\u0016B\u0001.7\u0005\u0011)f.\u001b;\t\u000fqC\u0011\u0011!a\u0001'\u0006\u0019\u0001\u0010J\u0019\u0002\r}\u001bX-\u001a3!\u0003!yv/Z5hQR\u001cX#\u00011\u0011\u0005\u0005$W\"\u00012\u000b\u0005\r\\\u0013A\u00027j]\u0006dw-\u0003\u0002fE\n1a+Z2u_J\fAbX<fS\u001eDGo]0%KF$\"\u0001\u00175\t\u000fq[\u0011\u0011!a\u0001A\u0006Iql^3jO\"$8\u000fI\u0001\u000b?N$\u0018mY6TSj,\u0017AD0ti\u0006\u001c7nU5{K~#S-\u001d\u000b\u000316Dq\u0001\u0018\b\u0002\u0002\u0003\u0007Q)A\u0006`gR\f7m[*ju\u0016\u0004\u0013a\u00033bi\u0006\u001cF/Y2lKJ,\u0012!\u001d\t\u0003\u0001JL!a]\u0015\u0003\u0017\u0011\u000bG/Y*uC\u000e\\WM]\u0001\u0010I\u0006$\u0018m\u0015;bG.,'o\u0018\u0013fcR\u0011\u0001L\u001e\u0005\b9F\t\t\u00111\u0001r\u00031!\u0017\r^1Ti\u0006\u001c7.\u001a:!\u0003%yvM]1eS\u0016tG/F\u0001{!\rY\u0018\u0011A\u0007\u0002y*\u0011QP`\u0001\r_B$\u0018.\\5{CRLwN\u001c\u0006\u0003\u007f6\nQ!\u001c7mS\nL1!a\u0001}\u0005!9%/\u00193jK:$\u0018!D0he\u0006$\u0017.\u001a8u?\u0012*\u0017\u000fF\u0002Y\u0003\u0013Aq\u0001\u0018\u000b\u0002\u0002\u0003\u0007!0\u0001\u0006`OJ\fG-[3oi\u0002\n\u0001bX;qI\u0006$XM]\u000b\u0003\u0003#\u00012a_A\n\u0013\r\t)\u0002 \u0002\b+B$\u0017\r^3s\u00031yV\u000f\u001d3bi\u0016\u0014x\fJ3r)\rA\u00161\u0004\u0005\t9^\t\t\u00111\u0001\u0002\u0012\u0005Iq,\u001e9eCR,'\u000fI\u0001\n_B$\u0018.\\5{KJ,\"!a\t\u0011\u0007m\f)#C\u0002\u0002(q\u0014\u0011b\u00149uS6L'0\u001a:\u0002\u001b=\u0004H/[7ju\u0016\u0014x\fJ3r)\rA\u0016Q\u0006\u0005\t9j\t\t\u00111\u0001\u0002$\u0005Qq\u000e\u001d;j[&TXM\u001d\u0011\u0002\u000f\u001d,GoU3fI\u000691/\u001a;TK\u0016$G\u0003BA\u001c\u0003si\u0011\u0001\u0001\u0005\u0007\u0003wi\u0002\u0019A*\u0002\u000bY\fG.^3\u0002\u0015\u001d,GoV3jO\"$8/\u0001\u0006tKR<V-[4iiN$B!a\u000e\u0002D!1\u00111H\u0010A\u0002\u0001\fAb]3u'R\f7m[*ju\u0016$B!a\u000e\u0002J!1\u00111\b\u0011A\u0002\u0015\u000bAbU$E\u001fB$\u0018.\\5{KJ,\"!a\u0014\u0011\u0007m\f\t&C\u0002\u0002Tq\u0014qb\u0012:bI&,g\u000e\u001e#fg\u000e,g\u000e^\u0001\u000f\u0019\n3uiU(qi&l\u0017N_3s+\t\tI\u0006E\u0002|\u00037J1!!\u0018}\u0005\u0015a%IR$T\u0003)\u0019X\r^+qI\u0006$XM\u001d\u000b\u0005\u0003o\t\u0019\u0007C\u0004\u0002<\r\u0002\r!!\u0005\u0002\u0017M,Go\u0012:bI&,g\u000e\u001e\u000b\u0005\u0003o\tI\u0007\u0003\u0004\u0002<\u0011\u0002\rA_\u0001\u000fkB$\u0017\r^3He\u0006$\u0017.\u001a8u)\rA\u0016q\u000e\u0005\u0007\u0003c*\u0003\u0019\u0001>\u0002\u0011\u001d\u0014\u0018\rZ5f]R\fQ\"\u001e9eCR,W\u000b\u001d3bi\u0016\u0014Hc\u0001-\u0002x!9\u0011\u0011\u0010\u0014A\u0002\u0005E\u0011aB;qI\u0006$XM]\u0001\u0006iJ\f\u0017N\u001c\u000b\u0005\u0003\u007f\n9\nE\u00046\u0003\u0003\u000b))a#\n\u0007\u0005\reG\u0001\u0004UkBdWM\r\t\u0004\u0001\u0006\u001d\u0015bAAES\tiAk\u001c9pY><\u00170T8eK2\u0004R!NAG\u0003#K1!a$7\u0005\u0015\t%O]1z!\r)\u00141S\u0005\u0004\u0003+3$A\u0002#pk\ndW\rC\u0004\u0002\u001a\u001e\u0002\r!a'\u0002\t\u0011\fG/\u0019\t\u0007\u0003;\u000b\u0019+a*\u000e\u0005\u0005}%bAAQ[\u0005\u0019!\u000f\u001a3\n\t\u0005\u0015\u0016q\u0014\u0002\u0004%\u0012#\u0005#B\u001b\u0002\u0002\u0002\u0004\u0007")
public class FeedForwardTrainer
implements scala.Serializable {
    private final Topology topology;
    private final int inputSize;
    private final int outputSize;
    private long _seed;
    private Vector _weights;
    private int _stackSize;
    private DataStacker dataStacker;
    private Gradient _gradient;
    private Updater _updater;
    private Optimizer optimizer;

    public int inputSize() {
        return this.inputSize;
    }

    public int outputSize() {
        return this.outputSize;
    }

    private long _seed() {
        return this._seed;
    }

    private void _seed_$eq(long x$1) {
        this._seed = x$1;
    }

    private Vector _weights() {
        return this._weights;
    }

    private void _weights_$eq(Vector x$1) {
        this._weights = x$1;
    }

    private int _stackSize() {
        return this._stackSize;
    }

    private void _stackSize_$eq(int x$1) {
        this._stackSize = x$1;
    }

    private DataStacker dataStacker() {
        return this.dataStacker;
    }

    private void dataStacker_$eq(DataStacker x$1) {
        this.dataStacker = x$1;
    }

    private Gradient _gradient() {
        return this._gradient;
    }

    private void _gradient_$eq(Gradient x$1) {
        this._gradient = x$1;
    }

    private Updater _updater() {
        return this._updater;
    }

    private void _updater_$eq(Updater x$1) {
        this._updater = x$1;
    }

    private Optimizer optimizer() {
        return this.optimizer;
    }

    private void optimizer_$eq(Optimizer x$1) {
        this.optimizer = x$1;
    }

    public long getSeed() {
        return this._seed();
    }

    public FeedForwardTrainer setSeed(long value) {
        this._seed_$eq(value);
        return this;
    }

    public Vector getWeights() {
        return this._weights();
    }

    public FeedForwardTrainer setWeights(Vector value) {
        this._weights_$eq(value);
        return this;
    }

    public FeedForwardTrainer setStackSize(int value) {
        this._stackSize_$eq(value);
        this.dataStacker_$eq(new DataStacker(value, this.inputSize(), this.outputSize()));
        return this;
    }

    public GradientDescent SGDOptimizer() {
        GradientDescent sgd = new GradientDescent(this._gradient(), this._updater());
        this.optimizer_$eq(sgd);
        return sgd;
    }

    public LBFGS LBFGSOptimizer() {
        LBFGS lbfgs = new LBFGS(this._gradient(), this._updater());
        this.optimizer_$eq(lbfgs);
        return lbfgs;
    }

    public FeedForwardTrainer setUpdater(Updater value) {
        this._updater_$eq(value);
        this.updateUpdater(value);
        return this;
    }

    public FeedForwardTrainer setGradient(Gradient value) {
        this._gradient_$eq(value);
        this.updateGradient(value);
        return this;
    }

    private void updateGradient(Gradient gradient) {
        Optimizer optimizer = this.optimizer();
        if (optimizer instanceof LBFGS) {
            LBFGS lBFGS = (LBFGS)optimizer;
            lBFGS.setGradient(gradient);
            return;
        }
        if (optimizer instanceof GradientDescent) {
            GradientDescent gradientDescent = (GradientDescent)optimizer;
            gradientDescent.setGradient(gradient);
            return;
        }
        throw new UnsupportedOperationException(new StringBuilder(54).append("Only LBFGS and GradientDescent are supported but got ").append(optimizer.getClass()).append(".").toString());
    }

    private void updateUpdater(Updater updater) {
        Optimizer optimizer = this.optimizer();
        if (optimizer instanceof LBFGS) {
            LBFGS lBFGS = (LBFGS)optimizer;
            lBFGS.setUpdater(updater);
            return;
        }
        if (optimizer instanceof GradientDescent) {
            GradientDescent gradientDescent = (GradientDescent)optimizer;
            gradientDescent.setUpdater(updater);
            return;
        }
        throw new UnsupportedOperationException(new StringBuilder(54).append("Only LBFGS and GradientDescent are supported but got ").append(optimizer.getClass()).append(".").toString());
    }

    public Tuple2<TopologyModel, double[]> train(RDD<Tuple2<Vector, Vector>> data) {
        Tuple2<org.apache.spark.mllib.linalg.Vector, double[]> tuple2;
        Vector w = this.getWeights() == null ? this.topology.model(this._seed()).weights() : this.getWeights();
        RDD trainData = this.dataStacker().stack(data).map((Function1 & Serializable & scala.Serializable)v -> new Tuple2((Object)BoxesRunTime.boxToDouble((double)v._1$mcD$sp()), (Object)Vectors$.MODULE$.fromML((Vector)v._2())), ClassTag$.MODULE$.apply(Tuple2.class));
        StorageLevel storageLevel = trainData.getStorageLevel();
        StorageLevel storageLevel2 = StorageLevel$.MODULE$.NONE();
        boolean handlePersistence = !(storageLevel != null ? !storageLevel.equals(storageLevel2) : storageLevel2 != null);
        Object object = handlePersistence ? trainData.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()) : BoxedUnit.UNIT;
        Optimizer optimizer = this.optimizer();
        if (optimizer instanceof LBFGS) {
            LBFGS lBFGS = (LBFGS)optimizer;
            tuple2 = lBFGS.optimizeWithLossReturned((RDD<Tuple2<Object, org.apache.spark.mllib.linalg.Vector>>)trainData, VectorImplicits$.MODULE$.mlVectorToMLlibVector(w));
        } else if (optimizer instanceof GradientDescent) {
            GradientDescent gradientDescent = (GradientDescent)optimizer;
            tuple2 = gradientDescent.optimizeWithLossReturned((RDD<Tuple2<Object, org.apache.spark.mllib.linalg.Vector>>)trainData, VectorImplicits$.MODULE$.mlVectorToMLlibVector(w));
        } else {
            throw new UnsupportedOperationException(new StringBuilder(54).append("Only LBFGS and GradientDescent are supported but got ").append(optimizer.getClass()).append(".").toString());
        }
        Tuple2<org.apache.spark.mllib.linalg.Vector, double[]> tuple22 = tuple2;
        if (tuple22 == null) {
            throw new MatchError(tuple22);
        }
        org.apache.spark.mllib.linalg.Vector newWeights = (org.apache.spark.mllib.linalg.Vector)tuple22._1();
        double[] lossHistory = (double[])tuple22._2();
        Tuple2 tuple23 = new Tuple2((Object)newWeights, (Object)lossHistory);
        org.apache.spark.mllib.linalg.Vector newWeights2 = (org.apache.spark.mllib.linalg.Vector)tuple23._1();
        double[] lossHistory2 = (double[])tuple23._2();
        Object object2 = handlePersistence ? trainData.unpersist(trainData.unpersist$default$1()) : BoxedUnit.UNIT;
        return new Tuple2((Object)this.topology.model(VectorImplicits$.MODULE$.mllibVectorToMLVector(newWeights2)), (Object)lossHistory2);
    }

    public FeedForwardTrainer(Topology topology, int inputSize, int outputSize) {
        this.topology = topology;
        this.inputSize = inputSize;
        this.outputSize = outputSize;
        this._seed = this.getClass().getName().hashCode();
        this._weights = null;
        this._stackSize = 128;
        this.dataStacker = new DataStacker(this._stackSize(), inputSize, outputSize);
        this._gradient = new ANNGradient(topology, this.dataStacker());
        this._updater = new ANNUpdater();
        this.optimizer = this.LBFGSOptimizer().setConvergenceTol(1.0E-4).setNumIterations(100);
    }
}

