/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.multivariate;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import jsat.classifiers.DataPoint;
import jsat.distributions.empirical.KernelDensityEstimator;
import jsat.distributions.empirical.kernelfunc.EpanechnikovKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.distributions.multivariate.MultivariateKDE;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.utils.IndexTable;
import jsat.utils.IntSet;

public class ProductKDE
extends MultivariateKDE {
    private static final long serialVersionUID = 7298078759216991650L;
    private KernelFunction k;
    private double[][] sortedDimVals;
    private double[] bandwidth;
    private int[][] sortedIndexVals;
    private List<Vec> originalVecs;

    public ProductKDE() {
        this(EpanechnikovKF.getInstance());
    }

    public ProductKDE(KernelFunction k) {
        this.k = k;
    }

    @Override
    public ProductKDE clone() {
        int i;
        ProductKDE clone = new ProductKDE();
        if (this.k != null) {
            clone.k = this.k;
        }
        if (this.sortedDimVals != null) {
            clone.sortedDimVals = new double[this.sortedDimVals.length][];
            for (i = 0; i < this.sortedDimVals.length; ++i) {
                clone.sortedDimVals[i] = Arrays.copyOf(this.sortedDimVals[i], this.sortedDimVals[i].length);
            }
        }
        if (this.sortedIndexVals != null) {
            clone.sortedIndexVals = new int[this.sortedIndexVals.length][];
            for (i = 0; i < this.sortedIndexVals.length; ++i) {
                clone.sortedIndexVals[i] = Arrays.copyOf(this.sortedIndexVals[i], this.sortedIndexVals[i].length);
            }
        }
        if (this.bandwidth != null) {
            clone.bandwidth = Arrays.copyOf(this.bandwidth, this.bandwidth.length);
        }
        if (this.originalVecs != null) {
            clone.originalVecs = new ArrayList<Vec>(this.originalVecs);
        }
        return clone;
    }

    public List<VecPaired<VecPaired<Vec, Integer>, Double>> getNearby(Vec x) {
        SparseVector logProd = new SparseVector(this.sortedDimVals[0].length);
        IntSet validIndecies = new IntSet();
        double logH = this.queryWork(x, validIndecies, logProd);
        ArrayList<VecPaired<VecPaired<Vec, Integer>, Double>> results = new ArrayList<VecPaired<VecPaired<Vec, Integer>, Double>>(validIndecies.size());
        Iterator iterator = validIndecies.iterator();
        while (iterator.hasNext()) {
            int i = (Integer)iterator.next();
            Vec v = this.originalVecs.get(i);
            results.add(new VecPaired<VecPaired<Vec, Integer>, Double>(new VecPaired<Vec, Integer>(v, i), Math.exp(logProd.get(i))));
        }
        return results;
    }

    public List<VecPaired<VecPaired<Vec, Integer>, Double>> getNearbyRaw(Vec x) {
        throw new UnsupportedOperationException("Product KDE can not recover raw Score values");
    }

    @Override
    public double pdf(Vec x) {
        double PDF = 0.0;
        int N = this.sortedDimVals[0].length;
        SparseVector logProd = new SparseVector(this.sortedDimVals[0].length);
        IntSet validIndecies = new IntSet();
        double logH = this.queryWork(x, validIndecies, logProd);
        Iterator iterator = validIndecies.iterator();
        while (iterator.hasNext()) {
            int i = (Integer)iterator.next();
            PDF += Math.exp(logProd.get(i) - logH);
        }
        return PDF / (double)N;
    }

    private double queryWork(Vec x, Set<Integer> validIndecies, SparseVector logProd) {
        if (this.originalVecs == null) {
            throw new UntrainedModelException("Model has not yet been created, queries can not be perfomed");
        }
        double logH = 0.0;
        for (int i = 0; i < this.sortedDimVals.length; ++i) {
            double[] X = this.sortedDimVals[i];
            double h = this.bandwidth[i];
            logH += Math.log(h);
            double xi = x.get(i);
            int from = Arrays.binarySearch(X, xi - h * this.k.cutOff());
            int to = Arrays.binarySearch(X, xi + h * this.k.cutOff());
            from = from < 0 ? -from - 1 : from;
            to = to < 0 ? -to - 1 : to;
            IntSet subIndecies = new IntSet();
            for (int j = Math.max(0, from); j < Math.min(X.length, to + 1); ++j) {
                int trueIndex = this.sortedIndexVals[i][j];
                if (i == 0) {
                    validIndecies.add(trueIndex);
                    logProd.set(trueIndex, Math.log(this.k.k((xi - X[j]) / h)));
                    continue;
                }
                if (!validIndecies.contains(trueIndex)) continue;
                logProd.increment(trueIndex, Math.log(this.k.k((xi - X[j]) / h)));
                subIndecies.add(Integer.valueOf(trueIndex));
            }
            if (i <= 0) continue;
            validIndecies.retainAll(subIndecies);
            if (validIndecies.isEmpty()) break;
        }
        return logH;
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet) {
        int j;
        int i;
        int dimSize = ((Vec)dataSet.get(0)).length();
        this.sortedDimVals = new double[dimSize][dataSet.size()];
        this.sortedIndexVals = new int[dimSize][dataSet.size()];
        this.bandwidth = new double[dimSize];
        for (i = 0; i < dataSet.size(); ++i) {
            Vec v = (Vec)dataSet.get(i);
            for (j = 0; j < v.length(); ++j) {
                this.sortedDimVals[j][i] = v.get(j);
            }
        }
        for (i = 0; i < dimSize; ++i) {
            IndexTable idt = new IndexTable(this.sortedDimVals[i]);
            for (j = 0; j < idt.length(); ++j) {
                this.sortedIndexVals[i][j] = idt.index(j);
            }
            idt.apply(this.sortedDimVals[i]);
            this.bandwidth[i] = KernelDensityEstimator.BandwithGuassEstimate(DenseVector.toDenseVec(this.sortedDimVals[i])) * (double)dimSize;
        }
        this.originalVecs = dataSet;
        return true;
    }

    @Override
    public boolean setUsingDataList(List<DataPoint> dataPoints) {
        ArrayList<Vec> dataSet = new ArrayList<Vec>(dataPoints.size());
        for (DataPoint dp : dataPoints) {
            dataSet.add(dp.getNumericalValues());
        }
        return this.setUsingData(dataSet);
    }

    @Override
    public List<Vec> sample(int count, Random rand) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public KernelFunction getKernelFunction() {
        return this.k;
    }

    @Override
    public void scaleBandwidth(double scale) {
        int i = 0;
        while (i < this.bandwidth.length) {
            int n = i++;
            this.bandwidth[n] = this.bandwidth[n] * 2.0;
        }
    }
}

