/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.svm.PlattSMO;
import jsat.distributions.Distribution;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.lossfunctions.LogisticLoss;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

public class LogisticRegressionDCD
implements Classifier,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = -5813704270903243462L;
    private static final double eps_1 = 0.001;
    private static final double eps_2 = 1.0E-8;
    private Vec w;
    private double bias;
    private boolean useBias;
    private double C;
    private int maxIterations;

    public LogisticRegressionDCD() {
        this(1.0);
    }

    public LogisticRegressionDCD(double C) {
        this(C, 100);
    }

    public LogisticRegressionDCD(double C, int maxIterations) {
        this.setC(C);
        this.setMaxIterations(maxIterations);
    }

    protected LogisticRegressionDCD(LogisticRegressionDCD toCopy) {
        this(toCopy.C, toCopy.maxIterations);
        if (toCopy.w != null) {
            this.w = toCopy.w.clone();
        }
        this.bias = toCopy.bias;
        this.useBias = toCopy.useBias;
    }

    public void setC(double C) {
        if (C <= 0.0 || Double.isInfinite(C) || Double.isNaN(C)) {
            throw new IllegalArgumentException("C must be a positive constant, not " + C);
        }
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1) {
            throw new IllegalArgumentException("iterations must be a positive value, not " + maxIterations);
        }
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return LogisticLoss.classify(this.w.dot(data.getNumericalValues()) + this.bias);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.trainC(dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("Logistic Regression is a binary classifier, can can not handle " + dataSet.getClassSize() + " class problems");
        }
        int N = dataSet.getSampleSize();
        List<Vec> x = dataSet.getDataVectors();
        double[] alpha = new double[N];
        double[] alphaPrime = new double[N];
        double[] Q_ii = new double[N];
        int[] y = new int[N];
        Arrays.fill(alpha, Math.min(0.001 * this.C, 1.0E-8));
        Arrays.fill(alphaPrime, this.C - alpha[0]);
        this.w = new DenseVector(dataSet.getNumNumericalVars());
        this.bias = 0.0;
        for (int i = 0; i < N; ++i) {
            y[i] = dataSet.getDataPointCategory(i) * 2 - 1;
            Vec x_i = x.get(i);
            Q_ii[i] = x_i.dot(x_i);
            this.w.mutableAdd(alpha[0] * (double)y[i], x_i);
            if (!this.useBias) continue;
            this.bias += alpha[0] * (double)y[i];
        }
        IntList permutation = new IntList(N);
        ListUtils.addRange(permutation, 0, N, 1);
        for (int iter = 0; iter < this.maxIterations; ++iter) {
            Collections.shuffle(permutation);
            double maxChange = 0.0;
            Iterator iterator = permutation.iterator();
            while (iterator.hasNext()) {
                boolean case1;
                int i = (Integer)iterator.next();
                Vec x_i = x.get(i);
                double c1 = alpha[i];
                double c2 = alphaPrime[i];
                double a = Q_ii[i];
                double b = (double)y[i] * (this.w.dot(x_i) + this.bias);
                double z_m = (c2 - c1) / 2.0;
                double s = c1 + c2;
                boolean bl = case1 = z_m >= -b / a;
                double z = case1 ? (c1 >= s / 2.0 ? 0.1 * c1 : c1) : (c2 >= s / 2.0 ? 0.1 * c2 : c2);
                if (z < 1.0E-20) continue;
                for (int subIter = 0; subIter < 100; ++subIter) {
                    double gP = Math.log(z / (this.C - z));
                    gP = case1 ? (gP += a * (z - c1) + b) : (gP += a * (z - c2) - b);
                    if (Math.abs(gP) < 1.0E-6) break;
                    double gPP = a + s / (z * (s - z));
                    double d = -gP / gPP;
                    if (z + d <= 0.0) {
                        z *= 0.1;
                        continue;
                    }
                    z += d;
                }
                if (case1) {
                    alpha[i] = z;
                    alphaPrime[i] = this.C - z;
                } else {
                    alpha[i] = this.C - z;
                    alphaPrime[i] = z;
                }
                double change = alpha[i] - c1;
                this.w.mutableAdd(change * (double)y[i], x_i);
                if (this.useBias) {
                    this.bias += change * (double)y[i];
                }
                maxChange = Math.max(maxChange, change);
            }
            if (!(Math.abs(maxChange) < 1.0E-4)) continue;
            return;
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Classifier clone() {
        return new LogisticRegressionDCD(this);
    }

    public Vec getWeightVec() {
        return this.w;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    public static Distribution guessC(DataSet d) {
        return PlattSMO.guessC(d);
    }
}

