/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.tsc.classifier.neighbors;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.tsc.classifier.neighbors.NearestNeighborClassifier;
import ai.libs.jaicore.ml.tsc.classifier.neighbors.ShotgunEnsembleLearnerAlgorithm;
import ai.libs.jaicore.ml.tsc.dataset.TimeSeriesDataset;
import ai.libs.jaicore.ml.tsc.distances.ITimeSeriesDistance;
import ai.libs.jaicore.ml.tsc.distances.ShotgunDistance;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.aeonbits.owner.ConfigCache;

public class ShotgunEnsembleClassifier
extends ASimplifiedTSClassifier<Integer> {
    protected double factor;
    protected double[][] values;
    protected int[] targets;
    protected NearestNeighborClassifier nearestNeighborClassifier;
    protected ShotgunDistance shotgunDistance;
    protected ArrayList<Pair<Integer, Integer>> windows;
    protected int bestScore;
    private final ShotgunEnsembleLearnerAlgorithm.IShotgunEnsembleLearnerConfig config = (ShotgunEnsembleLearnerAlgorithm.IShotgunEnsembleLearnerConfig)ConfigCache.getOrCreate(ShotgunEnsembleLearnerAlgorithm.IShotgunEnsembleLearnerConfig.class, (Map[])new Map[0]);

    public ShotgunEnsembleClassifier(int minWindowLength, int maxWindowLength, boolean meanNormalization, double factor) {
        if (minWindowLength < 1) {
            throw new IllegalArgumentException("The parameter minWindowLength must be greater equal to 1.");
        }
        if (maxWindowLength < 1) {
            throw new IllegalArgumentException("The parameter maxWindowLength must be greater equal to 1.");
        }
        if (minWindowLength > maxWindowLength) {
            throw new IllegalAccessError("The parameter maxWindowsLength must be greater equal to parameter minWindowLength");
        }
        this.config.setProperty("windowlength.min", "" + minWindowLength);
        this.config.setProperty("windowlength.max", "" + maxWindowLength);
        this.config.setProperty("meannormalization", "" + meanNormalization);
        if (factor <= 0.0 || factor > 1.0) {
            throw new IllegalArgumentException("The parameter factor must be in (0,1]");
        }
        this.factor = factor;
    }

    protected Map<Integer, Integer> calculateWindowLengthPredictions(double[] testInstance) throws PredictionException {
        HashMap<Integer, Integer> windowLengthPredicitions = new HashMap<Integer, Integer>();
        for (Pair<Integer, Integer> window : this.windows) {
            int correct = (Integer)window.getX();
            int windowLength = (Integer)window.getY();
            this.shotgunDistance.setWindowLength(windowLength);
            if (!((double)correct > (double)this.bestScore * this.factor)) continue;
            int prediction = this.nearestNeighborClassifier.predict(testInstance);
            windowLengthPredicitions.put(windowLength, prediction);
        }
        return windowLengthPredicitions;
    }

    protected Integer mostFrequentLabelFromWindowLengthPredicitions(Map<Integer, Integer> windowLengthPredicitions) {
        HashMap<Integer, Integer> labelFrequencyMap = new HashMap<Integer, Integer>();
        for (Integer label : windowLengthPredicitions.values()) {
            if (labelFrequencyMap.containsKey(label)) {
                labelFrequencyMap.put(label, (Integer)labelFrequencyMap.get(label) + 1);
                continue;
            }
            labelFrequencyMap.put(label, 1);
        }
        int topFrequency = -1;
        int mostFrequentLabel = 0;
        for (Map.Entry entry : labelFrequencyMap.entrySet()) {
            int label = (Integer)entry.getKey();
            int labelFrequency = (Integer)entry.getValue();
            if (labelFrequency <= topFrequency) continue;
            topFrequency = labelFrequency;
            mostFrequentLabel = label;
        }
        return mostFrequentLabel;
    }

    protected Map<Integer, List<Integer>> calculateWindowLengthPredictions(TimeSeriesDataset dataset) throws PredictionException {
        HashMap<Integer, List<Integer>> windowLengthPredicitions = new HashMap<Integer, List<Integer>>();
        for (Pair<Integer, Integer> window : this.windows) {
            int correct = (Integer)window.getX();
            int windowLength = (Integer)window.getY();
            this.shotgunDistance.setWindowLength(windowLength);
            if (!((double)correct > (double)this.bestScore * this.factor)) continue;
            List<Integer> predictions = this.nearestNeighborClassifier.predict(dataset);
            windowLengthPredicitions.put(windowLength, predictions);
        }
        return windowLengthPredicitions;
    }

    protected List<Integer> mostFrequentLabelsFromWindowLengthPredicitions(Map<Integer, List<Integer>> windowLengthPredicitions) {
        int numberOfInstances = windowLengthPredicitions.values().iterator().next().size();
        ArrayList<Integer> predicitions = new ArrayList<Integer>(numberOfInstances);
        for (int i = 0; i < numberOfInstances; ++i) {
            HashMap<Integer, Integer> windowLabelsForInstance = new HashMap<Integer, Integer>();
            for (Map.Entry<Integer, List<Integer>> entry : windowLengthPredicitions.entrySet()) {
                int windowLength = entry.getKey();
                int predictionForWindowLength = entry.getValue().get(i);
                windowLabelsForInstance.put(windowLength, predictionForWindowLength);
            }
            int mostFrequentLabelForInstance = this.mostFrequentLabelFromWindowLengthPredicitions(windowLabelsForInstance);
            predicitions.add(mostFrequentLabelForInstance);
        }
        return predicitions;
    }

    @Override
    public Integer predict(double[] univInstance) throws PredictionException {
        if (univInstance == null) {
            throw new IllegalArgumentException("Instance to predict must not be null.");
        }
        Map<Integer, Integer> windowLengthPredicitions = this.calculateWindowLengthPredictions(univInstance);
        return this.mostFrequentLabelFromWindowLengthPredicitions(windowLengthPredicitions);
    }

    @Override
    public List<Integer> predict(TimeSeriesDataset dataset) throws PredictionException {
        this.checkWhetherPredictionIsPossible(dataset);
        Map<Integer, List<Integer>> windowLengthPredicitions = this.calculateWindowLengthPredictions(dataset);
        return this.mostFrequentLabelsFromWindowLengthPredicitions(windowLengthPredicitions);
    }

    protected void setValues(double[][] values) {
        if (values == null) {
            throw new IllegalArgumentException("Values must not be null");
        }
        this.values = values;
    }

    protected void setTargets(int[] targets) {
        if (targets == null) {
            throw new IllegalArgumentException("Targets must not be null");
        }
        this.targets = targets;
    }

    protected void setWindows(ArrayList<Pair<Integer, Integer>> windows) {
        this.windows = windows;
        int tBestScore = -1;
        for (Pair<Integer, Integer> window : windows) {
            int correct = (Integer)window.getX();
            if (correct <= tBestScore) continue;
            tBestScore = correct;
        }
        this.bestScore = tBestScore;
    }

    protected void setNearestNeighborClassifier(NearestNeighborClassifier nearestNeighborClassifier) {
        ITimeSeriesDistance distanceMeasure = nearestNeighborClassifier.getDistanceMeasure();
        if (!(distanceMeasure instanceof ShotgunDistance)) {
            throw new IllegalArgumentException("The nearest neighbor classifier must use a ShotgunDistance as dsitance measure.");
        }
        this.shotgunDistance = (ShotgunDistance)distanceMeasure;
        this.nearestNeighborClassifier = nearestNeighborClassifier;
    }

    public ShotgunEnsembleLearnerAlgorithm getLearningAlgorithm(TimeSeriesDataset dataset) {
        return new ShotgunEnsembleLearnerAlgorithm(this.config, this, dataset);
    }
}

