/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.knn;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.distributions.Distribution;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.distributions.empirical.kernelfunc.EpanechnikovKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

public class LWL
implements Classifier,
Regressor,
Parameterized {
    private static final long serialVersionUID = 6942465758987345997L;
    private CategoricalData predicting;
    private Classifier classifier;
    private Regressor regressor;
    private int k;
    private DistanceMetric dm;
    private KernelFunction kf;
    private VectorCollectionFactory<VecPaired<Vec, Double>> vcf;
    private VectorCollection<VecPaired<Vec, Double>> vc;

    private LWL(LWL toCopy) {
        if (toCopy.predicting != null) {
            this.predicting = toCopy.predicting.clone();
        }
        if (toCopy.classifier != null) {
            this.setClassifier(toCopy.classifier);
        }
        if (toCopy.regressor != null) {
            this.setRegressor(toCopy.regressor);
        }
        this.setNeighbors(toCopy.k);
        this.setDistanceMetric(toCopy.dm.clone());
        this.setKernelFunction(toCopy.kf);
        this.vcf = toCopy.vcf;
        if (toCopy.vc != null) {
            this.vc = toCopy.vc.clone();
        }
    }

    public LWL(Classifier classifier, int k, DistanceMetric dm) {
        this(classifier, k, dm, (KernelFunction)EpanechnikovKF.getInstance());
    }

    public LWL(Classifier classifier, int k, DistanceMetric dm, KernelFunction kf) {
        this(classifier, k, dm, kf, new DefaultVectorCollectionFactory<VecPaired<Vec, Double>>());
    }

    public LWL(Classifier classifier, int k, DistanceMetric dm, KernelFunction kf, VectorCollectionFactory<VecPaired<Vec, Double>> vcf) {
        this.setClassifier(classifier);
        this.setNeighbors(k);
        this.setDistanceMetric(dm);
        this.setKernelFunction(kf);
        this.vcf = vcf;
    }

    public LWL(Regressor regressor, int k, DistanceMetric dm) {
        this(regressor, k, dm, (KernelFunction)EpanechnikovKF.getInstance());
    }

    public LWL(Regressor regressor, int k, DistanceMetric dm, KernelFunction kf) {
        this(regressor, k, dm, kf, new DefaultVectorCollectionFactory<VecPaired<Vec, Double>>());
    }

    public LWL(Regressor regressor, int k, DistanceMetric dm, KernelFunction kf, VectorCollectionFactory<VecPaired<Vec, Double>> vcf) {
        this.setRegressor(regressor);
        this.setNeighbors(k);
        this.setDistanceMetric(dm);
        this.setKernelFunction(kf);
        this.vcf = vcf;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.classifier == null || this.vc == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        List<VecPaired<VecPaired<Vec, Double>, Double>> knn = this.vc.search(data.getNumericalValues(), this.k);
        ArrayList<DataPointPair<Integer>> localPoints = new ArrayList<DataPointPair<Integer>>(knn.size());
        double maxD = knn.get(knn.size() - 1).getPair();
        for (int i = 0; i < knn.size(); ++i) {
            VecPaired<VecPaired<Vec, Double>, Double> v = knn.get(i);
            DataPoint dp = new DataPoint(v, new int[0], new CategoricalData[0], this.kf.k(v.getPair() / maxD));
            localPoints.add(new DataPointPair<Integer>(dp, v.getVector().getPair().intValue()));
        }
        ClassificationDataSet localSet = new ClassificationDataSet(localPoints, this.predicting);
        Classifier localClassifier = this.classifier.clone();
        localClassifier.trainC(localSet);
        return localClassifier.classify(data);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        List<VecPaired<Vec, Double>> trainList = this.getVecList(dataSet);
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, threadPool);
        this.vc = this.vcf.getVectorCollection(trainList, this.dm, threadPool);
        this.predicting = dataSet.getPredicting();
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        List<VecPaired<Vec, Double>> trainList = this.getVecList(dataSet);
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
        this.vc = this.vcf.getVectorCollection(trainList, this.dm);
        this.predicting = dataSet.getPredicting();
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public double regress(DataPoint data) {
        if (this.regressor == null || this.vc == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        List<VecPaired<VecPaired<Vec, Double>, Double>> knn = this.vc.search(data.getNumericalValues(), this.k);
        ArrayList<DataPointPair<Double>> localPoints = new ArrayList<DataPointPair<Double>>(knn.size());
        double maxD = knn.get(knn.size() - 1).getPair();
        for (int i = 0; i < knn.size(); ++i) {
            VecPaired<VecPaired<Vec, Double>, Double> v = knn.get(i);
            DataPoint dp = new DataPoint(v, new int[0], new CategoricalData[0], this.kf.k(v.getPair() / maxD));
            localPoints.add(new DataPointPair<Double>(dp, v.getVector().getPair()));
        }
        RegressionDataSet localSet = new RegressionDataSet(localPoints);
        Regressor localRegressor = this.regressor.clone();
        localRegressor.train(localSet);
        return localRegressor.regress(data);
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        List<VecPaired<Vec, Double>> trainList = this.getVecList(dataSet);
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, threadPool);
        this.vc = this.vcf.getVectorCollection(trainList, this.dm, threadPool);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        List<VecPaired<Vec, Double>> trainList = this.getVecList(dataSet);
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
        this.vc = this.vcf.getVectorCollection(trainList, this.dm);
    }

    @Override
    public LWL clone() {
        return new LWL(this);
    }

    private List<VecPaired<Vec, Double>> getVecList(ClassificationDataSet dataSet) {
        ArrayList<VecPaired<Vec, Double>> trainList = new ArrayList<VecPaired<Vec, Double>>(dataSet.getSampleSize());
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            trainList.add(new VecPaired<Vec, Double>(dataSet.getDataPoint(i).getNumericalValues(), new Double(dataSet.getDataPointCategory(i))));
        }
        return trainList;
    }

    private List<VecPaired<Vec, Double>> getVecList(RegressionDataSet dataSet) {
        ArrayList<VecPaired<Vec, Double>> trainList = new ArrayList<VecPaired<Vec, Double>>(dataSet.getSampleSize());
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            trainList.add(new VecPaired<Vec, Double>(dataSet.getDataPoint(i).getNumericalValues(), dataSet.getTargetValue(i)));
        }
        return trainList;
    }

    private void setClassifier(Classifier classifier) {
        this.classifier = classifier;
        if (classifier instanceof Regressor) {
            this.regressor = (Regressor)((Object)classifier);
        }
    }

    private void setRegressor(Regressor regressor) {
        this.regressor = regressor;
        if (regressor instanceof Classifier) {
            this.classifier = (Classifier)((Object)regressor);
        }
    }

    public void setNeighbors(int k) {
        if (k <= 1) {
            throw new RuntimeException("An average requires at least 2 neighbors to be taken into account");
        }
        this.k = k;
    }

    public int getNeighbors() {
        return this.k;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setKernelFunction(KernelFunction kf) {
        this.kf = kf;
    }

    public KernelFunction getKernelFunction() {
        return this.kf;
    }

    public static Distribution guessNeighbors(DataSet d) {
        return new UniformDiscrete(25, Math.min(200, d.getSampleSize() / 5));
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

