/*
 * Decompiled with CFR 0.152.
 */
package elki.datasource.filter.transform;

import elki.data.ClassLabel;
import elki.data.NumberVector;
import elki.datasource.filter.transform.AbstractSupervisedProjectionVectorFilter;
import elki.logging.Logging;
import elki.math.linearalgebra.Centroid;
import elki.math.linearalgebra.CovarianceMatrix;
import elki.math.linearalgebra.EigenvalueDecomposition;
import elki.math.linearalgebra.LUDecomposition;
import elki.math.linearalgebra.VMath;
import elki.math.linearalgebra.pca.PCAResult;
import elki.utilities.Alias;
import elki.utilities.documentation.Reference;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.ints.IntListIterator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

@Alias(value={"lda"})
@Reference(authors="R. A. Fisher", title="The use of multiple measurements in taxonomic problems", booktitle="Annals of Eugenics 7.2", url="https://doi.org/10.1111/j.1469-1809.1936.tb02137.x", bibkey="doi:10.1111/j.1469-1809.1936.tb02137.x")
public class LinearDiscriminantAnalysisFilter<V extends NumberVector>
extends AbstractSupervisedProjectionVectorFilter<V> {
    private static final Logging LOG = Logging.getLogger(LinearDiscriminantAnalysisFilter.class);

    public LinearDiscriminantAnalysisFilter(int projdimension) {
        super(projdimension);
    }

    @Override
    protected double[][] computeProjectionMatrix(List<V> vectorcolumn, List<? extends ClassLabel> classcolumn, int dim) {
        int i;
        Map<ClassLabel, IntList> classes = this.partition(classcolumn);
        ArrayList<ClassLabel> keys = new ArrayList<ClassLabel>(classes.keySet());
        List<Centroid> centroids = this.computeCentroids(dim, vectorcolumn, keys, classes);
        CovarianceMatrix covmake = new CovarianceMatrix(dim);
        for (Centroid c : centroids) {
            covmake.put((NumberVector)c);
        }
        double[][] sigmaB = covmake.destroyToSampleMatrix();
        covmake.reset();
        int numc = keys.size();
        for (i = 0; i < numc; ++i) {
            double[] c = centroids.get(i).getArrayRef();
            IntListIterator it = classes.get(keys.get(i)).iterator();
            while (it.hasNext()) {
                covmake.put(VMath.minusEquals((double[])((NumberVector)vectorcolumn.get(it.nextInt())).toArray(), (double[])c));
            }
        }
        double[][] sigmaI = covmake.destroyToSampleMatrix();
        if (new LUDecomposition(sigmaI).det() == 0.0) {
            i = 0;
            while (i < dim) {
                double[] dArray = sigmaI[i];
                int n = i++;
                dArray[n] = dArray[n] + 1.0E-10;
            }
        }
        double[][] sol = VMath.times((double[][])VMath.inverse((double[][])sigmaI), (double[][])sigmaB);
        EigenvalueDecomposition evd = new EigenvalueDecomposition(sol);
        return (double[][])Arrays.copyOf(new PCAResult(evd).getEigenvectors(), this.tdim);
    }

    protected List<Centroid> computeCentroids(int dim, List<V> vectorcolumn, List<ClassLabel> keys, Map<ClassLabel, IntList> classes) {
        int numc = keys.size();
        ArrayList<Centroid> centroids = new ArrayList<Centroid>(numc);
        for (int i = 0; i < numc; ++i) {
            Centroid c = new Centroid(dim);
            IntListIterator it = classes.get(keys.get(i)).iterator();
            while (it.hasNext()) {
                c.put((NumberVector)vectorcolumn.get(it.nextInt()));
            }
            centroids.add(c);
        }
        return centroids;
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Par<V extends NumberVector>
    extends AbstractSupervisedProjectionVectorFilter.Par<V> {
        public LinearDiscriminantAnalysisFilter<V> make() {
            return new LinearDiscriminantAnalysisFilter(this.tdim);
        }
    }
}

