/*
 * Decompiled with CFR 0.152.
 */
package elki.algorithm.statistics;

import elki.Algorithm;
import elki.data.LabelList;
import elki.data.type.AlternativeTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.DoubleDBIDList;
import elki.database.ids.DoubleDBIDListMIter;
import elki.database.ids.HashSetModifiableDBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.ids.ModifiableDoubleDBIDList;
import elki.database.query.QueryBuilder;
import elki.database.query.distance.DistanceQuery;
import elki.database.relation.Relation;
import elki.distance.Distance;
import elki.distance.minkowski.EuclideanDistance;
import elki.evaluation.scores.AveragePrecisionEvaluation;
import elki.evaluation.scores.ROCEvaluation;
import elki.logging.Logging;
import elki.logging.progress.AbstractProgress;
import elki.logging.progress.FiniteProgress;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.Statistic;
import elki.result.textwriter.TextWriteable;
import elki.result.textwriter.TextWriterStream;
import elki.utilities.exceptions.AbortException;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import elki.utilities.optionhandling.parameters.RandomParameter;
import elki.utilities.random.RandomFactory;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;

public class EvaluateRetrievalPerformance<O>
implements Algorithm {
    private static final Logging LOG = Logging.getLogger(EvaluateRetrievalPerformance.class);
    private final String PREFIX = this.getClass().getName();
    protected Distance<? super O> distance;
    protected double sampling = 1.0;
    protected RandomFactory random = null;
    protected boolean includeSelf;
    protected int maxk = 100;

    public EvaluateRetrievalPerformance(Distance<? super O> distance, double sampling, RandomFactory random, boolean includeSelf, int maxk) {
        this.distance = distance;
        this.sampling = sampling;
        this.random = random;
        this.includeSelf = includeSelf;
        this.maxk = maxk;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array((TypeInformation[])new TypeInformation[]{this.distance.getInputTypeRestriction(), new AlternativeTypeInformation(new TypeInformation[]{TypeUtil.CLASSLABEL, TypeUtil.LABELLIST})});
    }

    public RetrievalPerformanceResult run(Relation<O> relation, Relation<?> lrelation) {
        DBIDs ids = DBIDUtil.randomSample((DBIDs)relation.getDBIDs(), (double)this.sampling, (RandomFactory)this.random);
        DistanceQuery distQuery = new QueryBuilder(relation, this.distance).distanceQuery();
        HashSetModifiableDBIDs posn = DBIDUtil.newHashSet();
        ModifiableDoubleDBIDList nlist = DBIDUtil.newDistanceDBIDList((int)relation.size());
        Object2IntOpenHashMap counters = new Object2IntOpenHashMap();
        double map = 0.0;
        double mauroc = 0.0;
        double[] knnperf = new double[this.maxk];
        int samples = 0;
        FiniteProgress objloop = LOG.isVerbose() ? new FiniteProgress("Processing query objects", ids.size(), LOG) : null;
        DBIDIter iter = ids.iter();
        while (iter.valid()) {
            Object label = lrelation.get((DBIDRef)iter);
            this.findMatches(posn.clear(), lrelation, label);
            if (posn.size() > 0) {
                this.computeDistances(nlist, iter, distQuery, relation);
                if (nlist.size() != relation.size() - (this.includeSelf ? 0 : 1)) {
                    LOG.warning((CharSequence)("Neighbor list does not have the desired size: " + nlist.size()));
                }
                map += AveragePrecisionEvaluation.STATIC.evaluate((DBIDs)posn, (DoubleDBIDList)nlist);
                mauroc += ROCEvaluation.STATIC.evaluate((DBIDs)posn, (DoubleDBIDList)nlist);
                KNNEvaluator.STATIC.evaluateKNN(knnperf, nlist, lrelation, (Object2IntOpenHashMap<Object>)counters, label);
                ++samples;
            }
            LOG.incrementProcessed((AbstractProgress)objloop);
            iter.advance();
        }
        LOG.ensureCompleted(objloop);
        if (samples < 1) {
            throw new AbortException("No object matched - are labels parsed correctly?");
        }
        if (!(map >= 0.0) || !(mauroc >= 0.0)) {
            throw new AbortException("NaN in MAP/ROC.");
        }
        LOG.statistics((Statistic)new DoubleStatistic(this.PREFIX + ".map", map /= (double)samples));
        LOG.statistics((Statistic)new DoubleStatistic(this.PREFIX + ".auroc", mauroc /= (double)samples));
        LOG.statistics((Statistic)new DoubleStatistic(this.PREFIX + ".samples", (double)samples));
        for (int k = 0; k < this.maxk; ++k) {
            knnperf[k] = knnperf[k] / (double)samples;
            LOG.statistics((Statistic)new DoubleStatistic(this.PREFIX + ".knn-" + (k + 1), knnperf[k]));
        }
        return new RetrievalPerformanceResult(samples, map, mauroc, knnperf);
    }

    protected static boolean match(Object ref, Object test) {
        if (ref == null) {
            return false;
        }
        if (ref == test) {
            return true;
        }
        if (ref instanceof LabelList && test instanceof LabelList) {
            LabelList lref = (LabelList)ref;
            LabelList ltest = (LabelList)test;
            int s1 = lref.size();
            int s2 = ltest.size();
            if (s1 == 0 || s2 == 0) {
                return false;
            }
            for (int i = 0; i < s1; ++i) {
                String l1 = lref.get(i);
                if (l1 == null) continue;
                for (int j = 0; j < s2; ++j) {
                    if (!l1.equals(ltest.get(j))) continue;
                    return true;
                }
            }
        }
        return ref.equals(test);
    }

    private void findMatches(ModifiableDBIDs posn, Relation<?> lrelation, Object label) {
        DBIDIter ri = lrelation.iterDBIDs();
        while (ri.valid()) {
            if (EvaluateRetrievalPerformance.match(label, lrelation.get((DBIDRef)ri))) {
                posn.add((DBIDRef)ri);
            }
            ri.advance();
        }
    }

    private void computeDistances(ModifiableDoubleDBIDList nlist, DBIDIter query, DistanceQuery<O> distQuery, Relation<O> relation) {
        nlist.clear();
        Object qo = relation.get((DBIDRef)query);
        DBIDIter ri = relation.iterDBIDs();
        while (ri.valid()) {
            if (this.includeSelf || !DBIDUtil.equal((DBIDRef)ri, (DBIDRef)query)) {
                double dist = distQuery.distance(qo, (DBIDRef)ri);
                if (dist != dist) {
                    dist = Double.POSITIVE_INFINITY;
                }
                nlist.add(dist, (DBIDRef)ri);
            }
            ri.advance();
        }
        nlist.sort();
    }

    public static class Par<O>
    implements Parameterizer {
        public static final OptionID SAMPLING_ID = new OptionID("map.sampling", "Relative amount of object to sample.");
        public static final OptionID SEED_ID = new OptionID("map.sampling-seed", "Random seed for deterministic sampling.");
        public static final OptionID INCLUDESELF_ID = new OptionID("map.includeself", "Include the query object in the evaluation.");
        public static final OptionID MAXK_ID = new OptionID("map.maxk", "Maximum value of k for kNN evaluation.");
        protected Distance<? super O> distance;
        protected double sampling = 1.0;
        protected RandomFactory seed = null;
        protected boolean includeSelf;
        protected int maxk = 0;

        public void configure(Parameterization config) {
            new ObjectParameter(Algorithm.Utils.DISTANCE_FUNCTION_ID, Distance.class, EuclideanDistance.class).grab(config, x -> {
                this.distance = x;
            });
            ((DoubleParameter)((DoubleParameter)((DoubleParameter)new DoubleParameter(SAMPLING_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ZERO_DOUBLE)).addConstraint((ParameterConstraint)CommonConstraints.LESS_EQUAL_ONE_DOUBLE)).setOptional(true)).grab(config, x -> {
                this.sampling = x;
            });
            new RandomParameter(SEED_ID).grab(config, x -> {
                this.seed = x;
            });
            new Flag(INCLUDESELF_ID).grab(config, x -> {
                this.includeSelf = x;
            });
            ((IntParameter)((IntParameter)new IntParameter(MAXK_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).setOptional(true)).grab(config, x -> {
                this.maxk = x;
            });
        }

        public EvaluateRetrievalPerformance<O> make() {
            return new EvaluateRetrievalPerformance<O>(this.distance, this.sampling, this.seed, this.includeSelf, this.maxk);
        }
    }

    public static class RetrievalPerformanceResult
    implements TextWriteable {
        private int samplesize;
        private double map;
        private double auroc;
        private double[] knnperf;

        public RetrievalPerformanceResult(int samplesize, double map, double auroc, double[] knnperf) {
            this.map = map;
            this.auroc = auroc;
            this.samplesize = samplesize;
            this.knnperf = knnperf;
        }

        public double getAUROC() {
            return this.auroc;
        }

        public double getMAP() {
            return this.map;
        }

        public String getLongName() {
            return "Distance function retrieval evaluation.";
        }

        public String getShortName() {
            return "distance-retrieval-evaluation";
        }

        public void writeToText(TextWriterStream out, String label) {
            out.inlinePrintNoQuotes((Object)"MAP");
            out.inlinePrint((Object)this.map);
            out.flush();
            out.inlinePrintNoQuotes((Object)"AUROC");
            out.inlinePrint((Object)this.auroc);
            out.flush();
            out.inlinePrintNoQuotes((Object)"Samplesize");
            out.inlinePrint((Object)this.samplesize);
            out.flush();
            for (int i = 0; i < this.knnperf.length; ++i) {
                out.inlinePrintNoQuotes((Object)("knn-" + (i + 1)));
                out.inlinePrint((Object)this.knnperf[i]);
                out.flush();
            }
        }
    }

    public static class KNNEvaluator {
        public static final KNNEvaluator STATIC = new KNNEvaluator();

        public void evaluateKNN(double[] knnperf, ModifiableDoubleDBIDList nlist, Relation<?> lrelation, Object2IntOpenHashMap<Object> counters, Object label) {
            int maxk = knnperf.length;
            int k = 1;
            int prevk = 0;
            int max = 0;
            counters.clear();
            DoubleDBIDListMIter iter = nlist.iter();
            while (iter.valid() && prevk < maxk) {
                double prev = iter.doubleValue();
                Object l = lrelation.get((DBIDRef)iter);
                max = Math.max(max, this.countkNN(counters, l));
                iter.advance();
                ++k;
                if (iter.valid() && !(iter.doubleValue() > prev)) continue;
                int pties = 0;
                int ties = 0;
                ObjectIterator cit = counters.object2IntEntrySet().fastIterator();
                block1: while (cit.hasNext()) {
                    Object2IntMap.Entry entry = (Object2IntMap.Entry)cit.next();
                    if (entry.getIntValue() < max) continue;
                    ++ties;
                    Object key = entry.getKey();
                    if (key == null) continue;
                    if (key.equals(label)) {
                        ++pties;
                        continue;
                    }
                    if (!(label instanceof LabelList)) continue;
                    LabelList ll = (LabelList)label;
                    int e = ll.size();
                    for (int i = 0; i < e; ++i) {
                        if (!key.equals(ll.get(i))) continue;
                        ++pties;
                        continue block1;
                    }
                }
                while (prevk < k && prevk < maxk) {
                    int n = prevk++;
                    knnperf[n] = knnperf[n] + (double)pties / (double)ties;
                }
            }
        }

        public int countkNN(Object2IntOpenHashMap<Object> counters, Object l) {
            if (l instanceof LabelList) {
                LabelList ll = (LabelList)l;
                int m = 0;
                int e = ll.size();
                for (int i = 0; i < e; ++i) {
                    m = Math.max(m, counters.addTo((Object)ll.get(i), 1));
                }
                return m;
            }
            return counters.addTo(l, 1);
        }
    }
}

