/*
 * Decompiled with CFR 0.152.
 */
package elki.algorithm.statistics;

import elki.Algorithm;
import elki.data.LabelList;
import elki.data.type.AlternativeTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.DoubleDBIDListIter;
import elki.database.query.QueryBuilder;
import elki.database.query.knn.KNNSearcher;
import elki.database.relation.Relation;
import elki.distance.Distance;
import elki.distance.minkowski.EuclideanDistance;
import elki.logging.Logging;
import elki.logging.progress.AbstractProgress;
import elki.logging.progress.FiniteProgress;
import elki.math.MeanVarianceMinMax;
import elki.result.CollectionResult;
import elki.result.Metadata;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import elki.utilities.optionhandling.parameters.RandomParameter;
import elki.utilities.random.RandomFactory;
import java.util.ArrayList;

public class AveragePrecisionAtK<O>
implements Algorithm {
    private static final Logging LOG = Logging.getLogger(AveragePrecisionAtK.class);
    private Distance<? super O> distance;
    private int k;
    private double sampling = 1.0;
    private RandomFactory random = null;
    private boolean includeSelf;

    public AveragePrecisionAtK(Distance<? super O> distance, int k, double sampling, RandomFactory random, boolean includeSelf) {
        this.distance = distance;
        this.k = k;
        this.sampling = sampling;
        this.random = random;
        this.includeSelf = includeSelf;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array((TypeInformation[])new TypeInformation[]{this.distance.getInputTypeRestriction(), new AlternativeTypeInformation(new TypeInformation[]{TypeUtil.CLASSLABEL, TypeUtil.LABELLIST})});
    }

    public CollectionResult<double[]> run(Relation<O> relation, Relation<?> lrelation) {
        int qk = this.k + (this.includeSelf ? 0 : 1);
        KNNSearcher knnQuery = new QueryBuilder(relation, this.distance).kNNByDBID(qk);
        DBIDs ids = DBIDUtil.randomSample((DBIDs)relation.getDBIDs(), (double)this.sampling, (RandomFactory)this.random);
        MeanVarianceMinMax[] mvs = MeanVarianceMinMax.newArray((int)this.k);
        FiniteProgress objloop = LOG.isVerbose() ? new FiniteProgress("Computing nearest neighbors", ids.size(), LOG) : null;
        DBIDIter iter = ids.iter();
        while (iter.valid()) {
            Object label = lrelation.get((DBIDRef)iter);
            int positive = 0;
            int i = 0;
            DoubleDBIDListIter ri = knnQuery.getKNN((Object)iter, qk).iter();
            while (i < this.k && ri.valid()) {
                if (this.includeSelf || !DBIDUtil.equal((DBIDRef)iter, (DBIDRef)ri)) {
                    double precision = (double)(positive += AveragePrecisionAtK.match(label, lrelation.get((DBIDRef)ri)) ? 1 : 0) / (double)(i + 1);
                    mvs[i].put(precision);
                    ++i;
                }
                ri.advance();
            }
            LOG.incrementProcessed((AbstractProgress)objloop);
            iter.advance();
        }
        LOG.ensureCompleted(objloop);
        ArrayList<double[]> res = new ArrayList<double[]>(this.k);
        for (int i = 0; i < this.k; ++i) {
            MeanVarianceMinMax mv = mvs[i];
            double std = mv.getCount() > 1.0 ? mv.getSampleStddev() : 0.0;
            res.add(new double[]{i + 1, mv.getMean(), std, mv.getMin(), mv.getMax(), mv.getCount()});
        }
        CollectionResult result = new CollectionResult(res);
        Metadata.of((Object)result).setLongName("Average Precision");
        return result;
    }

    protected static boolean match(Object ref, Object test) {
        if (ref == null) {
            return false;
        }
        if (ref == test) {
            return true;
        }
        if (ref instanceof LabelList && test instanceof LabelList) {
            LabelList lref = (LabelList)ref;
            LabelList ltest = (LabelList)test;
            int s1 = lref.size();
            int s2 = ltest.size();
            if (s1 == 0 || s2 == 0) {
                return false;
            }
            for (int i = 0; i < s1; ++i) {
                String l1 = lref.get(i);
                if (l1 == null) continue;
                for (int j = 0; j < s2; ++j) {
                    if (!l1.equals(ltest.get(j))) continue;
                    return true;
                }
            }
        }
        return ref.equals(test);
    }

    public static class Par<O>
    implements Parameterizer {
        private static final OptionID K_ID = new OptionID("avep.k", "K to compute the average precision at.");
        public static final OptionID SAMPLING_ID = new OptionID("avep.sampling", "Relative amount of object to sample.");
        public static final OptionID SEED_ID = new OptionID("avep.sampling-seed", "Random seed for deterministic sampling.");
        public static final OptionID INCLUDESELF_ID = new OptionID("avep.includeself", "Include the query object in the evaluation.");
        protected Distance<? super O> distance;
        protected int k = 20;
        protected double sampling = 1.0;
        protected RandomFactory seed = null;
        protected boolean includeSelf;

        public void configure(Parameterization config) {
            new ObjectParameter(Algorithm.Utils.DISTANCE_FUNCTION_ID, Distance.class, EuclideanDistance.class).grab(config, x -> {
                this.distance = x;
            });
            ((IntParameter)new IntParameter(K_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ONE_INT)).grab(config, x -> {
                this.k = x;
            });
            ((DoubleParameter)((DoubleParameter)((DoubleParameter)new DoubleParameter(SAMPLING_ID).addConstraint((ParameterConstraint)CommonConstraints.GREATER_THAN_ZERO_DOUBLE)).addConstraint((ParameterConstraint)CommonConstraints.LESS_EQUAL_ONE_DOUBLE)).setOptional(true)).grab(config, x -> {
                this.sampling = x;
            });
            new RandomParameter(SEED_ID).grab(config, x -> {
                this.seed = x;
            });
            new Flag(INCLUDESELF_ID).grab(config, x -> {
                this.includeSelf = x;
            });
        }

        public AveragePrecisionAtK<O> make() {
            return new AveragePrecisionAtK<O>(this.distance, this.k, this.sampling, this.seed, this.includeSelf);
        }
    }
}

