/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.common.liblinear;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearType;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.util.Util;

public abstract class LibLinearTrainer<T extends Output<T>>
implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(LibLinearTrainer.class.getName());
    protected Parameter libLinearParams;
    @Config(description="Algorithm to use.")
    protected LibLinearType<T> trainerType;
    @Config(description="Cost penalty for misclassifications.")
    protected double cost = 1.0;
    @Config(description="Maximum number of iterations before terminating.")
    protected int maxIterations = 1000;
    @Config(description="Stop iterating when the loss score decreases by less than this value.")
    protected double terminationCriterion = 0.1;
    @Config(description="Epsilon insensitivity in the regression cost function.")
    protected double epsilon = 0.1;
    @Config(description="RNG seed.")
    protected long seed = 12345L;
    private SplittableRandom rng;
    private int trainInvocationCount = 0;

    protected LibLinearTrainer() {
    }

    protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion) {
        this(trainerType, cost, maxIterations, terminationCriterion, 0.1);
    }

    protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion, long seed) {
        this(trainerType, cost, maxIterations, terminationCriterion, 0.1, seed);
    }

    protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion, double epsilon) {
        this(trainerType, cost, maxIterations, terminationCriterion, epsilon, 12345L);
    }

    protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion, double epsilon, long seed) {
        this.trainerType = trainerType;
        this.cost = cost;
        this.maxIterations = maxIterations;
        this.terminationCriterion = terminationCriterion;
        this.epsilon = epsilon;
        this.seed = seed;
        this.postConfig();
    }

    public void postConfig() {
        this.libLinearParams = new Parameter(this.trainerType.getSolverType(), this.cost, this.terminationCriterion, this.maxIterations, this.epsilon);
        this.rng = new SplittableRandom(this.seed);
        Linear.disableDebugOutput();
    }

    public LibLinearModel<T> train(Dataset<T> examples) {
        return this.train((Dataset)examples, Collections.emptyMap());
    }

    public LibLinearModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
        return this.train((Dataset)examples, (Map)runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public LibLinearModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        LibLinearTrainer libLinearTrainer = this;
        synchronized (libLinearTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            localRNG = this.rng.split();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCount;
        }
        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = examples.getOutputIDInfo();
        Parameter curParams = this.setupParameters(outputIDInfo);
        curParams.setRandom(new Random(localRNG.nextLong()));
        ModelProvenance provenance = new ModelProvenance(LibLinearModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        Pair<FeatureNode[][], double[][]> data = this.extractData(examples, outputIDInfo, featureIDMap);
        List<Model> models = this.trainModels(curParams, featureIDMap.size() + 1, (FeatureNode[][])data.getA(), (double[][])data.getB());
        return this.createModel(provenance, featureIDMap, outputIDInfo, models);
    }

    public int getInvocationCount() {
        return this.trainInvocationCount;
    }

    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCount = 0;
        while (this.trainInvocationCount < invocationCount) {
            SplittableRandom splittableRandom = this.rng.split();
            ++this.trainInvocationCount;
        }
    }

    public String toString() {
        StringBuilder buffer = new StringBuilder();
        buffer.append("LibLinearTrainer(");
        buffer.append("solver=");
        buffer.append(this.libLinearParams.getSolverType());
        buffer.append(",cost=");
        buffer.append(this.libLinearParams.getC());
        buffer.append(",terminationCriterion=");
        buffer.append(this.libLinearParams.getEps());
        buffer.append(",maxIterations=");
        buffer.append(this.libLinearParams.getMaxIters());
        buffer.append(",regression-epsilon=");
        buffer.append(this.libLinearParams.getP());
        buffer.append(",seed=");
        buffer.append(this.seed);
        buffer.append(')');
        return buffer.toString();
    }

    protected abstract List<Model> trainModels(Parameter var1, int var2, FeatureNode[][] var3, double[][] var4);

    protected abstract LibLinearModel<T> createModel(ModelProvenance var1, ImmutableFeatureMap var2, ImmutableOutputInfo<T> var3, List<Model> var4);

    protected abstract Pair<FeatureNode[][], double[][]> extractData(Dataset<T> var1, ImmutableOutputInfo<T> var2, ImmutableFeatureMap var3);

    protected Parameter setupParameters(ImmutableOutputInfo<T> info) {
        return this.libLinearParams.clone();
    }

    public static <T extends Output<T>> FeatureNode[] exampleToNodes(Example<T> example, ImmutableFeatureMap featureIDMap, List<FeatureNode> features) {
        int biasIndex = featureIDMap.size() + 1;
        if (features == null) {
            features = new ArrayList<FeatureNode>();
        }
        features.clear();
        int prevIdx = -1;
        for (Feature f : example) {
            int id = featureIDMap.getID(f.getName());
            if (id > prevIdx) {
                prevIdx = id;
                features.add(new FeatureNode(id + 1, f.getValue()));
                continue;
            }
            if (id <= -1) continue;
            int collisionIdx = Util.binarySearch(features, (int)(id + 1), FeatureNode::getIndex);
            if (collisionIdx < 0) {
                collisionIdx = -(collisionIdx + 1);
                features.add(collisionIdx, new FeatureNode(id + 1, f.getValue()));
                continue;
            }
            FeatureNode n = features.get(collisionIdx);
            n.setValue(n.getValue() + f.getValue());
        }
        features.add(new FeatureNode(biasIndex, 1.0));
        return features.toArray(new FeatureNode[0]);
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((Trainer)this);
    }
}

