/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.slm;

import java.util.ArrayList;
import java.util.Collections;
import java.util.logging.Logger;
import org.apache.commons.math3.linear.RealVector;
import org.tribuo.regression.slm.SLMTrainer;

public class LARSTrainer
extends SLMTrainer {
    private static final Logger logger = Logger.getLogger(LARSTrainer.class.getName());

    public LARSTrainer(int maxNumFeatures) {
        super(true, maxNumFeatures);
    }

    public LARSTrainer() {
        this(-1);
    }

    @Override
    protected RealVector newWeights(SLMTrainer.SLMState state) {
        if (state.last.booleanValue()) {
            return super.newWeights(state);
        }
        RealVector deltapi = SLMTrainer.ordinaryLeastSquares(state.xpi, state.r);
        if (deltapi == null) {
            return null;
        }
        RealVector delta = state.unpack(deltapi);
        ArrayList<Double> candidates = new ArrayList<Double>();
        double AA = SLMTrainer.sumInverted(state.xpi);
        double CC = state.C;
        RealVector wa = SLMTrainer.getwa(state.xpi, AA);
        RealVector ar = SLMTrainer.getA(state.X, state.xpi, wa);
        for (int i = 0; i < state.numFeatures; ++i) {
            if (state.activeSet.contains(i)) continue;
            double c = state.corr.getEntry(i);
            double a = ar.getEntry(i);
            double v1 = (CC - c) / (AA - a);
            double v2 = (CC + c) / (AA + a);
            if (v1 >= 0.0) {
                candidates.add(v1);
            }
            if (!(v2 >= 0.0)) continue;
            candidates.add(v2);
        }
        double gamma = (Double)Collections.min(candidates);
        return state.beta.add(delta.mapMultiplyToSelf(gamma));
    }

    @Override
    public String toString() {
        return "LARSTrainer(maxNumFeatures=" + this.maxNumFeatures + ")";
    }
}

