/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear;

import java.util.List;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.Vec;

public class MatrixStatistics {
    private MatrixStatistics() {
    }

    public static <V extends Vec> Vec meanVector(List<V> dataSet) {
        if (dataSet.isEmpty()) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        DenseVector mean = new DenseVector(((Vec)dataSet.get(0)).length());
        MatrixStatistics.meanVector((Vec)mean, dataSet);
        return mean;
    }

    public static Vec meanVector(DataSet dataSet) {
        DenseVector dv = new DenseVector(dataSet.getNumNumericalVars());
        MatrixStatistics.meanVector((Vec)dv, dataSet);
        return dv;
    }

    public static <V extends Vec> void meanVector(Vec mean, List<V> dataSet) {
        if (dataSet.isEmpty()) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        if (((Vec)dataSet.get(0)).length() != mean.length()) {
            throw new ArithmeticException("Vector dimensions do not agree");
        }
        for (Vec x : dataSet) {
            mean.mutableAdd(x);
        }
        mean.mutableDivide(dataSet.size());
    }

    public static void meanVector(Vec mean, DataSet dataSet) {
        if (dataSet.getSampleSize() == 0) {
            throw new ArithmeticException("Can not compute the mean of zero data points");
        }
        double sumOfWeights = 0.0;
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            DataPoint dp = dataSet.getDataPoint(i);
            double w = dp.getWeight();
            sumOfWeights += w;
            mean.mutableAdd(w, dp.getNumericalValues());
        }
        mean.mutableDivide(sumOfWeights);
    }

    public static <V extends Vec> Matrix covarianceMatrix(Vec mean, List<V> dataSet) {
        DenseMatrix coMatrix = new DenseMatrix(mean.length(), mean.length());
        MatrixStatistics.covarianceMatrix(mean, coMatrix, dataSet);
        return coMatrix;
    }

    public static <V extends Vec> void covarianceMatrix(Vec mean, Matrix covariance, List<V> dataSet) {
        if (!covariance.isSquare()) {
            throw new ArithmeticException("Storage for covariance matrix must be square");
        }
        if (covariance.rows() != mean.length()) {
            throw new ArithmeticException("Covariance Matrix size and mean size do not agree");
        }
        if (dataSet.isEmpty()) {
            throw new ArithmeticException("No data points to compute covariance from");
        }
        if (mean.length() != ((Vec)dataSet.get(0)).length()) {
            throw new ArithmeticException("Data vectors do not agree with mean and covariance matrix");
        }
        DenseVector scratch = new DenseVector(mean.length());
        for (Vec x : dataSet) {
            x.copyTo(scratch);
            scratch.mutableSubtract(mean);
            Matrix.OuterProductUpdate(covariance, scratch, scratch, 1.0);
        }
        covariance.mutableMultiply(1.0 / ((double)dataSet.size() - 1.0));
    }

    public static void covarianceMatrix(Vec mean, List<DataPoint> dataSet, Matrix covariance) {
        double sumOfWeights = 0.0;
        double sumOfSquaredWeights = 0.0;
        for (DataPoint dp : dataSet) {
            sumOfWeights += dp.getWeight();
            sumOfSquaredWeights += Math.pow(dp.getWeight(), 2.0);
        }
        MatrixStatistics.covarianceMatrix(mean, dataSet, covariance, sumOfWeights, sumOfSquaredWeights);
    }

    public static void covarianceMatrix(Vec mean, List<DataPoint> dataSet, Matrix covariance, double sumOfWeights, double sumOfSquaredWeights) {
        if (!covariance.isSquare()) {
            throw new ArithmeticException("Storage for covariance matrix must be square");
        }
        if (covariance.rows() != mean.length()) {
            throw new ArithmeticException("Covariance Matrix size and mean size do not agree");
        }
        if (dataSet.isEmpty()) {
            throw new ArithmeticException("No data points to compute covariance from");
        }
        if (mean.length() != dataSet.get(0).getNumericalValues().length()) {
            throw new ArithmeticException("Data vectors do not agree with mean and covariance matrix");
        }
        DenseVector scratch = new DenseVector(mean.length());
        for (int i = 0; i < dataSet.size(); ++i) {
            DataPoint dp = dataSet.get(i);
            Vec x = dp.getNumericalValues();
            x.copyTo(scratch);
            scratch.mutableSubtract(mean);
            Matrix.OuterProductUpdate(covariance, scratch, scratch, dp.getWeight());
        }
        covariance.mutableMultiply(sumOfWeights / (Math.pow(sumOfWeights, 2.0) - sumOfSquaredWeights));
    }

    public static Matrix covarianceMatrix(Vec mean, DataSet dataSet) {
        DenseMatrix covariance = new DenseMatrix(mean.length(), mean.length());
        MatrixStatistics.covarianceMatrix(mean, dataSet, (Matrix)covariance);
        return covariance;
    }

    public static void covarianceMatrix(Vec mean, DataSet dataSet, Matrix covariance) {
        double sumOfWeights = 0.0;
        double sumOfSquaredWeights = 0.0;
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            DataPoint dp = dataSet.getDataPoint(i);
            sumOfWeights += dp.getWeight();
            sumOfSquaredWeights += Math.pow(dp.getWeight(), 2.0);
        }
        MatrixStatistics.covarianceMatrix(mean, dataSet, covariance, sumOfWeights, sumOfSquaredWeights);
    }

    public static void covarianceMatrix(Vec mean, DataSet dataSet, Matrix covariance, double sumOfWeights, double sumOfSquaredWeights) {
        if (!covariance.isSquare()) {
            throw new ArithmeticException("Storage for covariance matrix must be square");
        }
        if (covariance.rows() != mean.length()) {
            throw new ArithmeticException("Covariance Matrix size and mean size do not agree");
        }
        if (dataSet.getSampleSize() == 0) {
            throw new ArithmeticException("No data points to compute covariance from");
        }
        if (mean.length() != dataSet.getNumNumericalVars()) {
            throw new ArithmeticException("Data vectors do not agree with mean and covariance matrix");
        }
        DenseVector scratch = new DenseVector(mean.length());
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            DataPoint dp = dataSet.getDataPoint(i);
            Vec x = dp.getNumericalValues();
            x.copyTo(scratch);
            scratch.mutableSubtract(mean);
            Matrix.OuterProductUpdate(covariance, scratch, scratch, dp.getWeight());
        }
        covariance.mutableMultiply(sumOfWeights / (Math.pow(sumOfWeights, 2.0) - sumOfSquaredWeights));
    }

    public static void covarianceDiag(Vec means, Vec diag, DataSet dataset) {
        int i;
        int n = dataset.getSampleSize();
        int d = dataset.getNumNumericalVars();
        int[] nnzCounts = new int[d];
        double sumOfWeights = 0.0;
        for (i = 0; i < n; ++i) {
            DataPoint dp = dataset.getDataPoint(i);
            double w = dp.getWeight();
            sumOfWeights += w;
            Vec x = dataset.getDataPoint(i).getNumericalValues();
            for (IndexValue iv : x) {
                int indx;
                int n2 = indx = iv.getIndex();
                nnzCounts[n2] = nnzCounts[n2] + 1;
                diag.increment(indx, w * Math.pow(iv.getValue() - means.get(indx), 2.0));
            }
        }
        for (i = 0; i < nnzCounts.length; ++i) {
            diag.increment(i, Math.pow(means.get(i), 2.0) * (double)(n - nnzCounts[i]));
        }
        diag.mutableDivide(sumOfWeights);
    }

    public static Vec covarianceDiag(Vec means, DataSet dataset) {
        DenseVector diag = new DenseVector(dataset.getNumNumericalVars());
        MatrixStatistics.covarianceDiag(means, (Vec)diag, dataset);
        return diag;
    }

    public static <V extends Vec> void covarianceDiag(Vec means, Vec diag, List<V> dataset) {
        int i;
        int n = dataset.size();
        int d = ((Vec)dataset.get(0)).length();
        int[] nnzCounts = new int[d];
        for (i = 0; i < n; ++i) {
            Vec x = (Vec)dataset.get(i);
            for (IndexValue iv : x) {
                int indx;
                int n2 = indx = iv.getIndex();
                nnzCounts[n2] = nnzCounts[n2] + 1;
                diag.increment(indx, Math.pow(iv.getValue() - means.get(indx), 2.0));
            }
        }
        for (i = 0; i < nnzCounts.length; ++i) {
            diag.increment(i, Math.pow(means.get(i), 2.0) * (double)(n - nnzCounts[i]));
        }
        diag.mutableDivide(n);
    }

    public static <V extends Vec> Vec covarianceDiag(Vec means, List<V> dataset) {
        int d = ((Vec)dataset.get(0)).length();
        DenseVector diag = new DenseVector(d);
        MatrixStatistics.covarianceDiag(means, (Vec)diag, dataset);
        return diag;
    }
}

