/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.distancemetrics;

import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;
import jsat.linear.VecOps;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.math.FunctionBase;
import jsat.math.MathTricks;
import jsat.regression.RegressionDataSet;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

public class NormalizedEuclideanDistance
extends TrainableDistanceMetric {
    private static final long serialVersionUID = 210109457671623688L;
    private Vec invStndDevs;

    @Override
    public <V extends Vec> void train(List<V> dataSet) {
        this.invStndDevs = MatrixStatistics.covarianceDiag(MatrixStatistics.meanVector(dataSet), dataSet);
        this.invStndDevs.applyFunction(MathTricks.sqrdFunc);
        this.invStndDevs.applyFunction(MathTricks.invsFunc);
    }

    @Override
    public <V extends Vec> void train(List<V> dataSet, ExecutorService threadpool) {
        this.train(dataSet);
    }

    @Override
    public void train(DataSet dataSet) {
        this.invStndDevs = dataSet.getColumnMeanVariance()[1];
        this.invStndDevs.applyFunction(MathTricks.sqrdFunc);
        this.invStndDevs.applyFunction(MathTricks.invsFunc);
    }

    @Override
    public void train(DataSet dataSet, ExecutorService threadpool) {
        this.train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        this.train((DataSet)dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet, ExecutorService threadpool) {
        this.train(dataSet);
    }

    @Override
    public boolean supportsClassificationTraining() {
        return true;
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train((DataSet)dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadpool) {
        this.train(dataSet);
    }

    @Override
    public boolean supportsRegressionTraining() {
        return true;
    }

    @Override
    public boolean needsTraining() {
        return this.invStndDevs == null;
    }

    @Override
    public NormalizedEuclideanDistance clone() {
        NormalizedEuclideanDistance clone = new NormalizedEuclideanDistance();
        if (this.invStndDevs != null) {
            clone.invStndDevs = this.invStndDevs.clone();
        }
        return clone;
    }

    @Override
    public double dist(Vec a, Vec b) {
        double r = VecOps.accumulateSum(this.invStndDevs, a, b, new FunctionBase(){
            private static final long serialVersionUID = 3190953661114076430L;

            @Override
            public double f(Vec x) {
                return Math.pow(x.get(0), 2.0);
            }
        });
        return Math.sqrt(r);
    }

    @Override
    public boolean isSymmetric() {
        return true;
    }

    @Override
    public boolean isSubadditive() {
        return true;
    }

    @Override
    public boolean isIndiscemible() {
        return true;
    }

    @Override
    public double metricBound() {
        return Double.POSITIVE_INFINITY;
    }

    @Override
    public boolean supportsAcceleration() {
        return true;
    }

    @Override
    public List<Double> getAccelerationCache(List<? extends Vec> vecs) {
        DoubleList cache = new DoubleList(vecs.size());
        for (Vec vec : vecs) {
            cache.add(VecOps.weightedDot(this.invStndDevs, vec, vec));
        }
        return cache;
    }

    @Override
    public List<Double> getAccelerationCache(final List<? extends Vec> vecs, ExecutorService threadpool) {
        if (threadpool == null || threadpool instanceof FakeExecutor) {
            return this.getAccelerationCache(vecs);
        }
        final double[] cache = new double[vecs.size()];
        int P = Math.min(SystemInfo.LogicalCores, vecs.size());
        final CountDownLatch latch = new CountDownLatch(P);
        for (int ID = 0; ID < P; ++ID) {
            final int start = ParallelUtils.getStartBlock(cache.length, ID, P);
            final int end = ParallelUtils.getEndBlock(cache.length, ID, P);
            threadpool.submit(new Runnable(){

                @Override
                public void run() {
                    for (int i = start; i < end; ++i) {
                        cache[i] = VecOps.weightedDot(NormalizedEuclideanDistance.this.invStndDevs, (Vec)vecs.get(i), (Vec)vecs.get(i));
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(NormalizedEuclideanDistance.class.getName()).log(Level.SEVERE, null, ex);
        }
        return DoubleList.view(cache, cache.length);
    }

    @Override
    public double dist(int a, int b, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), vecs.get(b));
        }
        return Math.sqrt(cache.get(a) + cache.get(b) - 2.0 * VecOps.weightedDot(this.invStndDevs, vecs.get(a), vecs.get(b)));
    }

    @Override
    public double dist(int a, Vec b, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), b);
        }
        return Math.sqrt(cache.get(a) + VecOps.weightedDot(this.invStndDevs, b, b) - 2.0 * VecOps.weightedDot(this.invStndDevs, vecs.get(a), b));
    }

    @Override
    public List<Double> getQueryInfo(Vec q) {
        DoubleList qi = new DoubleList(1);
        qi.add(VecOps.weightedDot(this.invStndDevs, q, q));
        return qi;
    }

    @Override
    public double dist(int a, Vec b, List<Double> qi, List<? extends Vec> vecs, List<Double> cache) {
        if (cache == null) {
            return this.dist(vecs.get(a), b);
        }
        return Math.sqrt(cache.get(a) + qi.get(0) - 2.0 * VecOps.weightedDot(this.invStndDevs, vecs.get(a), b));
    }
}

