/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.RBFKernel;
import jsat.linear.CholeskyDecomposition;
import jsat.linear.DenseMatrix;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

public class KernelRidgeRegression
implements Regressor,
Parameterized {
    private static final long serialVersionUID = 6275333785663250072L;
    private double lambda;
    @Parameter.ParameterHolder
    private KernelTrick k;
    private List<Vec> vecs;
    private double[] alphas;

    public KernelRidgeRegression() {
        this(1.0E-6, new RBFKernel());
    }

    public KernelRidgeRegression(double lambda, KernelTrick kernel) {
        this.setLambda(lambda);
        this.setKernel(kernel);
    }

    protected KernelRidgeRegression(KernelRidgeRegression toCopy) {
        this(toCopy.lambda, toCopy.getKernel().clone());
        if (toCopy.alphas != null) {
            this.alphas = Arrays.copyOf(toCopy.alphas, toCopy.alphas.length);
        }
        if (toCopy.vecs != null) {
            this.vecs = new ArrayList<Vec>(toCopy.vecs);
        }
    }

    public static Distribution guessLambda(DataSet d) {
        return new LogUniform(1.0E-7, 0.01);
    }

    public void setLambda(double lambda) {
        if (Double.isNaN(lambda) || Double.isInfinite(lambda) || lambda <= 0.0) {
            throw new IllegalArgumentException("lambda must be a positive constant, not " + lambda);
        }
        this.lambda = lambda;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setKernel(KernelTrick k) {
        this.k = k;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();
        double score = 0.0;
        for (int i = 0; i < this.alphas.length; ++i) {
            score += this.alphas[i] * this.k.eval(this.vecs.get(i), x);
        }
        return score;
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        final int N = dataSet.getSampleSize();
        this.vecs = new ArrayList<Vec>(N);
        Vec Y = dataSet.getTargetValues();
        for (int i = 0; i < N; ++i) {
            this.vecs.add(dataSet.getDataPoint(i).getNumericalValues());
        }
        final DenseMatrix K = new DenseMatrix(N, N);
        final CountDownLatch cdl = new CountDownLatch(SystemInfo.LogicalCores);
        int id = 0;
        while (id < SystemInfo.LogicalCores) {
            final int ID = id++;
            threadPool.submit(new Runnable(){

                @Override
                public void run() {
                    for (int i = ID; i < N; i += SystemInfo.LogicalCores) {
                        K.set(i, i, KernelRidgeRegression.this.k.eval((Vec)KernelRidgeRegression.this.vecs.get(i), (Vec)KernelRidgeRegression.this.vecs.get(i)) + KernelRidgeRegression.this.lambda);
                        for (int j = i + 1; j < N; ++j) {
                            double K_ij = KernelRidgeRegression.this.k.eval((Vec)KernelRidgeRegression.this.vecs.get(i), (Vec)KernelRidgeRegression.this.vecs.get(j));
                            K.set(i, j, K_ij);
                            K.set(j, i, K_ij);
                        }
                    }
                    cdl.countDown();
                }
            });
        }
        try {
            cdl.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(KernelRidgeRegression.class.getName()).log(Level.SEVERE, null, ex);
        }
        CholeskyDecomposition cd = threadPool instanceof FakeExecutor ? new CholeskyDecomposition(K) : new CholeskyDecomposition(K, threadPool);
        Vec alphaTmp = cd.solve(Y);
        this.alphas = alphaTmp.arrayCopy();
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, new FakeExecutor());
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public KernelRidgeRegression clone() {
        return new KernelRidgeRegression(this);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

