/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.slm;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.SparseTrainer;
import org.tribuo.Trainer;
import org.tribuo.WeightedExamples;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.slm.SparseLinearModel;
import org.tribuo.util.Util;

public class SLMTrainer
implements SparseTrainer<Regressor>,
WeightedExamples {
    private static final Logger logger = Logger.getLogger(SLMTrainer.class.getName());
    @Config(description="Maximum number of features to use.")
    protected int maxNumFeatures = -1;
    @Config(description="Normalize the data first.")
    protected boolean normalize;
    protected int trainInvocationCounter = 0;

    public SLMTrainer(boolean normalize, int maxNumFeatures) {
        this.normalize = normalize;
        this.maxNumFeatures = maxNumFeatures;
    }

    public SLMTrainer(boolean normalize) {
        this(normalize, -1);
    }

    protected SLMTrainer() {
    }

    protected RealVector newWeights(SLMState state) {
        RealVector result = SLMTrainer.ordinaryLeastSquares(state.xpi, state.y);
        if (result == null) {
            return null;
        }
        return state.unpack(result);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public SparseLinearModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
        TrainerProvenance trainerProvenance;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        SLMTrainer sLMTrainer = this;
        synchronized (sLMTrainer) {
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableOutputInfo outputInfo = examples.getOutputIDInfo();
        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
        Set domain = outputInfo.getDomain();
        int numOutputs = outputInfo.size();
        int numExamples = examples.size();
        int numFeatures = this.normalize ? featureIDMap.size() : featureIDMap.size() + 1;
        double[][] outputs = new double[numOutputs][numExamples];
        SparseVector[] inputs = new SparseVector[numExamples];
        int n = 0;
        for (Example e : examples) {
            inputs[n] = SparseVector.createSparseVector((Example)e, (ImmutableFeatureMap)featureIDMap, (!this.normalize ? 1 : 0) != 0);
            double curWeight = Math.sqrt(e.getWeight());
            inputs[n].scaleInPlace(curWeight);
            for (Regressor.DimensionTuple r : (Regressor)e.getOutput()) {
                int id = outputInfo.getID((Output)r);
                outputs[id][n] = r.getValue() * curWeight;
            }
            ++n;
        }
        Array2DRowRealMatrix featureMatrix = new Array2DRowRealMatrix(numExamples, numFeatures);
        double[] denseFeatures = new double[numFeatures];
        for (int i = 0; i < inputs.length; ++i) {
            Arrays.fill(denseFeatures, 0.0);
            for (VectorTuple vec : inputs[i]) {
                denseFeatures[vec.index] = vec.value;
            }
            featureMatrix.setRow(i, denseFeatures);
        }
        double[] featureMeans = new double[numFeatures];
        double[] featureVariances = new double[numFeatures];
        double[] outputMeans = new double[numOutputs];
        double[] outputVariances = new double[numOutputs];
        if (this.normalize) {
            int i;
            for (i = 0; i < numFeatures; ++i) {
                double[] featV = featureMatrix.getColumn(i);
                featureMeans[i] = Util.mean((double[])featV);
                int j = 0;
                while (j < featV.length) {
                    int n2 = j++;
                    featV[n2] = featV[n2] - featureMeans[i];
                }
                ArrayRealVector xp = new ArrayRealVector(featV);
                featureVariances[i] = xp.getNorm();
                featureMatrix.setColumnVector(i, xp.mapDivideToSelf(featureVariances[i]));
            }
            for (i = 0; i < numOutputs; ++i) {
                int j;
                outputMeans[i] = Util.mean((double[])outputs[i]);
                double sum = 0.0;
                for (j = 0; j < numExamples; ++j) {
                    double[] dArray = outputs[i];
                    int n3 = j;
                    dArray[n3] = dArray[n3] - outputMeans[i];
                    sum += outputs[i][j] * outputs[i][j];
                }
                outputVariances[i] = Math.sqrt(sum);
                j = 0;
                while (j < numExamples) {
                    double[] dArray = outputs[i];
                    int n4 = j++;
                    dArray[n4] = dArray[n4] / outputVariances[i];
                }
            }
        } else {
            Arrays.fill(featureMeans, 0.0);
            Arrays.fill(featureVariances, 1.0);
            Arrays.fill(outputMeans, 0.0);
            Arrays.fill(outputVariances, 1.0);
        }
        Array2DRowRealMatrix outputMatrix = new Array2DRowRealMatrix(outputs);
        int[] exampleRows = new int[numExamples];
        for (int i = 0; i < numExamples; ++i) {
            exampleRows[i] = i;
        }
        ArrayRealVector one = new ArrayRealVector(numExamples, 1.0);
        int numToSelect = this.maxNumFeatures < 1 || this.maxNumFeatures > featureIDMap.size() ? featureIDMap.size() : this.maxNumFeatures;
        String[] dimensionNames = new String[numOutputs];
        SparseVector[] modelWeights = new SparseVector[numOutputs];
        for (Regressor r : domain) {
            int id = outputInfo.getID((Output)r);
            dimensionNames[id] = r.getNames()[0];
            SLMState state = new SLMState((RealMatrix)featureMatrix, outputMatrix.getRowVector(id), featureIDMap, this.normalize);
            modelWeights[id] = this.trainSingleDimension(state, exampleRows, numToSelect, (RealVector)one);
        }
        ModelProvenance provenance = new ModelProvenance(SparseLinearModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        return new SparseLinearModel("slm-model", dimensionNames, provenance, featureIDMap, (ImmutableOutputInfo<Regressor>)outputInfo, modelWeights, DenseVector.createDenseVector((double[])featureMeans), DenseVector.createDenseVector((double[])featureVariances), outputMeans, outputVariances, !this.normalize);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }

    public String toString() {
        return "SFSTrainer(normalize=" + this.normalize + ",maxNumFeatures=" + this.maxNumFeatures + ")";
    }

    private SparseVector trainSingleDimension(SLMState state, int[] exampleRows, int numToSelect, RealVector one) {
        int iter = 0;
        while (state.active.size() < numToSelect) {
            RealVector betapi;
            state.r = state.y.subtract(state.X.operate(state.beta));
            logger.info("At iteration " + iter + " Average residual " + state.r.dotProduct(one) / (double)state.numExamples);
            ++iter;
            state.corr = state.X.transpose().operate(state.r);
            double max = -1.0;
            int feature = -1;
            for (int i = 0; i < state.numFeatures; ++i) {
                double absCorr;
                if (state.activeSet.contains(i) || !((absCorr = Math.abs(state.corr.getEntry(i))) > max)) continue;
                max = absCorr;
                feature = i;
            }
            state.C = max;
            state.active.add(feature);
            state.activeSet.add(feature);
            if (!state.normalize && feature == state.numFeatures - 1) {
                logger.info("Bias selected");
            } else {
                logger.info("Feature selected: " + state.featureIDMap.get(feature).getName() + " (pos=" + feature + ")");
            }
            int[] activeFeatures = Util.toPrimitiveInt(state.active);
            state.xpi = state.X.getSubMatrix(exampleRows, activeFeatures);
            if (state.active.size() == numToSelect - 1) {
                state.last = true;
            }
            if ((betapi = this.newWeights(state)) == null) {
                logger.log(Level.INFO, "Stopping at feature " + state.active.size() + " matrix was no longer invertible.");
                break;
            }
            state.beta = betapi;
        }
        HashMap<Integer, Double> parameters = new HashMap<Integer, Double>();
        for (int i = 0; i < state.numFeatures; ++i) {
            if (state.beta.getEntry(i) == 0.0) continue;
            parameters.put(i, state.beta.getEntry(i));
        }
        return SparseVector.createSparseVector((int)state.numFeatures, parameters);
    }

    static RealVector ordinaryLeastSquares(RealMatrix M, RealVector target) {
        RealMatrix inv;
        try {
            inv = new LUDecomposition(M.transpose().multiply(M)).getSolver().getInverse();
        }
        catch (SingularMatrixException s) {
            return null;
        }
        return inv.multiply(M.transpose()).operate(target);
    }

    static double sumInverted(RealMatrix matrix) {
        RealMatrix inv = new LUDecomposition(matrix.transpose().multiply(matrix)).getSolver().getInverse();
        ArrayRealVector one = new ArrayRealVector(matrix.getColumnDimension(), 1.0);
        return one.dotProduct(inv.operate((RealVector)one));
    }

    static RealVector getwa(RealMatrix M, double AA) {
        RealMatrix inv = new LUDecomposition(M.transpose().multiply(M)).getSolver().getInverse();
        ArrayRealVector one = new ArrayRealVector(M.getColumnDimension(), 1.0);
        return inv.operate((RealVector)one).mapMultiply(AA);
    }

    static RealVector getA(RealMatrix D, RealMatrix M, RealVector v) {
        RealVector u = M.operate(v);
        return D.transpose().operate(u);
    }

    static class SLMState {
        protected final int numExamples;
        protected final int numFeatures;
        protected final boolean normalize;
        protected final ImmutableFeatureMap featureIDMap;
        protected final Set<Integer> activeSet;
        protected final List<Integer> active;
        protected final RealMatrix X;
        protected final RealVector y;
        protected RealMatrix xpi;
        protected RealVector r;
        protected RealVector beta;
        protected double C;
        protected RealVector corr;
        protected Boolean last = false;

        public SLMState(RealMatrix features, RealVector outputs, ImmutableFeatureMap featureIDMap, boolean normalize) {
            this.numExamples = features.getRowDimension();
            this.numFeatures = features.getColumnDimension();
            this.featureIDMap = featureIDMap;
            this.normalize = normalize;
            this.active = new ArrayList<Integer>();
            this.activeSet = new HashSet<Integer>();
            this.beta = new ArrayRealVector(this.numFeatures);
            this.X = features;
            this.y = outputs;
        }

        public RealVector unpack(RealVector values) {
            ArrayRealVector u = new ArrayRealVector(this.numFeatures);
            for (int i = 0; i < this.active.size(); ++i) {
                u.setEntry(this.active.get(i).intValue(), values.getEntry(i));
            }
            return u;
        }
    }
}

