/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.neighbors;

import ai.libs.jaicore.basic.IOwnerBasedAlgorithmConfig;
import ai.libs.jaicore.basic.metric.ShotgunDistance;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSCLearningAlgorithm;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.neighbors.NearestNeighborClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.neighbors.ShotgunEnsembleClassifier;
import java.util.ArrayList;
import org.aeonbits.owner.Config;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.common.metric.IDistanceMetric;

public class ShotgunEnsembleLearnerAlgorithm
extends ASimplifiedTSCLearningAlgorithm<Integer, ShotgunEnsembleClassifier> {
    public ShotgunEnsembleLearnerAlgorithm(IShotgunEnsembleLearnerConfig config, ShotgunEnsembleClassifier classifier, TimeSeriesDataset2 dataset) {
        super(config, classifier, dataset);
    }

    @Override
    public IAlgorithmEvent nextWithException() {
        throw new UnsupportedOperationException();
    }

    public IShotgunEnsembleLearnerConfig getConfig() {
        return (IShotgunEnsembleLearnerConfig)super.getConfig();
    }

    public ShotgunEnsembleClassifier call() throws AlgorithmException {
        TimeSeriesDataset2 dataset = (TimeSeriesDataset2)this.getInput();
        if (dataset == null) {
            throw new AlgorithmException("No input data set.");
        }
        if (dataset.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate datasets are not supported.");
        }
        double[][] values = dataset.getValuesOrNull(0);
        if (values == null) {
            throw new AlgorithmException("Empty input data set.");
        }
        int[] targets = dataset.getTargets();
        if (targets == null) {
            throw new AlgorithmException("Empty targets.");
        }
        ArrayList<Pair<Integer, Integer>> scores = new ArrayList<Pair<Integer, Integer>>();
        for (int windowLength = this.getConfig().windowSizeMax(); windowLength >= this.getConfig().windowSizeMin(); --windowLength) {
            int correct = 0;
            ShotgunDistance shotgunDistance = new ShotgunDistance(windowLength, this.getConfig().meanNormalization());
            for (int i = 0; i < values.length; ++i) {
                double minDistance = Double.MAX_VALUE;
                int instanceThatMinimizesDistance = -1;
                for (int j = 0; j < values.length; ++j) {
                    double distance;
                    if (i == j || !((distance = shotgunDistance.distance(values[i], values[j])) < minDistance)) continue;
                    minDistance = distance;
                    instanceThatMinimizesDistance = j;
                }
                if (targets[i] != targets[instanceThatMinimizesDistance]) continue;
                ++correct;
            }
            scores.add(new Pair((Object)correct, (Object)windowLength));
        }
        NearestNeighborClassifier nearestNeighborClassifier = new NearestNeighborClassifier((IDistanceMetric)new ShotgunDistance(this.getConfig().windowSizeMax(), this.getConfig().meanNormalization()));
        try {
            nearestNeighborClassifier.train(dataset);
        }
        catch (Exception e) {
            throw new AlgorithmException("Cant train nearest neighbor classifier.", (Throwable)e);
        }
        ShotgunEnsembleClassifier model = (ShotgunEnsembleClassifier)this.getClassifier();
        model.setWindows(scores);
        model.setNearestNeighborClassifier(nearestNeighborClassifier);
        return model;
    }

    public static interface IShotgunEnsembleLearnerConfig
    extends IOwnerBasedAlgorithmConfig {
        public static final String K_WINDOWLENGTH_MIN = "windowlength.min";
        public static final String K_WINDOWLENGTH_MAX = "windowlength.max";
        public static final String K_MEANNORMALIZATION = "meannormalization";

        @Config.Key(value="windowlength.min")
        public int windowSizeMin();

        @Config.Key(value="windowlength.max")
        public int windowSizeMax();

        @Config.Key(value="meannormalization")
        @Config.DefaultValue(value="false")
        public boolean meanNormalization();
    }
}

