/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.tsc.classifier.trees;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.tsc.classifier.trees.TimeSeriesBagOfFeaturesLearningAlgorithm;
import ai.libs.jaicore.ml.tsc.dataset.TimeSeriesDataset;
import ai.libs.jaicore.ml.tsc.features.TimeSeriesFeature;
import ai.libs.jaicore.ml.tsc.util.TimeSeriesUtil;
import ai.libs.jaicore.ml.tsc.util.WekaUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.aeonbits.owner.ConfigCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.trees.RandomForest;
import weka.core.Instances;

public class TimeSeriesBagOfFeaturesClassifier
extends ASimplifiedTSClassifier<Integer> {
    private static final Logger LOGGER = LoggerFactory.getLogger(TimeSeriesBagOfFeaturesClassifier.class);
    private RandomForest subseriesClf;
    private RandomForest finalClf;
    private int numClasses;
    private int[][][] intervals;
    private int[][] subsequences;
    private final TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig config = (TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig)ConfigCache.getOrCreate(TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig.class, (Map[])new Map[0]);

    public TimeSeriesBagOfFeaturesClassifier(int seed) {
        this(seed, 10, 10, 0.1, 5, false);
    }

    public TimeSeriesBagOfFeaturesClassifier(int seed, int numBins, int numFolds, double zProp, int minIntervalLength) {
        this(seed, numBins, numFolds, zProp, minIntervalLength, false);
    }

    public TimeSeriesBagOfFeaturesClassifier(int seed, int numBins, int numFolds, double zProp, int minIntervalLength, boolean useZNormalization) {
        this.config.setProperty("seed", "" + seed);
        this.setNumBins(numBins);
        this.config.setProperty("numfolds", "" + numFolds);
        this.config.setProperty("zprop", "" + zProp);
        this.config.setProperty("minintervallength", "" + minIntervalLength);
        this.config.setProperty("useznormalization", "" + useZNormalization);
    }

    @Override
    public Integer predict(double[] univInstance) throws PredictionException {
        if (!this.isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        if (this.config.zNormalization()) {
            univInstance = TimeSeriesUtil.zNormalize(univInstance, true);
        }
        double[][] intervalFeatures = new double[this.intervals.length][(this.intervals[0].length + 1) * 3 + 2];
        for (int i = 0; i < this.intervals.length; ++i) {
            for (int j = 0; j < this.intervals[i].length; ++j) {
                double[] tmpFeatures = TimeSeriesFeature.getFeatures(univInstance, this.intervals[i][j][0], this.intervals[i][j][1] - 1, false);
                intervalFeatures[i][j * 3] = tmpFeatures[0];
                intervalFeatures[i][j * 3 + 1] = tmpFeatures[1] * tmpFeatures[1];
                intervalFeatures[i][j * 3 + 2] = tmpFeatures[2];
            }
            double[] subseriesFeatures = TimeSeriesFeature.getFeatures(univInstance, this.subsequences[i][0], this.subsequences[i][1] - 1, false);
            intervalFeatures[i][this.intervals[i].length * 3] = subseriesFeatures[0];
            intervalFeatures[i][this.intervals[i].length * 3 + 1] = subseriesFeatures[1] * subseriesFeatures[1];
            intervalFeatures[i][this.intervals[i].length * 3 + 2] = subseriesFeatures[2];
            intervalFeatures[i][intervalFeatures[i].length - 2] = this.subsequences[i][0];
            intervalFeatures[i][intervalFeatures[i].length - 1] = this.subsequences[i][1];
        }
        Instances subseriesInstances = WekaUtil.simplifiedTimeSeriesDatasetToWekaInstances(TimeSeriesUtil.createDatasetForMatrix(new double[][][]{intervalFeatures}), IntStream.rangeClosed(0, this.numClasses - 1).boxed().map(String::valueOf).collect(Collectors.toList()));
        double[][] probs = null;
        int[] predictedTargets = new int[subseriesInstances.numInstances()];
        try {
            probs = this.subseriesClf.distributionsForInstances(subseriesInstances);
            for (int i = 0; i < subseriesInstances.numInstances(); ++i) {
                predictedTargets[i] = (int)this.subseriesClf.classifyInstance(subseriesInstances.get(i));
            }
        }
        catch (Exception e) {
            throw new PredictionException("Cannot derive the probabilities using the subseries classifier due to an internal Weka exception.", e);
        }
        int[][] discretizedProbs = TimeSeriesBagOfFeaturesLearningAlgorithm.discretizeProbs(this.getNumBins(), probs);
        Pair<int[][][], int[][]> histFreqPair = TimeSeriesBagOfFeaturesLearningAlgorithm.formHistogramsAndRelativeFreqs(discretizedProbs, predictedTargets, 1, this.numClasses, this.getNumBins());
        int[][][] histograms = (int[][][])histFreqPair.getX();
        int[][] relativeFrequencies = (int[][])histFreqPair.getY();
        double[][] finalHistogramInstances = TimeSeriesBagOfFeaturesLearningAlgorithm.generateHistogramInstances(histograms, relativeFrequencies);
        Instances finalInstances = WekaUtil.simplifiedTimeSeriesDatasetToWekaInstances(TimeSeriesUtil.createDatasetForMatrix(new double[][][]{finalHistogramInstances}), IntStream.rangeClosed(0, this.numClasses - 1).boxed().map(String::valueOf).collect(Collectors.toList()));
        if (finalInstances.size() != 1) {
            String errorMessage = "There should be only one instance given to the final Random Forest classifier.";
            throw new PredictionException("There should be only one instance given to the final Random Forest classifier.", new IllegalStateException("There should be only one instance given to the final Random Forest classifier."));
        }
        try {
            return (int)this.finalClf.classifyInstance(finalInstances.firstInstance());
        }
        catch (Exception e) {
            throw new PredictionException("Could not predict instance due to an internal Weka exception.", e);
        }
    }

    @Override
    public Integer predict(List<double[]> multivInstance) throws PredictionException {
        LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        return this.predict(multivInstance.get(0));
    }

    @Override
    public List<Integer> predict(TimeSeriesDataset dataset) throws PredictionException {
        if (!this.isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        ArrayList<Integer> result = new ArrayList<Integer>();
        for (int i = 0; i < dataset.getValues(0).length; ++i) {
            result.add(this.predict(dataset.getValues(0)[i]));
        }
        return result;
    }

    public RandomForest getSubseriesClf() {
        return this.subseriesClf;
    }

    public void setSubseriesClf(RandomForest subseriesClf) {
        this.subseriesClf = subseriesClf;
    }

    public RandomForest getFinalClf() {
        return this.finalClf;
    }

    public void setFinalClf(RandomForest finalClf) {
        this.finalClf = finalClf;
    }

    public int getNumBins() {
        return this.config.numBins();
    }

    public void setNumBins(int numBins) {
        this.config.setProperty("numbins", "" + numBins);
    }

    public int getNumClasses() {
        return this.numClasses;
    }

    public void setNumClasses(int numClasses) {
        this.numClasses = numClasses;
    }

    public int[][][] getIntervals() {
        return this.intervals;
    }

    public void setIntervals(int[][][] intervals) {
        this.intervals = intervals;
    }

    public int[][] getSubsequences() {
        return this.subsequences;
    }

    public void setSubsequences(int[][] subsequences) {
        this.subsequences = subsequences;
    }

    public TimeSeriesBagOfFeaturesLearningAlgorithm getLearningAlgorithm(TimeSeriesDataset dataset) {
        return new TimeSeriesBagOfFeaturesLearningAlgorithm(this.config, this, dataset);
    }

    public TimeSeriesBagOfFeaturesLearningAlgorithm.ITimeSeriesBagOfFeaturesConfig getConfig() {
        return this.config;
    }
}

