/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.SimpleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.ConstantVector;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.MathTricks;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.ExponetialDecay;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

public class StochasticMultinomialLogisticRegression
implements Classifier,
Parameterized,
SimpleWeightVectorModel {
    private static final long serialVersionUID = -492707881682847556L;
    private int epochs;
    private boolean clipping = true;
    private double regularization;
    private double tolerance = 1.0E-4;
    private double initialLearningRate;
    private double alpha = 0.5;
    private DecayRate learningRateDecay = new ExponetialDecay();
    private Prior prior;
    private boolean standardized = true;
    private boolean useBias = true;
    private int miniBatchSize = 1;
    private Vec[] B;
    private double[] biases;

    public StochasticMultinomialLogisticRegression(double initialLearningRate, int epochs, double regularization, Prior prior) {
        this.setEpochs(epochs);
        this.setRegularization(regularization);
        this.setInitialLearningRate(initialLearningRate);
        this.setPrior(prior);
    }

    public StochasticMultinomialLogisticRegression(double initialLearningRate, int epochs) {
        this(initialLearningRate, epochs, 1.0E-6, Prior.GAUSSIAN);
    }

    public StochasticMultinomialLogisticRegression() {
        this(0.1, 50);
    }

    protected StochasticMultinomialLogisticRegression(StochasticMultinomialLogisticRegression toClone) {
        this.epochs = toClone.epochs;
        this.clipping = toClone.clipping;
        this.regularization = toClone.regularization;
        this.tolerance = toClone.tolerance;
        this.initialLearningRate = toClone.initialLearningRate;
        this.alpha = toClone.alpha;
        this.learningRateDecay = toClone.learningRateDecay;
        this.prior = toClone.prior;
        this.standardized = toClone.standardized;
        if (toClone.B != null) {
            this.B = new Vec[toClone.B.length];
            for (int i = 0; i < toClone.B.length; ++i) {
                this.B[i] = toClone.B[i].clone();
            }
        }
        if (toClone.biases != null) {
            this.biases = Arrays.copyOf(toClone.biases, toClone.biases.length);
        }
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    public void setEpochs(int epochs) {
        if (epochs <= 0) {
            throw new IllegalArgumentException("Number of epochs must be positive");
        }
        this.epochs = epochs;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setAlpha(double alpha) {
        if (alpha < 0.0 || Double.isNaN(alpha) || Double.isInfinite(alpha)) {
            throw new IllegalArgumentException("Extra parameter must be non negative, not " + alpha);
        }
        this.alpha = alpha;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setClipping(boolean clipping) {
        this.clipping = clipping;
    }

    public boolean isClipping() {
        return this.clipping;
    }

    public void setInitialLearningRate(double initialLearningRate) {
        if (initialLearningRate <= 0.0 || Double.isInfinite(initialLearningRate) || Double.isNaN(initialLearningRate)) {
            throw new IllegalArgumentException("Learning rate must be a positive constant, not " + initialLearningRate);
        }
        this.initialLearningRate = initialLearningRate;
    }

    public double getInitialLearningRate() {
        return this.initialLearningRate;
    }

    public void setLearningRateDecay(DecayRate learningRateDecay) {
        this.learningRateDecay = learningRateDecay;
    }

    public DecayRate getLearningRateDecay() {
        return this.learningRateDecay;
    }

    public void setRegularization(double regularization) {
        if (regularization < 0.0 || Double.isNaN(regularization) || Double.isInfinite(regularization)) {
            throw new IllegalArgumentException("Regualrization must be a non negative constant, not " + regularization);
        }
        this.regularization = regularization;
    }

    public double getRegularization() {
        return this.regularization;
    }

    public void setPrior(Prior prior) {
        this.prior = prior;
    }

    public Prior getPrior() {
        return this.prior;
    }

    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setStandardized(boolean standardized) {
        this.standardized = standardized;
    }

    public boolean isStandardized() {
        return this.standardized;
    }

    public void setMiniBatchSize(int miniBatchSize) {
        this.miniBatchSize = miniBatchSize;
    }

    public int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index == this.B.length) {
            return new ConstantVector(0.0, this.B[0].length());
        }
        return this.B[index];
    }

    @Override
    public double getBias(int index) {
        if (index == this.biases.length) {
            return 0.0;
        }
        return this.biases[index];
    }

    @Override
    public int numWeightsVecs() {
        return this.B.length + 1;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.B == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        Vec x = data.getNumericalValues();
        double[] probs = new double[this.B.length + 1];
        for (int i = 0; i < this.B.length; ++i) {
            probs[i] = x.dot(this.B[i]) + this.biases[i];
        }
        probs[this.B.length] = 1.0;
        MathTricks.softmax(probs, false);
        return new CategoricalResults(probs);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.trainC(dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        int n = dataSet.getSampleSize();
        double N = n;
        int d = dataSet.getNumNumericalVars();
        if (d < 1) {
            throw new FailedToFitException("Data set has no numeric attributes to train on");
        }
        this.B = new Vec[dataSet.getClassSize() - 1];
        this.biases = new double[this.B.length];
        for (int i = 0; i < this.B.length; ++i) {
            this.B[i] = new DenseVector(d);
        }
        IntList randOrder = new IntList(n);
        ListUtils.addRange(randOrder, 0, n, 1);
        Vec means = null;
        Vec stdDevs = null;
        if (this.standardized) {
            Vec[] ms = dataSet.getColumnMeanVariance();
            means = ms[0];
            stdDevs = ms[1];
            stdDevs.applyFunction(MathTricks.sqrtFunc);
            means.pairwiseDivide(stdDevs);
            stdDevs.applyFunction(MathTricks.invsFunc);
        }
        double[] zs = new double[this.B.length];
        int[] u = new int[d];
        int q = 0;
        double prevLogLike = Double.POSITIVE_INFINITY;
        for (int iter = 0; iter < this.epochs; ++iter) {
            double dif;
            Collections.shuffle(randOrder);
            double logLike = 0.0;
            double eta = this.learningRateDecay.rate(iter, this.epochs, this.initialLearningRate);
            double etaReg = this.regularization * eta;
            for (int batch = 0; batch < randOrder.size(); batch += this.miniBatchSize) {
                int batchCount = Math.min(this.miniBatchSize, randOrder.size() - batch);
                double batchFrac = 1.0 / (double)batchCount;
                for (int k = 0; k < batchCount; ++k) {
                    int j = randOrder.get(batch + k);
                    int c_j = dataSet.getDataPointCategory(j);
                    Vec x_j = dataSet.getDataPoint(j).getNumericalValues();
                    for (int i = 0; i < this.B.length; ++i) {
                        zs[i] = x_j.dot(this.B[i]) + this.biases[i];
                    }
                    MathTricks.softmax(zs, true);
                    if (this.prior != Prior.UNIFORM) {
                        for (IndexValue iv : x_j) {
                            int i = iv.getIndex();
                            if (u[i] == 0) continue;
                            double etaRegScaled = etaReg * (double)(u[i] - q) / N;
                            for (Vec b : this.B) {
                                double bVal;
                                double bNewVal = bVal = b.get(i);
                                bNewVal = this.standardized ? (bNewVal += etaRegScaled * this.prior.gradientError(bVal * stdDevs.get(i) - means.get(i), 1.0, this.alpha)) : (bNewVal += etaRegScaled * this.prior.gradientError(bVal, 1.0, this.alpha));
                                if (this.clipping && Math.signum(bVal) != Math.signum(bNewVal)) {
                                    b.set(i, 0.0);
                                    continue;
                                }
                                b.set(i, bNewVal);
                            }
                            u[i] = q;
                        }
                    }
                    for (int c = 0; c < this.B.length; ++c) {
                        Vec b = this.B[c];
                        double p_c = zs[c];
                        double log_pc = Math.log(p_c);
                        if (!Double.isInfinite(log_pc)) {
                            logLike += log_pc;
                        }
                        double errScaling = (double)(c == c_j ? 1 : 0) - p_c;
                        b.mutableAdd(batchFrac * eta * errScaling, x_j);
                        if (!this.useBias) continue;
                        int n2 = c;
                        this.biases[n2] = this.biases[n2] + (batchFrac * eta * errScaling + etaReg * this.prior.gradientError(this.biases[c] - 1.0, 1.0, this.alpha));
                    }
                }
                ++q;
            }
            logLike *= -1.0;
            if (this.prior != Prior.UNIFORM) {
                for (int i = 0; i < d; ++i) {
                    if (u[i] - q == 0) {
                        for (Vec b : this.B) {
                            if (this.standardized) {
                                logLike += this.regularization * this.prior.logProb(b.get(i) * stdDevs.get(i) - means.get(i), 1.0, this.alpha);
                                continue;
                            }
                            logLike += this.regularization * this.prior.logProb(b.get(i), 1.0, this.alpha);
                        }
                        continue;
                    }
                    double etaRegScaled = etaReg * (double)(u[i] - q) / N;
                    for (Vec b : this.B) {
                        double bVal = b.get(i);
                        if (bVal == 0.0) continue;
                        double bNewVal = bVal;
                        bNewVal = this.standardized ? (bNewVal += etaRegScaled * this.prior.gradientError(bVal * stdDevs.get(i) - means.get(i), 1.0, this.alpha)) : (bNewVal += etaRegScaled * this.prior.gradientError(bVal, 1.0, this.alpha));
                        if (this.clipping && Math.signum(bVal) != Math.signum(bNewVal)) {
                            b.set(i, 0.0);
                        } else {
                            b.set(i, bNewVal);
                        }
                        if (this.standardized) {
                            logLike += this.regularization * this.prior.logProb(b.get(i) * stdDevs.get(i) - means.get(i), 1.0, this.alpha);
                            continue;
                        }
                        logLike += this.regularization * this.prior.logProb(b.get(i), 1.0, this.alpha);
                    }
                    u[i] = q;
                }
            }
            if ((dif = Math.abs(prevLogLike - logLike) / (Math.abs(prevLogLike) + Math.abs(logLike))) < this.tolerance) break;
            prevLogLike = logLike;
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    public Vec getCoefficientVector(int id) {
        return this.B[id];
    }

    @Override
    public Classifier clone() {
        return new StochasticMultinomialLogisticRegression(this);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    public static enum Prior {
        GAUSSIAN{

            @Override
            protected double gradientError(double b_i, double s_i) {
                return -b_i / s_i;
            }

            @Override
            protected double logProb(double b_i, double s_i) {
                return -0.5 * Math.log(Math.PI * 2 * s_i) - 2.0 * b_i * b_i * s_i / 2.0;
            }
        }
        ,
        LAPLACE{

            @Override
            protected double gradientError(double b_i, double s_i) {
                return -Math.sqrt(2.0) * Math.signum(b_i) / Math.sqrt(s_i);
            }

            @Override
            protected double logProb(double b_i, double s_i) {
                return -Math.signum(b_i) * Math.sqrt(2.0) * b_i / Math.sqrt(s_i) - 0.5 * Math.log(2.0 * s_i);
            }
        }
        ,
        ELASTIC{

            @Override
            protected double gradientError(double b_i, double s_i) {
                throw new UnsupportedOperationException();
            }

            @Override
            protected double gradientError(double b_i, double s_i, double alpha) {
                return alpha * LAPLACE.gradientError(b_i, s_i) + (1.0 - alpha) * GAUSSIAN.gradientError(b_i, s_i);
            }

            @Override
            protected double logProb(double b_i, double s_i) {
                return Double.NaN;
            }

            @Override
            protected double logProb(double b_i, double s_i, double alpha) {
                return alpha * LAPLACE.logProb(b_i, s_i) + (1.0 - alpha) * GAUSSIAN.logProb(b_i, s_i);
            }
        }
        ,
        CAUCHY{

            @Override
            protected double gradientError(double b_i, double s_i) {
                throw new UnsupportedOperationException();
            }

            @Override
            protected double gradientError(double b_i, double s_i, double alpha) {
                return -2.0 * b_i / (b_i * b_i + alpha * alpha);
            }

            @Override
            protected double logProb(double b_i, double s_i) {
                return Double.NaN;
            }

            @Override
            protected double logProb(double b_i, double s_i, double alpha) {
                return -Math.log(Math.PI) + Math.log(alpha) - Math.log(b_i * b_i + alpha * alpha);
            }
        }
        ,
        UNIFORM{

            @Override
            protected double gradientError(double b_i, double s_i) {
                return 0.0;
            }

            @Override
            protected double logProb(double b_i, double s_i) {
                return 0.0;
            }
        };


        protected abstract double gradientError(double var1, double var3);

        protected double gradientError(double b_i, double s_i, double alpha) {
            return this.gradientError(b_i, s_i);
        }

        protected abstract double logProb(double var1, double var3);

        protected double logProb(double b_i, double s_i, double alpha) {
            return this.logProb(b_i, s_i);
        }
    }
}

