/*
 * Decompiled with CFR 0.152.
 */
package elki.clustering.silhouette;

import elki.clustering.kmeans.initialization.RandomlyChosen;
import elki.clustering.kmedoids.initialization.KMedoidsInitialization;
import elki.clustering.silhouette.PAMMEDSIL;
import elki.data.Clustering;
import elki.data.model.MedoidModel;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.DoubleDataStore;
import elki.database.datastore.WritableDataStore;
import elki.database.datastore.WritableDoubleDataStore;
import elki.database.datastore.WritableIntegerDataStore;
import elki.database.ids.ArrayDBIDs;
import elki.database.ids.ArrayModifiableDBIDs;
import elki.database.ids.DBIDArrayIter;
import elki.database.ids.DBIDArrayMIter;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDVar;
import elki.database.ids.DBIDs;
import elki.database.query.distance.DistanceQuery;
import elki.database.relation.MaterializedDoubleRelation;
import elki.database.relation.Relation;
import elki.distance.Distance;
import elki.logging.Logging;
import elki.logging.progress.AbstractProgress;
import elki.logging.progress.IndefiniteProgress;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.Duration;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.math.linearalgebra.VMath;
import elki.result.EvaluationResult;
import elki.result.Metadata;
import elki.utilities.documentation.Reference;
import elki.utilities.exceptions.AbortException;
import java.util.Arrays;

@Reference(authors="Lars Lenssen and Erich Schubert", title="Clustering by Direct Optimization of the Medoid Silhouette", booktitle="Int. Conf. on Similarity Search and Applications, SISAP 2022", url="https://doi.org/10.1007/978-3-031-17849-8_15", bibkey="DBLP:conf/sisap/LenssenS22")
public class FastMSC<O>
extends PAMMEDSIL<O> {
    private static final Logging LOG = Logging.getLogger(FastMSC.class);

    public FastMSC(Distance<? super O> distance, int k, int maxiter, KMedoidsInitialization<O> initializer) {
        super(distance, k, maxiter, initializer);
    }

    @Override
    public Clustering<MedoidModel> run(Relation<O> relation, int k, DistanceQuery<? super O> distQ) {
        DoubleDataStore silhouettes;
        double sil;
        Object instance;
        DBIDs ids = relation.getDBIDs();
        ArrayModifiableDBIDs medoids = this.initialMedoids(distQ, ids, k);
        WritableIntegerDataStore assignment = DataStoreUtil.makeIntegerStorage((DBIDs)ids, (int)3, (int)-1);
        Duration optd = this.getLogger().newDuration(this.getClass().getName() + ".optimization-time").begin();
        if (k == 2) {
            instance = new Instance2(distQ, ids, assignment);
            sil = ((Instance2)instance).run(medoids, this.maxiter);
            silhouettes = ((Instance2)instance).silhouetteScores();
        } else {
            instance = new Instance(distQ, ids, assignment);
            sil = ((Instance)instance).run(medoids, this.maxiter);
            silhouettes = ((Instance)instance).silhouetteScores();
        }
        this.getLogger().statistics((Statistic)optd.end());
        Clustering<MedoidModel> res = FastMSC.wrapResult(ids, assignment, medoids, "FastMSC Clustering");
        Metadata.hierarchyOf(res).addChild((Object)new MaterializedDoubleRelation("Silhouette scores", ids, silhouettes));
        EvaluationResult ev = EvaluationResult.findOrCreate(res, (String)"Internal Clustering Evaluation");
        EvaluationResult.MeasurementGroup g = ev.findOrCreateGroup("Distance-based");
        g.addMeasure("Medoid Silhouette", sil, -1.0, 1.0, 0.0, false);
        return res;
    }

    protected static final double loss(double a, double b) {
        return a > 0.0 && b > 0.0 ? a / b : 0.0;
    }

    @Override
    protected Logging getLogger() {
        return LOG;
    }

    public static class Par<O>
    extends PAMMEDSIL.Par<O> {
        @Override
        protected Class<? extends KMedoidsInitialization> defaultInitializer() {
            return RandomlyChosen.class;
        }

        @Override
        public FastMSC<O> make() {
            return new FastMSC(this.distance, this.k, this.maxiter, this.initializer);
        }
    }

    protected class Instance {
        protected DBIDs ids;
        protected DistanceQuery<?> distQ;
        protected WritableDataStore<Record> assignment;
        protected WritableIntegerDataStore output;

        public Instance(DistanceQuery<?> distQ, DBIDs ids, WritableIntegerDataStore assignment) {
            this.distQ = distQ;
            this.ids = ids;
            this.assignment = DataStoreUtil.makeStorage((DBIDs)ids, (int)3, Record.class);
            this.output = assignment;
        }

        protected double run(ArrayModifiableDBIDs medoids, int maxiter) {
            int k = medoids.size();
            double sil = this.assignToNearestCluster((ArrayDBIDs)medoids);
            DBIDArrayMIter m = medoids.iter();
            String key = this.getClass().getName().replace("$Instance", "");
            if (LOG.isStatistics()) {
                LOG.statistics((Statistic)new DoubleStatistic(key + ".iteration-" + 0 + ".medoid-silhouette", sil));
            }
            double[] losses = new double[k];
            double[] scratch = new double[k];
            this.updateRemovalLoss(losses);
            IndefiniteProgress prog = LOG.isVerbose() ? new IndefiniteProgress("FastMSC iteration", LOG) : null;
            DBIDVar bestid = DBIDUtil.newVar();
            int iteration = 0;
            while (iteration < maxiter || maxiter <= 0) {
                ++iteration;
                LOG.incrementProcessed((AbstractProgress)prog);
                double best = 0.0;
                int bestcluster = -1;
                DBIDIter j = this.ids.iter();
                while (j.valid()) {
                    if (!DBIDUtil.equal((DBIDRef)m.seek(((Record)this.assignment.get((DBIDRef)j)).m1), (DBIDRef)j)) {
                        System.arraycopy(losses, 0, scratch, 0, k);
                        double acc = this.findBestSwap((DBIDRef)j, scratch);
                        int b = VMath.argmax((double[])scratch);
                        double l = scratch[b] + acc;
                        if (l > best) {
                            best = l;
                            bestid.set((DBIDRef)j);
                            bestcluster = b;
                        }
                    }
                    j.advance();
                }
                if (best <= 0.0) break;
                medoids.set(bestcluster, (DBIDRef)bestid);
                sil = this.doSwap((ArrayDBIDs)medoids, bestcluster, (DBIDRef)bestid);
                if (LOG.isStatistics()) {
                    LOG.statistics((Statistic)new DoubleStatistic(key + ".iteration-" + iteration + ".medoid-silhouette", sil));
                }
                this.updateRemovalLoss(losses);
            }
            LOG.setCompleted(prog);
            if (LOG.isStatistics()) {
                LOG.statistics((Statistic)new LongStatistic(key + ".iterations", (long)iteration));
                LOG.statistics((Statistic)new DoubleStatistic(key + ".final-medoid-silhouette", sil));
            }
            DBIDIter j = this.ids.iter();
            while (j.valid()) {
                this.output.putInt((DBIDRef)j, ((Record)this.assignment.get((DBIDRef)j)).m1);
                j.advance();
            }
            return sil;
        }

        protected double assignToNearestCluster(ArrayDBIDs means) {
            DBIDArrayIter miter = means.iter();
            double loss = 0.0;
            DBIDIter iditer = this.ids.iter();
            while (iditer.valid()) {
                Record rec = new Record();
                miter.seek(0);
                while (miter.valid()) {
                    double dist = this.distQ.distance((DBIDRef)iditer, (DBIDRef)miter);
                    if (dist < rec.d1) {
                        rec.m3 = rec.m2;
                        rec.d3 = rec.d2;
                        rec.m2 = rec.m1;
                        rec.d2 = rec.d1;
                        rec.m1 = miter.getOffset();
                        rec.d1 = dist;
                    } else if (dist < rec.d2) {
                        rec.m3 = rec.m2;
                        rec.d3 = rec.d2;
                        rec.m2 = miter.getOffset();
                        rec.d2 = dist;
                    } else if (dist < rec.d3) {
                        rec.m3 = miter.getOffset();
                        rec.d3 = dist;
                    }
                    miter.advance();
                }
                if (rec.m2 < 0) {
                    throw new AbortException("Too many infinite distances. Cannot assign objects.");
                }
                this.assignment.put((DBIDRef)iditer, (Object)rec);
                loss += rec.d1 / rec.d2;
                assert (rec.m1 != rec.m2 && rec.m1 != rec.m3 && rec.m2 != rec.m3);
                assert (rec.d1 <= rec.d2);
                assert (rec.d2 <= rec.d3);
                iditer.advance();
            }
            return 1.0 - loss / (double)this.ids.size();
        }

        protected double findBestSwap(DBIDRef j, double[] ploss) {
            double acc = 0.0;
            DBIDIter o = this.ids.iter();
            while (o.valid()) {
                double djo = this.distQ.distance(j, (DBIDRef)o);
                Record reco = (Record)this.assignment.get((DBIDRef)o);
                if (djo < reco.d1) {
                    acc += FastMSC.loss(reco.d1, reco.d2) - FastMSC.loss(djo, reco.d1);
                    int n = reco.m1;
                    ploss[n] = ploss[n] + (FastMSC.loss(djo, reco.d1) + FastMSC.loss(reco.d2, reco.d3) - FastMSC.loss(reco.d1 + djo, reco.d2));
                    int n2 = reco.m2;
                    ploss[n2] = ploss[n2] + (FastMSC.loss(reco.d1, reco.d3) - FastMSC.loss(reco.d1, reco.d2));
                } else if (djo < reco.d2) {
                    acc += FastMSC.loss(reco.d1, reco.d2) - FastMSC.loss(reco.d1, djo);
                    int n = reco.m1;
                    ploss[n] = ploss[n] + (FastMSC.loss(reco.d1, djo) + FastMSC.loss(reco.d2, reco.d3) - FastMSC.loss(reco.d1 + djo, reco.d2));
                    int n3 = reco.m2;
                    ploss[n3] = ploss[n3] + (FastMSC.loss(reco.d1, reco.d3) - FastMSC.loss(reco.d1, reco.d2));
                } else if (djo < reco.d3) {
                    int n = reco.m1;
                    ploss[n] = ploss[n] + (FastMSC.loss(reco.d2, reco.d3) - FastMSC.loss(reco.d2, djo));
                    int n4 = reco.m2;
                    ploss[n4] = ploss[n4] + (FastMSC.loss(reco.d1, reco.d3) - FastMSC.loss(reco.d1, djo));
                }
                o.advance();
            }
            return acc;
        }

        protected double doSwap(ArrayDBIDs medoids, int b, DBIDRef j) {
            DBIDArrayIter miter = medoids.iter();
            assert (DBIDUtil.equal((DBIDRef)j, (DBIDRef)miter.seek(b)));
            double silsum = 0.0;
            DBIDIter o = this.ids.iter();
            while (o.valid()) {
                Record rec = (Record)this.assignment.get((DBIDRef)o);
                if (DBIDUtil.equal((DBIDRef)j, (DBIDRef)o)) {
                    if (rec.m1 != b) {
                        if (rec.m2 != b) {
                            rec.m3 = rec.m2;
                            rec.d3 = rec.d2;
                        }
                        rec.m2 = rec.m1;
                        rec.d2 = rec.d1;
                    }
                    rec.m1 = b;
                    rec.d1 = 0.0;
                } else {
                    double djo = this.distQ.distance(j, (DBIDRef)o);
                    if (rec.m1 == b) {
                        if (djo < rec.d2) {
                            rec.d1 = djo;
                        } else if (djo < rec.d3) {
                            rec.m1 = rec.m2;
                            rec.d1 = rec.d2;
                            rec.m2 = b;
                            rec.d2 = djo;
                        } else {
                            rec.m1 = rec.m2;
                            rec.d1 = rec.d2;
                            rec.m2 = rec.m3;
                            rec.d2 = rec.d3;
                            this.updateThirdNearest((DBIDRef)o, rec, b, djo, miter);
                        }
                    } else if (rec.m2 == b) {
                        if (djo < rec.d1) {
                            rec.m2 = rec.m1;
                            rec.d2 = rec.d1;
                            rec.m1 = b;
                            rec.d1 = djo;
                        } else if (djo < rec.d3) {
                            rec.m2 = b;
                            rec.d2 = djo;
                        } else {
                            rec.m2 = rec.m3;
                            rec.d2 = rec.d3;
                            this.updateThirdNearest((DBIDRef)o, rec, b, djo, miter);
                        }
                    } else if (djo < rec.d1) {
                        rec.m3 = rec.m2;
                        rec.d3 = rec.d2;
                        rec.m2 = rec.m1;
                        rec.d2 = rec.d1;
                        rec.m1 = b;
                        rec.d1 = djo;
                    } else if (djo < rec.d2) {
                        rec.m3 = rec.m2;
                        rec.d3 = rec.d2;
                        rec.m2 = b;
                        rec.d2 = djo;
                    } else if (djo < rec.d3) {
                        rec.m3 = b;
                        rec.d3 = djo;
                    } else if (rec.m3 == b) {
                        this.updateThirdNearest((DBIDRef)o, rec, b, djo, miter);
                    }
                    silsum += FastMSC.loss(rec.d1, rec.d2);
                }
                o.advance();
            }
            return 1.0 - silsum / (double)this.ids.size();
        }

        protected void updateThirdNearest(DBIDRef j, Record rec, int m, double bestd, DBIDArrayIter miter) {
            if (FastMSC.this.k == 3) {
                rec.m3 = m;
                rec.d3 = bestd;
                return;
            }
            int best = m;
            miter.seek(0);
            while (miter.valid()) {
                double d;
                if (miter.getOffset() != m && miter.getOffset() != rec.m1 && miter.getOffset() != rec.m2 && (d = this.distQ.distance(j, (DBIDRef)miter)) < bestd) {
                    best = miter.getOffset();
                    bestd = d;
                }
                miter.advance();
            }
            rec.m3 = best;
            rec.d3 = bestd;
            assert (rec.m1 != rec.m2 && rec.m1 != rec.m3 && rec.m2 != rec.m3);
            assert (rec.d1 <= rec.d2);
            assert (rec.d2 <= rec.d3);
        }

        protected void updateRemovalLoss(double[] losses) {
            Arrays.fill(losses, 0.0);
            DBIDIter j = this.ids.iter();
            while (j.valid()) {
                Record reco = (Record)this.assignment.get((DBIDRef)j);
                double l12 = FastMSC.loss(reco.d1, reco.d2);
                int n = reco.m1;
                losses[n] = losses[n] + (l12 - FastMSC.loss(reco.d2, reco.d3));
                int n2 = reco.m2;
                losses[n2] = losses[n2] + (l12 - FastMSC.loss(reco.d1, reco.d3));
                j.advance();
            }
        }

        public DoubleDataStore silhouetteScores() {
            WritableDoubleDataStore silhouettes = DataStoreUtil.makeDoubleStorage((DBIDs)this.ids, (int)30);
            DBIDIter iter = this.ids.iter();
            while (iter.valid()) {
                Record rec = (Record)this.assignment.get((DBIDRef)iter);
                silhouettes.putDouble((DBIDRef)iter, rec.d1 > 0.0 ? 1.0 - rec.d1 / rec.d2 : 1.0);
                iter.advance();
            }
            return silhouettes;
        }
    }

    protected static class Record {
        int m1 = -1;
        int m2 = -1;
        int m3 = -1;
        double d1 = Double.POSITIVE_INFINITY;
        double d2 = Double.POSITIVE_INFINITY;
        double d3 = Double.POSITIVE_INFINITY;

        protected Record() {
        }

        public String toString() {
            return "Record [m1=" + this.m1 + ", m2=" + this.m2 + ", m3=" + this.m3 + ", d1=" + this.d1 + ", d2=" + this.d2 + ", d3=" + this.d3 + "]";
        }
    }

    protected class Instance2 {
        protected DBIDs ids;
        protected DistanceQuery<?> distQ;
        protected WritableDoubleDataStore dm0;
        protected WritableDoubleDataStore dm1;
        protected WritableIntegerDataStore assignment;

        public Instance2(DistanceQuery<?> distQ, DBIDs ids, WritableIntegerDataStore assignment) {
            this.distQ = distQ;
            this.ids = ids;
            this.dm0 = DataStoreUtil.makeDoubleStorage((DBIDs)ids, (int)3);
            this.dm1 = DataStoreUtil.makeDoubleStorage((DBIDs)ids, (int)3);
            this.assignment = assignment;
        }

        protected double run(ArrayModifiableDBIDs medoids, int maxiter) {
            int k = medoids.size();
            assert (k == 2);
            double sil = this.assignToNearestCluster((ArrayDBIDs)medoids);
            DBIDArrayMIter m = medoids.iter();
            String key = this.getClass().getName().replace("$Instance", "");
            if (LOG.isStatistics()) {
                LOG.statistics((Statistic)new DoubleStatistic(key + ".iteration-" + 0 + ".medoid-silhouette", sil));
            }
            double[] scratch = new double[k];
            IndefiniteProgress prog = LOG.isVerbose() ? new IndefiniteProgress("FastMSC iteration", LOG) : null;
            DBIDVar bestid = DBIDUtil.newVar();
            int iteration = 0;
            while (iteration < maxiter || maxiter <= 0) {
                ++iteration;
                LOG.incrementProcessed((AbstractProgress)prog);
                double best = 0.0;
                int bestcluster = -1;
                DBIDIter j = this.ids.iter();
                while (j.valid()) {
                    if (!DBIDUtil.equal((DBIDRef)m.seek(this.assignment.intValue((DBIDRef)j)), (DBIDRef)j)) {
                        Arrays.fill(scratch, 0.0);
                        this.findBestSwap((DBIDRef)j, scratch);
                        int b = scratch[0] > scratch[1] ? 0 : 1;
                        double l = scratch[b];
                        if (l > best) {
                            best = l;
                            bestid.set((DBIDRef)j);
                            bestcluster = b;
                        }
                    }
                    j.advance();
                }
                if (best <= sil) break;
                medoids.set(bestcluster, (DBIDRef)bestid);
                sil = this.doSwap((ArrayDBIDs)medoids, bestcluster, (DBIDRef)bestid);
                if (!LOG.isStatistics()) continue;
                LOG.statistics((Statistic)new DoubleStatistic(key + ".iteration-" + iteration + ".medoid-silhouette", sil));
            }
            LOG.setCompleted(prog);
            if (LOG.isStatistics()) {
                LOG.statistics((Statistic)new LongStatistic(key + ".iterations", (long)iteration));
                LOG.statistics((Statistic)new DoubleStatistic(key + ".final-medoid-silhouette", sil));
            }
            return sil;
        }

        protected double assignToNearestCluster(ArrayDBIDs means) {
            DBIDArrayIter miter = means.iter();
            double silsum = 0.0;
            DBIDIter iditer = this.ids.iter();
            while (iditer.valid()) {
                double di1;
                double di0 = this.distQ.distance((DBIDRef)iditer, (DBIDRef)miter.seek(0));
                this.assignment.putInt((DBIDRef)iditer, di0 < (di1 = this.distQ.distance((DBIDRef)iditer, (DBIDRef)miter.seek(1))) ? 0 : 1);
                this.dm0.putDouble((DBIDRef)iditer, di0);
                this.dm1.putDouble((DBIDRef)iditer, di1);
                silsum += di0 < di1 ? FastMSC.loss(di0, di1) : FastMSC.loss(di1, di0);
                iditer.advance();
            }
            return 1.0 - silsum / (double)this.ids.size();
        }

        protected void findBestSwap(DBIDRef j, double[] ploss) {
            DBIDIter o = this.ids.iter();
            while (o.valid()) {
                double djo = this.distQ.distance(j, (DBIDRef)o);
                double dm0o = this.dm0.doubleValue((DBIDRef)o);
                double dm1o = this.dm1.doubleValue((DBIDRef)o);
                ploss[0] = ploss[0] + (djo < dm1o ? FastMSC.loss(djo, dm1o) : FastMSC.loss(dm1o, djo));
                ploss[1] = ploss[1] + (djo < dm0o ? FastMSC.loss(djo, dm0o) : FastMSC.loss(dm0o, djo));
                o.advance();
            }
            ploss[0] = 1.0 - ploss[0] / (double)this.ids.size();
            ploss[1] = 1.0 - ploss[1] / (double)this.ids.size();
        }

        protected double doSwap(ArrayDBIDs medoids, int b, DBIDRef j) {
            double silsum = 0.0;
            WritableDoubleDataStore dmm = b == 0 ? this.dm0 : this.dm1;
            WritableDoubleDataStore dmx = b == 0 ? this.dm1 : this.dm0;
            DBIDIter o = this.ids.iter();
            while (o.valid()) {
                double djo = this.distQ.distance(j, (DBIDRef)o);
                dmm.putDouble((DBIDRef)o, djo);
                double dmo = dmx.doubleValue((DBIDRef)o);
                int a = djo < dmo ? b : (djo > dmo ? 1 - b : this.assignment.intValue((DBIDRef)o));
                this.assignment.putInt((DBIDRef)o, a);
                silsum += djo < dmo ? FastMSC.loss(djo, dmo) : FastMSC.loss(dmo, djo);
                o.advance();
            }
            return 1.0 - silsum / (double)this.ids.size();
        }

        public DoubleDataStore silhouetteScores() {
            WritableDoubleDataStore silhouettes = DataStoreUtil.makeDoubleStorage((DBIDs)this.ids, (int)30);
            DBIDIter iter = this.ids.iter();
            while (iter.valid()) {
                int a = this.assignment.intValue((DBIDRef)iter);
                double d1 = (a == 0 ? this.dm0 : this.dm1).doubleValue((DBIDRef)iter);
                double d2 = (a == 0 ? this.dm1 : this.dm0).doubleValue((DBIDRef)iter);
                silhouettes.putDouble((DBIDRef)iter, d1 > 0.0 ? 1.0 - d1 / d2 : 1.0);
                iter.advance();
            }
            return silhouettes;
        }
    }
}

