/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.em;

import elki.clustering.ClusteringAlgorithm;
import elki.clustering.em.EM;
import elki.clustering.em.models.TextbookMultivariateGaussianModel;
import elki.clustering.em.models.TextbookMultivariateGaussianModelFactory;
import elki.data.Cluster;
import elki.data.Clustering;
import elki.data.DoubleVector;
import elki.data.NumberVector;
import elki.data.model.EMModel;
import elki.data.type.SimpleTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.datastore.DataStore;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDataStore;
import elki.database.ids.ArrayModifiableDBIDs;
import elki.database.ids.DBIDArrayIter;
import elki.database.ids.DBIDArrayMIter;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.MaterializedRelation;
import elki.database.relation.Relation;
import elki.logging.Logging;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.Duration;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.math.MathUtil;
import elki.math.linearalgebra.ConstrainedQuadraticProblemSolver;
import elki.math.linearalgebra.VMath;
import elki.result.Metadata;
import elki.utilities.datastructures.arraylike.IntegerArray;
import elki.utilities.documentation.Description;
import elki.utilities.documentation.Reference;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.IntParameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import net.jafama.FastMath;

@Description(value="Gaussian mixture modeling accelerated using a kd-tree")
@Reference(authors="Andrew W. Moore", booktitle="Advances in Neural Information Processing Systems 11 (NIPS 1998)", title="Very Fast EM-based Mixture Model Clustering using Multiresolution kd-trees", bibkey="DBLP:conf/nips/Moore98")
public class KDTreeEM
implements ClusteringAlgorithm<Clustering<EMModel>> {
    private static final Logging LOG = Logging.getLogger(KDTreeEM.class);
    private TextbookMultivariateGaussianModelFactory mfactory;
    private boolean soft;
    private double delta;
    public static final SimpleTypeInformation<double[]> SOFT_TYPE = new SimpleTypeInformation(double[].class);
    private int k = 3;
    private double mbw;
    private double tau;
    private double tauClass;
    private int miniter;
    private int maxiter;
    protected ArrayModifiableDBIDs sorted;
    private List<TextbookMultivariateGaussianModel> models;
    private List<TextbookMultivariateGaussianModel> newmodels;
    private ConstrainedQuadraticProblemSolver solver;
    private double ipiPow;
    private double[] wsum;
    protected boolean exactAssign = false;

    public KDTreeEM(int k, double mbw, double tau, double tauclass, double delta, TextbookMultivariateGaussianModelFactory mfactory, int miniter, int maxiter, boolean soft, boolean exactAssign) {
        this.k = k;
        this.mbw = mbw;
        this.tau = tau;
        this.tauClass = tauclass;
        this.delta = delta;
        this.mfactory = mfactory;
        this.miniter = miniter;
        this.maxiter = maxiter;
        this.soft = soft;
        this.exactAssign = exactAssign;
    }

    public Clustering<EMModel> run(Relation<? extends NumberVector> relation) {
        int it;
        DBIDIter iter = relation.iterDBIDs();
        int dim = ((NumberVector)relation.get((DBIDRef)iter)).getDimensionality();
        this.sorted = DBIDUtil.newArray((DBIDs)relation.getDBIDs());
        double[] dimWidth = this.analyseDimWidth(relation);
        Duration buildtime = LOG.newDuration(this.getClass().getName() + ".kdtree.buildtime").begin();
        KDTree tree = new KDTree(relation, this.sorted, 0, this.sorted.size(), dimWidth, this.mbw);
        LOG.statistics((Statistic)buildtime.end());
        this.models = this.mfactory.buildInitialModels(relation, this.k);
        this.newmodels = new ArrayList<TextbookMultivariateGaussianModel>(this.k);
        for (int i = 0; i < this.k; ++i) {
            this.newmodels.add(new TextbookMultivariateGaussianModel(0.0, new double[dim]));
        }
        this.wsum = new double[this.k];
        DoubleStatistic likeStat = new DoubleStatistic(this.getClass().getName() + ".loglikelihood");
        this.solver = new ConstrainedQuadraticProblemSolver(dim);
        this.ipiPow = 1.0 / FastMath.pow((double)MathUtil.SQRTPI, (double)dim);
        int lastImprovement = 0;
        double bestLogLikelihood = Double.NEGATIVE_INFINITY;
        double logLikelihood = 0.0;
        for (it = 0; it < this.maxiter || this.maxiter < 0; ++it) {
            double oldLogLikelihood = logLikelihood;
            for (TextbookMultivariateGaussianModel c : this.newmodels) {
                c.beginEStep();
            }
            Arrays.fill(this.wsum, 0.0);
            logLikelihood = this.makeStats(tree, MathUtil.sequence((int)0, (int)this.k), null) / (double)relation.size();
            for (int i = 0; i < this.k; ++i) {
                double weight = this.wsum[i] / (double)relation.size();
                if (weight <= Double.MIN_NORMAL) {
                    LOG.warning((CharSequence)"A cluster has degenerated by pruning.");
                    this.newmodels.get(i).clone(this.models.get(i));
                    continue;
                }
                this.newmodels.get(i).finalizeEStep(weight, 0.0);
            }
            List<TextbookMultivariateGaussianModel> tmp = this.newmodels;
            this.newmodels = this.models;
            this.models = tmp;
            LOG.statistics((Statistic)likeStat.setDouble(logLikelihood));
            if (logLikelihood - bestLogLikelihood > this.delta) {
                lastImprovement = it;
                bestLogLikelihood = logLikelihood;
            }
            if (it >= this.miniter && (Math.abs(logLikelihood - oldLogLikelihood) <= this.delta || lastImprovement < it >> 1)) break;
        }
        ArrayList<ArrayModifiableDBIDs> hardClusters = new ArrayList<ArrayModifiableDBIDs>(this.k);
        for (int i = 0; i < this.k; ++i) {
            hardClusters.add(DBIDUtil.newArray());
        }
        WritableDataStore probClusterIGivenX = DataStoreUtil.makeStorage((DBIDs)relation.getDBIDs(), (int)10, double[].class);
        logLikelihood = this.exactAssign ? EM.assignProbabilitiesToInstances(relation, this.models, (WritableDataStore<double[]>)probClusterIGivenX, null) : this.makeStats(tree, MathUtil.sequence((int)0, (int)this.k), (WritableDataStore<double[]>)probClusterIGivenX) / (double)relation.size();
        LOG.statistics((Statistic)new LongStatistic(this.getClass().getName() + ".iterations", (long)it));
        LOG.statistics((Statistic)new DoubleStatistic(this.getClass().getName() + ".loglikelihood", logLikelihood));
        DBIDIter iditer = relation.iterDBIDs();
        while (iditer.valid()) {
            ((ModifiableDBIDs)hardClusters.get(VMath.argmax((double[])((double[])probClusterIGivenX.get((DBIDRef)iditer))))).add((DBIDRef)iditer);
            iditer.advance();
        }
        Clustering<EMModel> result = new Clustering<EMModel>();
        Metadata.of(result).setLongName("KDTreeEM Clustering");
        for (int i = 0; i < this.k; ++i) {
            result.addToplevelCluster(new Cluster<EMModel>((DBIDs)hardClusters.get(i), this.models.get(i).finalizeCluster()));
        }
        if (this.soft) {
            Metadata.hierarchyOf(result).addChild((Object)new MaterializedRelation("KDTreeEM Cluster Probabilities", SOFT_TYPE, relation.getDBIDs(), (DataStore)probClusterIGivenX));
        } else {
            probClusterIGivenX.destroy();
        }
        this.solver = null;
        this.newmodels = null;
        return result;
    }

    private double[] analyseDimWidth(Relation<? extends NumberVector> relation) {
        DBIDIter it = relation.iterDBIDs();
        NumberVector first = (NumberVector)relation.get((DBIDRef)it);
        int d = first.getDimensionality();
        double[] lowerBounds = first.toArray();
        double[] upperBounds = (double[])lowerBounds.clone();
        it.advance();
        while (it.valid()) {
            NumberVector x = (NumberVector)relation.get((DBIDRef)it);
            for (int i = 0; i < d; ++i) {
                double t = x.doubleValue(i);
                lowerBounds[i] = lowerBounds[i] < t ? lowerBounds[i] : t;
                upperBounds[i] = upperBounds[i] > t ? upperBounds[i] : t;
            }
            it.advance();
        }
        return VMath.minusEquals((double[])upperBounds, (double[])lowerBounds);
    }

    private int[] checkStoppingCondition(KDTree node, int[] indices) {
        if (!(this.models.get(0) instanceof TextbookMultivariateGaussianModel)) {
            return indices;
        }
        double[][] maxPnts = new double[this.models.size()][node.sum.length];
        double[][] minPnts = new double[this.models.size()][node.sum.length];
        double[][] limits = new double[this.models.size()][2];
        for (int i : indices) {
            this.calculateModelLimits(node, this.models.get(i), minPnts[i], maxPnts[i], limits[i]);
        }
        double maxDenomTotal = 0.0;
        double minDenomTotal = 0.0;
        for (int i : indices) {
            maxDenomTotal += this.models.get(i).getWeight() * limits[i][0];
            minDenomTotal += this.models.get(i).getWeight() * limits[i][1];
        }
        boolean prune = true;
        double maxMinWeight = Double.NEGATIVE_INFINITY;
        double[] wmaxs = new double[this.models.size()];
        double size = node.right - node.left;
        for (int i : indices) {
            double weight = this.models.get(i).getWeight();
            double wminDenom = maxDenomTotal + weight * (limits[i][0] - limits[i][1]);
            double wmaxDenom = minDenomTotal + weight * (limits[i][1] - limits[i][0]);
            double wmin = MathUtil.clamp((double)(weight * limits[i][0] / wminDenom), (double)0.0, (double)1.0);
            maxMinWeight = wmin > maxMinWeight ? wmin : maxMinWeight;
            wmaxs[i] = MathUtil.clamp((double)(weight * limits[i][1] / wmaxDenom), (double)0.0, (double)1.0);
            double maximumError = size * (wmaxs[i] - wmin);
            double minPossibleWeight = this.newmodels.get(i).getWeight() + wmin * size;
            if (!(maximumError > this.tau * minPossibleWeight)) continue;
            prune = false;
        }
        if (prune) {
            return null;
        }
        if (this.tauClass <= 0.0) {
            return indices;
        }
        IntegerArray result = new IntegerArray(indices.length);
        for (int i : indices) {
            if (!(wmaxs[i] >= this.tauClass * maxMinWeight)) continue;
            result.add(i);
        }
        return result.toArray();
    }

    private void calculateModelLimits(KDTree node, TextbookMultivariateGaussianModel model, double[] minpnt, double[] maxpnt, double[] ret) {
        double[] min = VMath.minus((double[])node.midpoint, (double[])node.halfwidth);
        double[] max = VMath.plusTimes((double[])min, (double[])node.halfwidth, (double)2.0);
        model.calculateModelLimits(min, max, this.solver, this.ipiPow, minpnt, maxpnt, ret);
    }

    private double makeStats(KDTree node, int[] indices, WritableDataStore<double[]> probs) {
        int[] nextIndices;
        int size = node.right - node.left;
        if (indices.length == 1) {
            DoubleVector midpoint = DoubleVector.wrap((double[])VMath.times((double[])node.sum, (double)(1.0 / (double)size)));
            double logDenSum = this.models.get(indices[0]).estimateLogDensity((NumberVector)midpoint);
            int n = indices[0];
            this.wsum[n] = this.wsum[n] + (double)size;
            this.newmodels.get(indices[0]).updateE(node.sum, node.sumSq, 1.0, size);
            if (probs != null) {
                double[] p = new double[this.k];
                p[indices[0]] = 1.0;
                DBIDArrayMIter it = this.sorted.iter().seek(node.left);
                while (it.getOffset() < node.right) {
                    probs.put((DBIDRef)it, (Object)p);
                    it.advance();
                }
            }
            return logDenSum * (double)size;
        }
        if (node.leftChild != null && (nextIndices = this.checkStoppingCondition(node, indices)) != null) {
            return this.makeStats(node.leftChild, nextIndices, probs) + this.makeStats(node.rightChild, nextIndices, probs);
        }
        DoubleVector midpoint = DoubleVector.wrap((double[])VMath.times((double[])node.sum, (double)(1.0 / (double)size)));
        double[] logProb = new double[indices.length];
        for (int i = 0; i < indices.length; ++i) {
            logProb[i] = this.models.get(indices[i]).estimateLogDensity((NumberVector)midpoint);
        }
        double logDenSum = EM.logSumExp(logProb);
        VMath.minusEquals((double[])logProb, (double)logDenSum);
        double[] ps = probs != null ? new double[this.k] : null;
        for (int i = 0; i < indices.length; ++i) {
            double p = FastMath.exp((double)logProb[i]);
            int n = indices[i];
            this.wsum[n] = this.wsum[n] + p * (double)size;
            this.newmodels.get(indices[i]).updateE(node.sum, node.sumSq, p, p * (double)size);
            if (ps == null) continue;
            ps[indices[i]] = p;
        }
        if (probs != null) {
            DBIDArrayMIter it = this.sorted.iter().seek(node.left);
            while (it.getOffset() < node.right) {
                probs.put((DBIDRef)it, (Object)ps);
                it.advance();
            }
        }
        return logDenSum * (double)size;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array((TypeInformation[])new TypeInformation[]{TypeUtil.NUMBER_VECTOR_FIELD});
    }

    public static class Par
    implements Parameterizer {
        public static final OptionID K_ID = EM.Par.K_ID;
        public static final OptionID DELTA_ID = EM.Par.DELTA_ID;
        public static final OptionID MBW_ID = new OptionID("emkd.mbw", "Pruning criterion for the KD-Tree during construction. Stop splitting when leafwidth < mbw * width.");
        public static final OptionID TAU_ID = new OptionID("emkd.tau", "Pruning criterion for the KD-Tree during algorithm. Stop traversing when error e < tau * totalweight.");
        public static final OptionID TAU_CLASS_ID = new OptionID("emkd.tauclass", "Parameter for pruning. Drop a class if w[c] < tauclass * max(wmins). Set to 0 to disable dropping of classes.");
        public static final OptionID MINITER_ID = EM.Par.MINITER_ID;
        public static final OptionID MAXITER_ID = EM.Par.MAXITER_ID;
        public static final OptionID SOFT_ID = EM.Par.SOFT_ID;
        public static final OptionID EXACT_ASSIGN_ID = new OptionID("emkd.exactassign", "Assign each point individually, not using the kd-tree in the final step.");
        protected int k;
        protected double mbw;
        protected double tau;
        protected double tauclass;
        protected double delta;
        protected TextbookMultivariateGaussianModelFactory mfactory;
        protected int miniter = 1;
        protected int maxiter = -1;
        boolean soft = false;
        boolean exactAssign = false;

        public void configure(Parameterization config) {
            ((IntParameter)new IntParameter(K_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).grab(config, x -> {
                this.k = x;
            });
            ((DoubleParameter)((DoubleParameter)new DoubleParameter(MBW_ID, 0.01).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE)).addConstraint((ParameterConstraint)CommonConstraints.LESS_THAN_ONE_DOUBLE)).grab(config, x -> {
                this.mbw = x;
            });
            ((DoubleParameter)((DoubleParameter)new DoubleParameter(TAU_ID, 0.01).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE)).addConstraint((ParameterConstraint)CommonConstraints.LESS_THAN_ONE_DOUBLE)).grab(config, x -> {
                this.tau = x;
            });
            ((DoubleParameter)((DoubleParameter)new DoubleParameter(TAU_CLASS_ID, 1.0E-4).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE)).addConstraint((ParameterConstraint)CommonConstraints.LESS_THAN_ONE_DOUBLE)).grab(config, x -> {
                this.tauclass = x;
            });
            ((DoubleParameter)new DoubleParameter(DELTA_ID, 1.0E-7).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE)).grab(config, x -> {
                this.delta = x;
            });
            this.mfactory = (TextbookMultivariateGaussianModelFactory)config.tryInstantiate(TextbookMultivariateGaussianModelFactory.class);
            ((IntParameter)((IntParameter)new IntParameter(MINITER_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_INT)).setOptional(true)).grab(config, x -> {
                this.miniter = x;
            });
            ((IntParameter)((IntParameter)new IntParameter(MAXITER_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_INT)).setOptional(true)).grab(config, x -> {
                this.maxiter = x;
            });
            new Flag(SOFT_ID).grab(config, x -> {
                this.soft = x;
            });
            new Flag(EXACT_ASSIGN_ID).grab(config, x -> {
                this.exactAssign = x;
            });
        }

        public KDTreeEM make() {
            return new KDTreeEM(this.k, this.mbw, this.tau, this.tauclass, this.delta, this.mfactory, this.miniter, this.maxiter, this.soft, this.exactAssign);
        }
    }

    static class KDTree {
        KDTree leftChild;
        KDTree rightChild;
        int left;
        int right;
        double[] sum;
        double[][] sumSq;
        double[] midpoint;
        double[] halfwidth;

        public KDTree(Relation<? extends NumberVector> relation, ArrayModifiableDBIDs sorted, int left, int right, double[] dimWidth, double mbw) {
            DBIDArrayMIter iter = sorted.iter();
            int dim = ((NumberVector)relation.get((DBIDRef)iter)).toArray().length;
            this.left = left;
            this.right = right;
            this.computeBoundingBox(relation, (DBIDArrayIter)iter);
            int splitDim = VMath.argmax((double[])this.halfwidth);
            double maxDiff = 2.0 * this.halfwidth[splitDim];
            if (maxDiff < mbw * dimWidth[splitDim]) {
                this.aggregateStats(relation, (DBIDArrayIter)iter, dim);
                return;
            }
            double splitPoint = this.midpoint[splitDim];
            int l = left;
            int r = right - 1;
            while (true) {
                if (l <= r && ((NumberVector)relation.get((DBIDRef)iter.seek(l))).doubleValue(splitDim) <= splitPoint) {
                    ++l;
                    continue;
                }
                while (l <= r && ((NumberVector)relation.get((DBIDRef)iter.seek(r))).doubleValue(splitDim) >= splitPoint) {
                    --r;
                }
                if (l >= r) break;
                sorted.swap(l++, r--);
            }
            assert (((NumberVector)relation.get((DBIDRef)iter.seek(r))).doubleValue(splitDim) <= splitPoint) : ((NumberVector)relation.get((DBIDRef)iter.seek(r))).doubleValue(splitDim) + " not less than " + splitPoint;
            if (++r != right) {
                this.leftChild = new KDTree(relation, sorted, left, r, dimWidth, mbw);
                this.rightChild = new KDTree(relation, sorted, r, right, dimWidth, mbw);
                this.sum = VMath.plus((double[])this.leftChild.sum, (double[])this.rightChild.sum);
                this.sumSq = VMath.plus((double[][])this.leftChild.sumSq, (double[][])this.rightChild.sumSq);
            } else {
                this.aggregateStats(relation, (DBIDArrayIter)iter, dim);
            }
        }

        private void computeBoundingBox(Relation<? extends NumberVector> relation, DBIDArrayIter iter) {
            double[] b1 = ((NumberVector)relation.get((DBIDRef)iter.seek(this.left))).toArray();
            double[] b2 = (double[])b1.clone();
            iter.advance();
            while (iter.getOffset() < this.right) {
                NumberVector vector = (NumberVector)relation.get((DBIDRef)iter);
                for (int d = 0; d < b1.length; ++d) {
                    double value = vector.doubleValue(d);
                    b1[d] = value < b1[d] ? value : b1[d];
                    b2[d] = value > b2[d] ? value : b2[d];
                }
                iter.advance();
            }
            for (int d = 0; d < b1.length; ++d) {
                double l = b1[d];
                double u = b2[d];
                b1[d] = (l + u) * 0.5;
                b2[d] = (u - l) * 0.5;
            }
            this.midpoint = b1;
            this.halfwidth = b2;
        }

        private void aggregateStats(Relation<? extends NumberVector> relation, DBIDArrayIter iter, int dim) {
            this.sum = new double[dim];
            this.sumSq = new double[dim][dim];
            iter.seek(this.left);
            while (iter.getOffset() < this.right) {
                NumberVector vector = (NumberVector)relation.get((DBIDRef)iter);
                for (int d1 = 0; d1 < dim; ++d1) {
                    double value1 = vector.doubleValue(d1);
                    int n = d1;
                    this.sum[n] = this.sum[n] + value1;
                    int d2 = 0;
                    while (d2 < dim) {
                        double value2 = vector.doubleValue(d2);
                        double[] dArray = this.sumSq[d1];
                        int n2 = d2++;
                        dArray[n2] = dArray[n2] + value1 * value2;
                    }
                }
                iter.advance();
            }
        }
    }
}

