/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.distancemetrics;

import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.linear.LUPDecomposition;
import jsat.linear.Matrix;
import jsat.linear.MatrixStatistics;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.regression.RegressionDataSet;

public class MahalanobisDistance
extends TrainableDistanceMetric {
    private static final long serialVersionUID = 7878528119699276817L;
    private boolean reTrain = true;
    private Matrix S;

    public boolean isReTrain() {
        return this.reTrain;
    }

    public void setReTrain(boolean reTrain) {
        this.reTrain = reTrain;
    }

    @Override
    public <V extends Vec> void train(List<V> dataSet) {
        this.train(dataSet, null);
    }

    @Override
    public <V extends Vec> void train(List<V> dataSet, ExecutorService threadpool) {
        Vec mean = MatrixStatistics.meanVector(dataSet);
        Matrix covariance = MatrixStatistics.covarianceMatrix(mean, dataSet);
        LUPDecomposition lup = threadpool != null ? new LUPDecomposition(covariance.clone(), threadpool) : new LUPDecomposition(covariance.clone());
        double det = lup.det();
        if (Double.isNaN(det) || Double.isInfinite(det) || Math.abs(det) <= 1.0E-13) {
            lup = null;
            SingularValueDecomposition svd = new SingularValueDecomposition(covariance);
            this.S = svd.getPseudoInverse();
        } else {
            this.S = threadpool != null ? lup.solve(Matrix.eye(covariance.cols()), threadpool) : lup.solve(Matrix.eye(covariance.cols()));
        }
    }

    @Override
    public void train(DataSet dataSet) {
        this.train(dataSet, null);
    }

    @Override
    public void train(DataSet dataSet, ExecutorService threadpool) {
        this.train(dataSet.getDataVectors(), threadpool);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        this.train((DataSet)dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet, ExecutorService threadpool) {
        this.train((DataSet)dataSet, threadpool);
    }

    @Override
    public boolean supportsClassificationTraining() {
        return true;
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train((DataSet)dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadpool) {
        this.train((DataSet)dataSet, threadpool);
    }

    @Override
    public boolean supportsRegressionTraining() {
        return true;
    }

    @Override
    public boolean needsTraining() {
        if (this.S == null) {
            return true;
        }
        return this.isReTrain();
    }

    @Override
    public double dist(Vec a, Vec b) {
        Vec aMb = a.subtract(b);
        return Math.sqrt(aMb.dot(this.S.multiply(aMb)));
    }

    @Override
    public boolean isSymmetric() {
        return true;
    }

    @Override
    public boolean isSubadditive() {
        return true;
    }

    @Override
    public boolean isIndiscemible() {
        return true;
    }

    @Override
    public double metricBound() {
        return Double.POSITIVE_INFINITY;
    }

    @Override
    public String toString() {
        return "Mahalanobis Distance";
    }

    @Override
    public MahalanobisDistance clone() {
        MahalanobisDistance clone = new MahalanobisDistance();
        clone.reTrain = this.reTrain;
        if (this.S != null) {
            clone.S = this.S.clone();
        }
        return clone;
    }

    @Override
    public boolean supportsAcceleration() {
        return false;
    }

    @Override
    public List<Double> getAccelerationCache(List<? extends Vec> vecs) {
        return null;
    }

    @Override
    public double dist(int a, int b, List<? extends Vec> vecs, List<Double> cache) {
        return this.dist(vecs.get(a), vecs.get(b));
    }

    @Override
    public double dist(int a, Vec b, List<? extends Vec> vecs, List<Double> cache) {
        return this.dist(vecs.get(a), b);
    }

    @Override
    public List<Double> getQueryInfo(Vec q) {
        return null;
    }

    @Override
    public List<Double> getAccelerationCache(List<? extends Vec> vecs, ExecutorService threadpool) {
        return null;
    }

    @Override
    public double dist(int a, Vec b, List<Double> qi, List<? extends Vec> vecs, List<Double> cache) {
        return this.dist(vecs.get(a), b);
    }
}

