/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.slm;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.SparseTrainer;
import org.tribuo.Trainer;
import org.tribuo.WeightedExamples;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.slm.SparseLinearModel;

public class SLMTrainer
implements SparseTrainer<Regressor>,
WeightedExamples {
    private static final Logger logger = Logger.getLogger(SLMTrainer.class.getName());
    @Config(description="Maximum number of features to use.")
    protected int maxNumFeatures = -1;
    @Config(description="Normalize the data first.")
    protected boolean normalize;
    protected int trainInvocationCounter = 0;

    public SLMTrainer(boolean normalize, int maxNumFeatures) {
        this.normalize = normalize;
        this.maxNumFeatures = maxNumFeatures;
    }

    public SLMTrainer(boolean normalize) {
        this(normalize, -1);
    }

    protected SLMTrainer() {
    }

    protected DenseVector newWeights(SLMState state) {
        Pair<DenseVector, DenseMatrix> result = SLMTrainer.ordinaryLeastSquares(state.xpi, state.y);
        if (result == null) {
            return null;
        }
        return state.unpack((DenseVector)result.getA());
    }

    public SparseLinearModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
        return this.train(examples, runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public SparseLinearModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        TrainerProvenance trainerProvenance;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        SLMTrainer sLMTrainer = this;
        synchronized (sLMTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableOutputInfo outputInfo = examples.getOutputIDInfo();
        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
        Set domain = outputInfo.getDomain();
        int numOutputs = outputInfo.size();
        int numExamples = examples.size();
        int numFeatures = this.normalize ? featureIDMap.size() : featureIDMap.size() + 1;
        DenseMatrix outputMatrix = new DenseMatrix(numOutputs, numExamples);
        SparseVector[] inputs = new SparseVector[numExamples];
        int n = 0;
        for (Example e : examples) {
            inputs[n] = SparseVector.createSparseVector((Example)e, (ImmutableFeatureMap)featureIDMap, (!this.normalize ? 1 : 0) != 0);
            double curWeight = Math.sqrt(e.getWeight());
            inputs[n].scaleInPlace(curWeight);
            for (Regressor.DimensionTuple r : (Regressor)e.getOutput()) {
                int id = outputInfo.getID((Output)r);
                outputMatrix.set(id, n, r.getValue() * curWeight);
            }
            ++n;
        }
        DenseMatrix featureMatrix = DenseMatrix.createDenseMatrix((SGDVector[])inputs);
        double[] featureMeans = new double[numFeatures];
        double[] featureNorms = new double[numFeatures];
        double[] outputMeans = new double[numOutputs];
        double[] outputNorms = new double[numOutputs];
        if (this.normalize) {
            int i;
            for (i = 0; i < numFeatures; ++i) {
                DenseVector col = featureMatrix.getColumn(i);
                double colMean = col.meanVariance().getMean();
                double colNorm = Math.sqrt(col.reduce(0.0, a -> a - colMean, (a, b) -> b + a * a));
                col.foreachInPlace(a -> (a - colMean) / colNorm);
                featureMatrix.setColumn(i, (SGDVector)col);
                featureMeans[i] = colMean;
                featureNorms[i] = colNorm;
            }
            for (i = 0; i < numOutputs; ++i) {
                DenseVector row = outputMatrix.getRow(i);
                double rowMean = row.meanVariance().getMean();
                double rowNorm = Math.sqrt(row.reduce(0.0, a -> a - rowMean, (a, b) -> b + a * a));
                row.foreachInPlace(a -> (a - rowMean) / rowNorm);
                outputMeans[i] = rowMean;
                outputNorms[i] = rowNorm;
            }
        } else {
            Arrays.fill(featureMeans, 0.0);
            Arrays.fill(featureNorms, 1.0);
            Arrays.fill(outputMeans, 0.0);
            Arrays.fill(outputNorms, 1.0);
        }
        int numToSelect = this.maxNumFeatures < 1 || this.maxNumFeatures > featureIDMap.size() ? featureIDMap.size() : this.maxNumFeatures;
        String[] dimensionNames = new String[numOutputs];
        SparseVector[] modelWeights = new SparseVector[numOutputs];
        for (Regressor r : domain) {
            int id = outputInfo.getID((Output)r);
            dimensionNames[id] = r.getNames()[0];
            SLMState state = new SLMState(featureMatrix, outputMatrix.getRow(id), featureIDMap, this.normalize);
            modelWeights[id] = this.trainSingleDimension(state, numToSelect);
        }
        ModelProvenance provenance = new ModelProvenance(SparseLinearModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        return new SparseLinearModel("slm-model", dimensionNames, provenance, featureIDMap, (ImmutableOutputInfo<Regressor>)outputInfo, modelWeights, DenseVector.createDenseVector((double[])featureMeans), DenseVector.createDenseVector((double[])featureNorms), outputMeans, outputNorms, !this.normalize);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.trainInvocationCounter = invocationCount;
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }

    public String toString() {
        return "SFSTrainer(normalize=" + this.normalize + ",maxNumFeatures=" + this.maxNumFeatures + ")";
    }

    private SparseVector trainSingleDimension(SLMState state, int numToSelect) {
        int iter = 0;
        while (state.active.size() < numToSelect) {
            DenseVector betapi;
            state.r = state.y.subtract((SGDVector)state.X.leftMultiply((SGDVector)state.beta));
            logger.info("At iteration " + iter + " Average residual " + state.r.sum() / (double)state.numExamples);
            ++iter;
            state.corr = state.X.rightMultiply((SGDVector)state.r);
            double max = -1.0;
            int feature = -1;
            for (int i = 0; i < state.numFeatures; ++i) {
                double absCorr;
                if (state.activeSet.contains(i) || !((absCorr = Math.abs(state.corr.get(i))) > max)) continue;
                max = absCorr;
                feature = i;
            }
            state.C = max;
            state.active.add(feature);
            state.activeSet.add(feature);
            if (!state.normalize && feature == state.numFeatures - 1) {
                logger.info("Bias selected");
            } else {
                logger.info("Feature selected: " + state.featureIDMap.get(feature).getName() + " (pos=" + feature + ")");
            }
            state.xpi = state.X.selectColumns(state.active);
            if (state.active.size() == numToSelect - 1) {
                state.last = true;
            }
            if ((betapi = this.newWeights(state)) == null) {
                logger.log(Level.INFO, "Stopping at feature " + state.active.size() + " matrix was no longer invertible.");
                break;
            }
            state.beta = betapi;
        }
        HashMap<Integer, Double> parameters = new HashMap<Integer, Double>();
        for (int i = 0; i < state.numFeatures; ++i) {
            if (state.beta.get(i) == 0.0) continue;
            parameters.put(i, state.beta.get(i));
        }
        return SparseVector.createSparseVector((int)state.numFeatures, parameters);
    }

    static Pair<DenseVector, DenseMatrix> ordinaryLeastSquares(DenseMatrix M, DenseVector target) {
        Optional lu = M.matrixMultiply((Matrix)M, true, false).luFactorization();
        if (lu.isPresent()) {
            DenseMatrix inv = (DenseMatrix)((DenseMatrix.LUFactorization)lu.get()).inverse();
            return new Pair((Object)inv.matrixMultiply((Matrix)M, false, true).leftMultiply((SGDVector)target), (Object)inv);
        }
        return null;
    }

    static DenseVector getWA(DenseMatrix inv, double AA) {
        DenseVector ones = new DenseVector(inv.getDimension2Size(), 1.0);
        DenseVector output = inv.rightMultiply((SGDVector)ones);
        output.scaleInPlace(AA);
        return output;
    }

    static DenseVector getA(DenseMatrix D, DenseMatrix M, DenseVector v) {
        DenseVector u = M.leftMultiply((SGDVector)v);
        return D.rightMultiply((SGDVector)u);
    }

    static class SLMState {
        protected final int numExamples;
        protected final int numFeatures;
        protected final boolean normalize;
        protected final ImmutableFeatureMap featureIDMap;
        protected final Set<Integer> activeSet;
        protected final List<Integer> active;
        protected final DenseMatrix X;
        protected final DenseVector y;
        protected DenseMatrix xpi;
        protected DenseVector r;
        protected DenseVector beta;
        protected double C;
        protected DenseVector corr;
        protected boolean last = false;

        public SLMState(DenseMatrix features, DenseVector outputs, ImmutableFeatureMap featureIDMap, boolean normalize) {
            this.numExamples = features.getDimension1Size();
            this.numFeatures = features.getDimension2Size();
            this.featureIDMap = featureIDMap;
            this.normalize = normalize;
            this.active = new ArrayList<Integer>(this.numFeatures);
            this.activeSet = new HashSet<Integer>();
            this.beta = new DenseVector(this.numFeatures);
            this.X = features;
            this.y = outputs;
        }

        public DenseVector unpack(DenseVector values) {
            DenseVector u = new DenseVector(this.numFeatures);
            for (int i = 0; i < this.active.size(); ++i) {
                u.set(this.active.get(i).intValue(), values.get(i));
            }
            return u;
        }
    }
}

