/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.liblinear;

import com.oracle.labs.mlrg.olcut.util.Pair;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.common.liblinear.LibLinearType;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.liblinear.LibLinearRegressionModel;
import org.tribuo.regression.liblinear.LinearRegressionType;

public class LibLinearRegressionTrainer
extends LibLinearTrainer<Regressor> {
    private static final Logger logger = Logger.getLogger(LibLinearRegressionTrainer.class.getName());
    boolean forceZero = false;

    public LibLinearRegressionTrainer() {
        this(new LinearRegressionType(LinearRegressionType.LinearType.L2R_L2LOSS_SVR));
    }

    public LibLinearRegressionTrainer(LinearRegressionType trainerType) {
        this(trainerType, 1.0, 1000, 0.1, 0.1);
    }

    public LibLinearRegressionTrainer(LinearRegressionType trainerType, double cost, int maxIterations, double terminationCriterion, double epsilon) {
        this(trainerType, cost, maxIterations, terminationCriterion, epsilon, 12345L);
    }

    public LibLinearRegressionTrainer(LinearRegressionType trainerType, double cost, int maxIterations, double terminationCriterion, double epsilon, long seed) {
        super((LibLinearType)trainerType, cost, maxIterations, terminationCriterion, epsilon, seed);
    }

    public void postConfig() {
        super.postConfig();
        if (!this.trainerType.isRegression()) {
            throw new IllegalArgumentException("Supplied classification or anomaly detection parameters to a regression linear model.");
        }
    }

    protected List<Model> trainModels(Parameter curParams, int numFeatures, FeatureNode[][] features, double[][] outputs) {
        ArrayList<Model> models = new ArrayList<Model>();
        for (int i = 0; i < outputs.length; ++i) {
            Problem data = new Problem();
            data.l = features.length;
            data.y = outputs[i];
            data.x = features;
            data.n = numFeatures;
            data.bias = 1.0;
            if (this.forceZero) {
                curParams.setRandom(new Random(0L));
            }
            models.add(Linear.train((Problem)data, (Parameter)curParams));
        }
        return models;
    }

    protected LibLinearModel<Regressor> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, List<Model> models) {
        if (models.size() != outputIDInfo.size()) {
            throw new IllegalArgumentException("Regression uses one model per dimension. Found " + models.size() + " models, and " + outputIDInfo.size() + " dimensions.");
        }
        return new LibLinearRegressionModel("liblinear-regression-model", provenance, featureIDMap, outputIDInfo, models);
    }

    protected Pair<FeatureNode[][], double[][]> extractData(Dataset<Regressor> data, ImmutableOutputInfo<Regressor> outputInfo, ImmutableFeatureMap featureMap) {
        int numOutputs = outputInfo.size();
        int[] dimensionIds = ((ImmutableRegressionInfo)outputInfo).getNaturalOrderToIDMapping();
        ArrayList featureCache = new ArrayList();
        FeatureNode[][] features = new FeatureNode[data.size()][];
        double[][] outputs = new double[numOutputs][data.size()];
        int i = 0;
        for (Example e : data) {
            double[] curOutputs = ((Regressor)e.getOutput()).getValues();
            for (int j = 0; j < curOutputs.length; ++j) {
                outputs[dimensionIds[j]][i] = curOutputs[j];
            }
            features[i] = LibLinearRegressionTrainer.exampleToNodes((Example)e, (ImmutableFeatureMap)featureMap, featureCache);
            ++i;
        }
        return new Pair((Object)features, (Object)outputs);
    }
}

