/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;

public class GMeans<C extends Clusterable> {
    private List<double[]> center = new ArrayList<double[]>();
    private List<CentroidCluster<C>> gmeansCluster = new ArrayList<CentroidCluster<C>>();
    private Map<double[], List<C>> currentPoints = new HashMap<double[], List<C>>();
    private Map<double[], List<C>> intermediatePoints = new HashMap<double[], List<C>>();
    private List<C> points = new ArrayList<C>();
    private DistanceMeasure distanceMeasure = new ManhattanDistance();
    private RandomGenerator randomGenerator;

    public GMeans(Collection<C> toClusterPoints) {
        this(toClusterPoints, (DistanceMeasure)new ManhattanDistance(), 1L);
    }

    public GMeans(Collection<C> toClusterPoints, DistanceMeasure distanceMeasure, long seed) {
        this.points = new ArrayList<C>(toClusterPoints);
        this.distanceMeasure = distanceMeasure;
        this.gmeansCluster = new ArrayList<CentroidCluster<C>>();
        this.randomGenerator = new JDKRandomGenerator();
        this.randomGenerator.setSeed(seed);
    }

    public List<CentroidCluster<C>> cluster() {
        HashMap<Integer, double[]> positionOfCenter = new HashMap<Integer, double[]>();
        int tmp = 1;
        int k = 1;
        int i = 1;
        KMeansPlusPlusClusterer test = new KMeansPlusPlusClusterer(k, -1, this.distanceMeasure, this.randomGenerator);
        List currentPointsTemp = test.cluster(this.points);
        for (CentroidCluster centroidCluster : currentPointsTemp) {
            this.currentPoints.put(centroidCluster.getCenter().getPoint(), centroidCluster.getPoints());
        }
        for (double[] dArray : this.currentPoints.keySet()) {
            this.center.add(dArray);
        }
        for (double[] dArray : this.center) {
            positionOfCenter.put(tmp, dArray);
            ++tmp;
        }
        while (i <= k) {
            List<C> loopPoints = this.currentPoints.get(positionOfCenter.get(i));
            KMeansPlusPlusClusterer kMeansPlusPlusClusterer = new KMeansPlusPlusClusterer(2, -1, this.distanceMeasure, this.randomGenerator);
            ArrayList<double[]> intermediateCenter = new ArrayList<double[]>(2);
            if (loopPoints.size() < 2) break;
            List intermediatePointsTemp = kMeansPlusPlusClusterer.cluster(loopPoints);
            for (CentroidCluster centroidCluster : intermediatePointsTemp) {
                this.intermediatePoints.put(centroidCluster.getCenter().getPoint(), centroidCluster.getPoints());
                intermediateCenter.add(centroidCluster.getCenter().getPoint());
            }
            double[] v = this.difference((double[])intermediateCenter.get(0), (double[])intermediateCenter.get(1));
            double w = 0.0;
            for (int l = 0; l < v.length; ++l) {
                if (Double.isNaN(v[l])) continue;
                w += Math.pow(v[l], 2.0);
            }
            if (w == 0.0) {
                throw new IllegalStateException("All entries in v are NaN, cannot compute w!");
            }
            double[] y = new double[loopPoints.size()];
            for (int r = 0; r < loopPoints.size(); ++r) {
                for (int p = 0; p < ((Clusterable)loopPoints.get(r)).getPoint().length; ++p) {
                    if (Double.isNaN(((Clusterable)loopPoints.get(r)).getPoint()[p]) || Double.isNaN(v[p])) continue;
                    int n = r;
                    y[n] = y[n] + v[p] * ((Clusterable)loopPoints.get(r)).getPoint()[p] / w;
                }
            }
            if (!this.andersonDarlingTest(y)) {
                this.currentPoints.remove(positionOfCenter.get(i));
                this.currentPoints.put((double[])intermediateCenter.get(0), this.intermediatePoints.get(intermediateCenter.get(0)));
                positionOfCenter.replace(i, (double[])intermediateCenter.get(0));
                this.currentPoints.put((double[])intermediateCenter.get(1), this.intermediatePoints.get(intermediateCenter.get(1)));
                positionOfCenter.put(++k, (double[])intermediateCenter.get(1));
                continue;
            }
            ++i;
        }
        this.mergeCluster(this.currentPoints);
        for (Map.Entry entry : this.currentPoints.entrySet()) {
            List<C> pointsInCluster = this.currentPoints.get(entry.getKey());
            CentroidCluster c = new CentroidCluster(entry::getKey);
            for (Clusterable point : pointsInCluster) {
                c.addPoint(point);
            }
            this.gmeansCluster.add(c);
        }
        return this.gmeansCluster;
    }

    protected void mergeCluster(Map<double[], List<C>> currentPoints) {
        ArrayList<double[]> toMergeCenter = new ArrayList<double[]>();
        for (Map.Entry<double[], List<C>> entry : currentPoints.entrySet()) {
            if (currentPoints.get(entry.getKey()).size() > 2) continue;
            toMergeCenter.add(entry.getKey());
        }
        for (double[] d : toMergeCenter) {
            List<C> tmp = currentPoints.remove(d);
            for (Clusterable tmpPoints : tmp) {
                double minDist = Double.MAX_VALUE;
                double[] myCenter = null;
                for (double[] c : currentPoints.keySet()) {
                    double tmpDist = this.distanceMeasure.compute(tmpPoints.getPoint(), c);
                    if (!(tmpDist <= minDist)) continue;
                    myCenter = c;
                    minDist = tmpDist;
                }
                currentPoints.get(myCenter).add(tmpPoints);
            }
        }
    }

    protected boolean andersonDarlingTest(double[] d) {
        Arrays.sort(d);
        double mean = 0.0;
        double variance = 0.0;
        int totalvalue = 0;
        for (double i : d) {
            if (Double.isNaN(i)) continue;
            ++totalvalue;
            mean += i;
        }
        mean /= (double)totalvalue;
        totalvalue = 0;
        for (double i : d) {
            if (Double.isNaN(i)) continue;
            variance += Math.pow(i - mean, 2.0);
            ++totalvalue;
        }
        double[] y = this.standraizeRandomVariable(d, mean, variance /= (double)(totalvalue - 1));
        double aSquare1 = -1.0 * (double)y.length;
        double aSquare2 = 0.0;
        NormalDistribution normal = new NormalDistribution(null, 0.0, 1.0);
        for (int i = 1; i < y.length; ++i) {
            if (Double.isNaN(y[i])) continue;
            aSquare2 += (double)(2 * i - 1) * (Math.log(normal.cumulativeProbability(y[i - 1])) + Math.log(1.0 - normal.cumulativeProbability(y[y.length - i])));
        }
        double aSqurestar = aSquare1 - (aSquare2 /= (double)y.length);
        if (y.length <= 10) {
            return aSqurestar <= 0.683;
        }
        if (y.length <= 20) {
            return aSqurestar <= 0.704;
        }
        if (y.length <= 50) {
            return aSqurestar <= 0.735;
        }
        if (y.length <= 100) {
            return aSqurestar <= 0.754;
        }
        return aSqurestar <= 0.787;
    }

    private double[] standraizeRandomVariable(double[] d, double mean, double variance) {
        double[] tmp = new double[d.length];
        for (int i = 0; i < tmp.length; ++i) {
            tmp[i] = !Double.isNaN(d[i]) ? (d[i] - mean) / Math.sqrt(variance) : Double.NaN;
        }
        return tmp;
    }

    protected double[] difference(double[] a, double[] b) {
        double[] c = new double[a.length];
        for (int i = 0; i < a.length; ++i) {
            c[i] = !Double.isNaN(a[i]) && !Double.isNaN(b[i]) ? a[i] - b[i] : Double.NaN;
        }
        return c;
    }

    protected List<double[]> getCentersModifiable() {
        return this.center;
    }

    public List<C> getPoints() {
        return this.points;
    }
}

