/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.linear.Vec;
import jsat.lossfunctions.HingeLoss;
import jsat.lossfunctions.LossC;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.random.RandomUtil;

public class BOGD
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
Parameterized {
    private static final long serialVersionUID = -3547832514098781996L;
    @Parameter.ParameterHolder
    private KernelTrick k;
    private int budget;
    private double eta;
    private double reg;
    private double maxCoeff;
    private LossC lossC;
    private boolean uniformSampling;
    private Random rand;
    private List<Vec> vecs;
    private List<Double> selfK;
    private DoubleList alphas;
    private List<Double> accelCache;
    private double[] dist;

    public BOGD(KernelTrick k, int budget, double eta, double reg, double maxCoeff) {
        this(k, budget, eta, reg, maxCoeff, new HingeLoss());
    }

    public BOGD(KernelTrick k, int budget, double eta, double reg, double maxCoeff, LossC lossC) {
        this.setKernel(k);
        this.setBudget(budget);
        this.setEta(eta);
        this.setRegularization(reg);
        this.setMaxCoeff(maxCoeff);
        this.lossC = lossC;
        this.setUniformSampling(false);
    }

    public BOGD(BOGD toCopy) {
        this.k = toCopy.k.clone();
        this.budget = toCopy.budget;
        this.eta = toCopy.eta;
        this.reg = toCopy.reg;
        this.maxCoeff = toCopy.maxCoeff;
        this.lossC = toCopy.lossC.clone();
        this.uniformSampling = toCopy.uniformSampling;
        this.rand = RandomUtil.getRandom();
        if (toCopy.vecs != null) {
            this.vecs = new ArrayList<Vec>(this.budget);
            for (Vec v : toCopy.vecs) {
                this.vecs.add(v.clone());
            }
            this.selfK = new DoubleList(toCopy.selfK);
            this.alphas = new DoubleList(toCopy.alphas);
        }
        if (toCopy.accelCache != null) {
            this.accelCache = new DoubleList(toCopy.accelCache);
        }
        if (toCopy.dist != null) {
            this.dist = Arrays.copyOf(toCopy.dist, toCopy.dist.length);
        }
    }

    public void setRegularization(double regularization) {
        if (regularization <= 0.0 || Double.isNaN(regularization) || Double.isInfinite(regularization)) {
            throw new IllegalArgumentException("Regularization must be positive, not " + regularization);
        }
        this.reg = regularization;
    }

    public double getRegularization() {
        return this.reg;
    }

    public void setEta(double eta) {
        if (eta <= 0.0 || Double.isNaN(eta) || Double.isInfinite(eta)) {
            throw new IllegalArgumentException("Eta must be positive, not " + eta);
        }
        this.eta = eta;
    }

    public double getEta() {
        return this.eta;
    }

    public void setMaxCoeff(double maxCoeff) {
        if (maxCoeff <= 0.0 || Double.isNaN(maxCoeff) || Double.isInfinite(maxCoeff)) {
            throw new IllegalArgumentException("MaxCoeff must be positive, not " + maxCoeff);
        }
        this.maxCoeff = maxCoeff;
    }

    public double getMaxCoeff() {
        return this.maxCoeff;
    }

    public void setBudget(int budget) {
        if (budget <= 0) {
            throw new IllegalArgumentException("Budget must be positive, not " + budget);
        }
        this.budget = budget;
    }

    public int getBudget() {
        return this.budget;
    }

    public void setKernel(KernelTrick k) {
        this.k = k;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    public void setUniformSampling(boolean uniformSampling) {
        this.uniformSampling = uniformSampling;
    }

    public boolean isUniformSampling() {
        return this.uniformSampling;
    }

    @Override
    public BOGD clone() {
        return new BOGD(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        this.vecs = new ArrayList<Vec>(this.budget);
        this.alphas = new DoubleList(this.budget);
        this.selfK = new DoubleList(this.budget);
        this.accelCache = this.k.supportsAcceleration() ? new DoubleList(this.budget) : null;
        if (!this.uniformSampling) {
            this.dist = new double[this.budget];
        }
        this.rand = RandomUtil.getRandom();
    }

    @Override
    public double getScore(DataPoint dp) {
        Vec x = dp.getNumericalValues();
        return this.score(x, this.k.getQueryInfo(x));
    }

    private double score(Vec x, List<Double> qi) {
        return this.k.evalSum(this.vecs, this.accelCache, this.alphas.getBackingArray(), x, qi, 0, this.alphas.size());
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        Vec x_t = dataPoint.getNumericalValues();
        double y_t = targetClass * 2 - 1;
        List<Double> qi = this.k.getQueryInfo(x_t);
        double score = this.score(x_t, qi);
        double lossD = this.lossC.getDeriv(score, y_t);
        if (lossD == 0.0) {
            this.alphas.getVecView().mutableMultiply(1.0 - this.eta * this.reg);
        } else if (this.vecs.size() < this.budget) {
            this.alphas.getVecView().mutableMultiply(1.0 - this.eta * this.reg);
            this.alphas.add(-this.eta * lossD);
            this.selfK.add(Math.sqrt(this.k.eval(0, 0, Arrays.asList(x_t), qi)));
            if (this.k.supportsAcceleration()) {
                this.accelCache.addAll(qi);
            }
            this.vecs.add(x_t);
        } else {
            double normalize;
            int toRemove;
            if (this.uniformSampling) {
                toRemove = this.rand.nextInt(this.budget);
                normalize = 1.0;
            } else {
                double cur;
                double s = 0.0;
                for (int i = 0; i < this.budget; ++i) {
                    s += Math.abs(this.alphas.get(i)) * this.selfK.get(i);
                }
                s = (double)(this.budget - 1) / s;
                double target = this.rand.nextDouble();
                int i = -1;
                for (cur = 0.0; cur < target; cur += this.dist[++i]) {
                    this.dist[++i] = 1.0 - s * this.alphas.get(i) * this.selfK.get(i);
                }
                toRemove = i++;
                while (i < this.budget) {
                    int n = i;
                    double d = 1.0 - s * this.alphas.get(i) * this.selfK.get(i++);
                    this.dist[n] = d;
                    cur += d;
                }
                normalize = cur;
            }
            for (int i = 0; i < this.budget; ++i) {
                if (i == toRemove) continue;
                double alpha_i = this.alphas.getD(i);
                double sign = Math.signum(alpha_i);
                alpha_i = Math.abs(alpha_i);
                double tmp = this.uniformSampling ? 1.0 / (double)this.budget : this.dist[i] / normalize;
                this.alphas.set(i, sign * Math.min((1.0 - this.reg * this.eta) / (1.0 - tmp) * alpha_i, this.maxCoeff * this.eta));
            }
            if (this.k.supportsAcceleration()) {
                int catToRet = this.accelCache.size() / this.budget;
                for (int i = 0; i < catToRet; ++i) {
                    this.accelCache.remove(toRemove * catToRet);
                }
            }
            this.alphas.remove(toRemove);
            this.vecs.remove(toRemove);
            this.selfK.remove(toRemove);
            this.alphas.add(-this.eta * lossD);
            this.selfK.add(Math.sqrt(this.k.eval(0, 0, Arrays.asList(x_t), qi)));
            this.accelCache.addAll(qi);
            this.vecs.add(x_t);
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        return this.lossC.getClassification(this.score(x, this.k.getQueryInfo(x)));
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    public static Distribution guessRegularization(DataSet d) {
        double T2 = d.getSampleSize();
        T2 *= T2;
        return new LogUniform(Math.pow(2.0, -3.0) / T2, Math.pow(2.0, 3.0) / T2);
    }

    public static Distribution guessEta(DataSet d) {
        return new LogUniform(Math.pow(2.0, -3.0), Math.pow(2.0, 3.0));
    }

    public static Distribution guessMaxCoeff(DataSet d) {
        return new LogUniform(Math.pow(2.0, 0.0), Math.pow(2.0, 4.0));
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

