/*
 * Decompiled with CFR 0.152.
 */
package elki.classification;

import elki.Algorithm;
import elki.classification.Classifier;
import elki.data.ClassLabel;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.Database;
import elki.database.ids.DBIDRef;
import elki.database.ids.DoubleDBIDListIter;
import elki.database.ids.KNNList;
import elki.database.query.QueryBuilder;
import elki.database.query.knn.KNNSearcher;
import elki.database.relation.Relation;
import elki.distance.Distance;
import elki.distance.minkowski.EuclideanDistance;
import elki.utilities.Priority;
import elki.utilities.documentation.Description;
import elki.utilities.documentation.Title;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.util.ArrayList;
import java.util.Collections;

@Title(value="kNN-classifier")
@Description(value="Lazy classifier classifies a given instance to the majority class of the k-nearest neighbors.")
@Priority(value=100)
public class KNNClassifier<O>
implements Classifier<O> {
    protected int k;
    protected KNNSearcher<O> knnq;
    protected Relation<? extends ClassLabel> labelrep;
    protected Distance<? super O> distance;

    public KNNClassifier(Distance<? super O> distance, int k) {
        this.distance = distance;
        this.k = k;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array((TypeInformation[])new TypeInformation[]{TypeUtil.NUMBER_VECTOR_FIELD});
    }

    @Override
    public void buildClassifier(Database database, Relation<? extends ClassLabel> labels) {
        Relation relation = database.getRelation(this.distance.getInputTypeRestriction(), new Object[0]);
        this.knnq = new QueryBuilder(relation, this.distance).kNNByObject(this.k);
        this.labelrep = labels;
    }

    @Override
    public ClassLabel classify(O instance) {
        Object2IntOpenHashMap count = new Object2IntOpenHashMap();
        KNNList query = this.knnq.getKNN(instance, this.k);
        DoubleDBIDListIter neighbor = query.iter();
        while (neighbor.valid()) {
            count.addTo((Object)((ClassLabel)this.labelrep.get((DBIDRef)neighbor)), 1);
            neighbor.advance();
        }
        int bestoccur = Integer.MIN_VALUE;
        ClassLabel bestl = null;
        ObjectIterator iter = count.object2IntEntrySet().fastIterator();
        while (iter.hasNext()) {
            Object2IntMap.Entry entry = (Object2IntMap.Entry)iter.next();
            if (entry.getIntValue() <= bestoccur) continue;
            bestoccur = entry.getIntValue();
            bestl = (ClassLabel)entry.getKey();
        }
        return bestl;
    }

    public double[] classProbabilities(O instance, ArrayList<ClassLabel> labels) {
        int[] occurences = new int[labels.size()];
        KNNList query = this.knnq.getKNN(instance, this.k);
        DoubleDBIDListIter neighbor = query.iter();
        while (neighbor.valid()) {
            int index = Collections.binarySearch(labels, (ClassLabel)this.labelrep.get((DBIDRef)neighbor));
            if (index >= 0) {
                int n = index;
                occurences[n] = occurences[n] + 1;
            }
            neighbor.advance();
        }
        double[] distribution = new double[labels.size()];
        for (int i = 0; i < distribution.length; ++i) {
            distribution[i] = (double)occurences[i] / (double)query.size();
        }
        return distribution;
    }

    @Override
    public String model() {
        return "lazy learner - provides no model";
    }

    public Distance<? super O> getDistance() {
        return this.distance;
    }

    public static class Par<O>
    implements Parameterizer {
        public static final OptionID K_ID = new OptionID("knnclassifier.k", "The number of neighbors to take into account for classification.");
        protected Distance<? super O> distanceFunction;
        protected int k;

        public void configure(Parameterization config) {
            new ObjectParameter(Algorithm.Utils.DISTANCE_FUNCTION_ID, Distance.class, EuclideanDistance.class).grab(config, x -> {
                this.distanceFunction = x;
            });
            ((IntParameter)new IntParameter(K_ID, 1).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ONE_INT)).grab(config, x -> {
                this.k = x;
            });
        }

        public KNNClassifier<O> make() {
            return new KNNClassifier<O>(this.distanceFunction, this.k);
        }
    }
}

