/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.slm;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.logging.Logger;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.regression.slm.SLMTrainer;

public class LARSTrainer
extends SLMTrainer {
    private static final Logger logger = Logger.getLogger(LARSTrainer.class.getName());

    public LARSTrainer(int maxNumFeatures) {
        super(true, maxNumFeatures);
    }

    public LARSTrainer() {
        this(-1);
    }

    @Override
    protected DenseVector newWeights(SLMTrainer.SLMState state) {
        if (state.last) {
            return super.newWeights(state);
        }
        Pair<DenseVector, DenseMatrix> deltapi = SLMTrainer.ordinaryLeastSquares(state.xpi, state.r);
        if (deltapi == null) {
            return null;
        }
        DenseVector delta = state.unpack((DenseVector)deltapi.getA());
        DenseMatrix xpiInv = (DenseMatrix)deltapi.getB();
        ArrayList<Double> candidates = new ArrayList<Double>();
        double AA = xpiInv.rowSum().sum();
        double CC = state.C;
        DenseVector wa = SLMTrainer.getWA(xpiInv, AA);
        DenseVector ar = SLMTrainer.getA(state.X, state.xpi, wa);
        for (int i = 0; i < state.numFeatures; ++i) {
            if (state.activeSet.contains(i)) continue;
            double c = state.corr.get(i);
            double a = ar.get(i);
            double v1 = (CC - c) / (AA - a);
            double v2 = (CC + c) / (AA + a);
            if (v1 >= 0.0) {
                candidates.add(v1);
            }
            if (!(v2 >= 0.0)) continue;
            candidates.add(v2);
        }
        double gamma = (Double)Collections.min(candidates);
        delta.scaleInPlace(gamma);
        return state.beta.add((SGDVector)delta);
    }

    @Override
    public String toString() {
        return "LARSTrainer(maxNumFeatures=" + this.maxNumFeatures + ")";
    }
}

