/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.tsc.classifier.neighbors;

import ai.libs.jaicore.basic.algorithm.IAlgorithmConfig;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.tsc.classifier.neighbors.NearestNeighborLearningAlgorithm;
import ai.libs.jaicore.ml.tsc.dataset.TimeSeriesDataset;
import ai.libs.jaicore.ml.tsc.distances.ITimeSeriesDistance;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.aeonbits.owner.ConfigCache;

public class NearestNeighborClassifier
extends ASimplifiedTSClassifier<Integer> {
    protected static final NearestNeighborComparator nearestNeighborComparator = new NearestNeighborComparator();
    private int k;
    private ITimeSeriesDistance distanceMeasure;
    private VoteType voteType;
    protected double[][] values;
    protected double[][] timestamps;
    protected int[] targets;

    public NearestNeighborClassifier(int k, ITimeSeriesDistance distanceMeasure, VoteType voteType) {
        if (distanceMeasure == null) {
            throw new IllegalArgumentException("Distance measure must not be null");
        }
        if (voteType == null) {
            throw new IllegalArgumentException("Vote type must not be null.");
        }
        this.distanceMeasure = distanceMeasure;
        this.k = k;
        this.voteType = voteType;
    }

    public NearestNeighborClassifier(int k, ITimeSeriesDistance distanceMeasure) {
        this(k, distanceMeasure, VoteType.MAJORITY);
    }

    public NearestNeighborClassifier(ITimeSeriesDistance distanceMeasure) {
        this(1, distanceMeasure, VoteType.MAJORITY);
    }

    @Override
    public Integer predict(double[] univInstance) throws PredictionException {
        if (univInstance == null) {
            throw new IllegalArgumentException("Instance to predict must not be null.");
        }
        return this.calculatePrediction(univInstance);
    }

    @Override
    public List<Integer> predict(TimeSeriesDataset dataset) throws PredictionException {
        double[][] testInstances = this.checkWhetherPredictionIsPossible(dataset);
        ArrayList<Integer> predictions = new ArrayList<Integer>(dataset.getNumberOfInstances());
        for (double[] testInstance : testInstances) {
            int prediction = this.calculatePrediction(testInstance);
            predictions.add(prediction);
        }
        return predictions;
    }

    protected int calculatePrediction(double[] testInstance) {
        PriorityQueue<Pair<Integer, Double>> nearestNeighbors = this.calculateNearestNeigbors(testInstance);
        return this.vote(nearestNeighbors);
    }

    protected PriorityQueue<Pair<Integer, Double>> calculateNearestNeigbors(double[] testInstance) {
        int numberOfTrainInstances = this.values.length;
        PriorityQueue<Pair<Integer, Double>> nearestNeighbors = new PriorityQueue<Pair<Integer, Double>>(nearestNeighborComparator);
        for (int i = 0; i < numberOfTrainInstances; ++i) {
            double d = this.distanceMeasure.distance(testInstance, this.values[i]);
            Pair neighbor = new Pair((Object)this.targets[i], (Object)d);
            nearestNeighbors.add((Pair<Integer, Double>)neighbor);
            if (nearestNeighbors.size() <= this.k) continue;
            nearestNeighbors.poll();
        }
        return nearestNeighbors;
    }

    protected int vote(PriorityQueue<Pair<Integer, Double>> nearestNeighbors) {
        switch (this.voteType) {
            case WEIGHTED_STEPWISE: {
                return this.voteWeightedStepwise(nearestNeighbors);
            }
            case WEIGHTED_PROPORTIONAL_TO_DISTANCE: {
                return this.voteWeightedProportionalToDistance(nearestNeighbors);
            }
        }
        return this.voteMajority(nearestNeighbors);
    }

    protected int voteWeightedStepwise(PriorityQueue<Pair<Integer, Double>> nearestNeighbors) {
        HashMap<Integer, Integer> votes = new HashMap<Integer, Integer>();
        int weight = 1;
        while (!nearestNeighbors.isEmpty()) {
            Pair<Integer, Double> neighbor = nearestNeighbors.poll();
            Integer targetClass = (Integer)neighbor.getX();
            Integer currentVotesOnTargetClass = (Integer)votes.get(targetClass);
            if (currentVotesOnTargetClass == null) {
                votes.put(targetClass, weight);
            } else {
                votes.put(targetClass, currentVotesOnTargetClass + weight);
            }
            ++weight;
        }
        Integer maxWeightOfVotes = Integer.MIN_VALUE;
        Integer mostVotedTargetClass = -1;
        for (Map.Entry entry : votes.entrySet()) {
            int targetClass = (Integer)entry.getKey();
            int votedWeightsForTargetClass = (Integer)entry.getValue();
            if (votedWeightsForTargetClass <= maxWeightOfVotes) continue;
            maxWeightOfVotes = votedWeightsForTargetClass;
            mostVotedTargetClass = targetClass;
        }
        return mostVotedTargetClass;
    }

    protected int voteWeightedProportionalToDistance(PriorityQueue<Pair<Integer, Double>> nearestNeighbors) {
        HashMap<Integer, Double> votes = new HashMap<Integer, Double>();
        for (Pair<Integer, Double> neighbor : nearestNeighbors) {
            Integer targetClass = (Integer)neighbor.getX();
            double distance = (Double)neighbor.getY();
            Double currentVotesOnTargetClass = (Double)votes.get(targetClass);
            if (currentVotesOnTargetClass == null) {
                votes.put(targetClass, 1.0 / distance);
                continue;
            }
            votes.put(targetClass, currentVotesOnTargetClass + 1.0 / distance);
        }
        Double maxWeightOfVotes = Double.MIN_VALUE;
        Integer mostVotedTargetClass = -1;
        for (Map.Entry entry : votes.entrySet()) {
            int targetClass = (Integer)entry.getKey();
            double votedWeightsForTargetClass = (Double)entry.getValue();
            if (!(votedWeightsForTargetClass > maxWeightOfVotes)) continue;
            maxWeightOfVotes = votedWeightsForTargetClass;
            mostVotedTargetClass = targetClass;
        }
        return mostVotedTargetClass;
    }

    protected int voteMajority(PriorityQueue<Pair<Integer, Double>> nearestNeighbors) {
        HashMap<Integer, Integer> votes = new HashMap<Integer, Integer>();
        for (Pair<Integer, Double> neighbor : nearestNeighbors) {
            Integer targetClass = (Integer)neighbor.getX();
            Integer currentVotesOnTargetClass = (Integer)votes.get(targetClass);
            if (currentVotesOnTargetClass == null) {
                votes.put(targetClass, 1);
                continue;
            }
            votes.put(targetClass, currentVotesOnTargetClass + 1);
        }
        Integer maxNumberOfVotes = Integer.MIN_VALUE;
        Integer mostVotedTargetClass = -1;
        for (Map.Entry entry : votes.entrySet()) {
            int targetClass = (Integer)entry.getKey();
            int numberOfVotesForTargetClass = (Integer)entry.getValue();
            if (numberOfVotesForTargetClass <= maxNumberOfVotes) continue;
            maxNumberOfVotes = numberOfVotesForTargetClass;
            mostVotedTargetClass = targetClass;
        }
        return mostVotedTargetClass;
    }

    protected void setValues(double[][] values) {
        if (values == null) {
            throw new IllegalArgumentException("Values must not be null");
        }
        this.values = values;
    }

    protected void setTimestamps(double[][] timestamps) {
        this.timestamps = timestamps;
    }

    protected void setTargets(int[] targets) {
        if (targets == null) {
            throw new IllegalArgumentException("Targets must not be null");
        }
        this.targets = targets;
    }

    public int getK() {
        return this.k;
    }

    public VoteType getVoteType() {
        return this.voteType;
    }

    public ITimeSeriesDistance getDistanceMeasure() {
        return this.distanceMeasure;
    }

    public NearestNeighborLearningAlgorithm getLearningAlgorithm(TimeSeriesDataset dataset) {
        return new NearestNeighborLearningAlgorithm((IAlgorithmConfig)ConfigCache.getOrCreate(IAlgorithmConfig.class, (Map[])new Map[0]), this, dataset);
    }

    private static class NearestNeighborComparator
    implements Comparator<Pair<Integer, Double>> {
        private NearestNeighborComparator() {
        }

        @Override
        public int compare(Pair<Integer, Double> o1, Pair<Integer, Double> o2) {
            return -1 * ((Double)o1.getY()).compareTo((Double)o2.getY());
        }
    }

    public static enum VoteType {
        MAJORITY,
        WEIGHTED_STEPWISE,
        WEIGHTED_PROPORTIONAL_TO_DISTANCE;

    }
}

