/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering;

import elki.Algorithm;
import elki.clustering.ClusteringAlgorithm;
import elki.data.Cluster;
import elki.data.Clustering;
import elki.data.NumberVector;
import elki.data.model.MeanModel;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.ids.ArrayModifiableDBIDs;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.DoubleDBIDList;
import elki.database.ids.DoubleDBIDListIter;
import elki.database.ids.ModifiableDBIDs;
import elki.database.query.QueryBuilder;
import elki.database.query.distance.DistanceQuery;
import elki.database.query.range.RangeSearcher;
import elki.database.relation.Relation;
import elki.database.relation.RelationUtil;
import elki.distance.NumberVectorDistance;
import elki.distance.minkowski.EuclideanDistance;
import elki.logging.Logging;
import elki.logging.progress.AbstractProgress;
import elki.logging.progress.FiniteProgress;
import elki.math.linearalgebra.Centroid;
import elki.math.statistics.kernelfunctions.EpanechnikovKernelDensityFunction;
import elki.math.statistics.kernelfunctions.KernelDensityFunction;
import elki.result.Metadata;
import elki.utilities.documentation.Reference;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import elki.utilities.pairs.Pair;
import java.util.ArrayList;

@Reference(authors="Y. Cheng", title="Mean shift, mode seeking, and clustering", booktitle="IEEE Transactions on Pattern Analysis and Machine Intelligence 17-8", url="https://doi.org/10.1109/34.400568", bibkey="DBLP:journals/pami/Cheng95")
public class NaiveMeanShiftClustering<V extends NumberVector>
implements ClusteringAlgorithm<Clustering<MeanModel>> {
    private static final Logging LOG = Logging.getLogger(NaiveMeanShiftClustering.class);
    protected NumberVectorDistance<? super V> distance;
    protected KernelDensityFunction kernel = EpanechnikovKernelDensityFunction.KERNEL;
    protected double bandwidth;
    protected static final int MAXITER = 1000;

    public NaiveMeanShiftClustering(NumberVectorDistance<? super V> distance, KernelDensityFunction kernel, double range) {
        this.distance = distance;
        this.kernel = kernel;
        this.bandwidth = range;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array((TypeInformation[])new TypeInformation[]{TypeUtil.NUMBER_VECTOR_FIELD});
    }

    public Clustering<MeanModel> run(Relation<V> relation) {
        QueryBuilder qb = new QueryBuilder(relation, this.distance);
        RangeSearcher rangeq = qb.rangeByObject(this.bandwidth);
        DistanceQuery distq = qb.distanceQuery();
        NumberVector.Factory factory = RelationUtil.getNumberVectorFactory(relation);
        int dim = RelationUtil.dimensionality(relation);
        double threshold = this.bandwidth * 1.0E-10;
        ArrayList<Pair> clusters = new ArrayList<Pair>();
        ArrayModifiableDBIDs noise = DBIDUtil.newArray();
        FiniteProgress prog = LOG.isVerbose() ? new FiniteProgress("Mean-shift clustering", relation.size(), LOG) : null;
        DBIDIter iter = relation.iterDBIDs();
        while (iter.valid()) {
            Object position = (NumberVector)relation.get((DBIDRef)iter);
            int j = 1;
            while (true) {
                boolean okay;
                NumberVector newvec = null;
                DoubleDBIDList neigh = rangeq.getRange(position, this.bandwidth);
                boolean bl = okay = neigh.size() > 1 || neigh.size() >= 1 && j > 1;
                if (okay) {
                    Centroid newpos = new Centroid(dim);
                    DoubleDBIDListIter niter = neigh.iter();
                    while (niter.valid()) {
                        double weight = this.kernel.density(niter.doubleValue() / this.bandwidth);
                        newpos.put((NumberVector)relation.get((DBIDRef)niter), weight);
                        niter.advance();
                    }
                    newvec = factory.newNumberVector(newpos.getArrayRef());
                }
                if (!okay) {
                    noise.add((DBIDRef)iter);
                    break;
                }
                double bestd = Double.POSITIVE_INFINITY;
                Pair bestp = null;
                for (Pair pair : clusters) {
                    double merged = distq.distance((Object)newvec, (Object)((NumberVector)pair.first));
                    if (!(merged < bestd)) continue;
                    bestd = merged;
                    bestp = pair;
                }
                double delta = distq.distance(position, (Object)newvec);
                if (bestd < 10.0 * threshold || bestd * 2.0 < delta) {
                    assert (bestp != null);
                    ((ModifiableDBIDs)bestp.second).add((DBIDRef)iter);
                    break;
                }
                if (Double.isNaN(delta)) {
                    LOG.warning((CharSequence)("Encountered NaN distance. Invalid center vector? " + newvec.toString()));
                    break;
                }
                if (j == 1000 || delta < threshold) {
                    if (j == 1000) {
                        LOG.warning((CharSequence)("No convergence after 1000 iterations. Distance: " + delta));
                    }
                    if (LOG.isDebuggingFine()) {
                        LOG.debugFine((CharSequence)("New cluster:" + newvec + " delta: " + delta + " threshold: " + threshold + " bestd: " + bestd));
                    }
                    ArrayModifiableDBIDs cids = DBIDUtil.newArray((int)1);
                    cids.add((DBIDRef)iter);
                    clusters.add(new Pair((Object)newvec, (Object)cids));
                    break;
                }
                position = newvec;
                ++j;
            }
            LOG.incrementProcessed((AbstractProgress)prog);
            iter.advance();
        }
        LOG.ensureCompleted(prog);
        ArrayList cs = new ArrayList(clusters.size());
        for (Pair pair : clusters) {
            cs.add(new Cluster<MeanModel>((DBIDs)pair.second, new MeanModel(((NumberVector)pair.first).toArray())));
        }
        if (noise.size() > 0) {
            cs.add(new Cluster((DBIDs)noise, true));
        }
        Clustering<MeanModel> c = new Clustering<MeanModel>(cs);
        Metadata.of(c).setLongName("Mean-shift Clustering");
        return c;
    }

    public static class Par<V extends NumberVector>
    implements Parameterizer {
        public static final OptionID KERNEL_ID = new OptionID("meanshift.kernel", "Kernel function to use with mean-shift clustering.");
        public static final OptionID RANGE_ID = new OptionID("meanshift.kernel-bandwidth", "Range of the kernel to use (aka: radius, bandwidth).");
        KernelDensityFunction kernel = EpanechnikovKernelDensityFunction.KERNEL;
        double range;
        protected NumberVectorDistance<? super V> distance;

        public void configure(Parameterization config) {
            new ObjectParameter(Algorithm.Utils.DISTANCE_FUNCTION_ID, NumberVectorDistance.class, EuclideanDistance.class).grab(config, x -> {
                this.distance = x;
            });
            new ObjectParameter(KERNEL_ID, KernelDensityFunction.class, EpanechnikovKernelDensityFunction.class).grab(config, x -> {
                this.kernel = x;
            });
            new DoubleParameter(RANGE_ID).grab(config, x -> {
                this.range = x;
            });
        }

        public NaiveMeanShiftClustering<V> make() {
            return new NaiveMeanShiftClustering<V>(this.distance, this.kernel, this.range);
        }
    }
}

