/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees;

import ai.libs.jaicore.basic.IOwnerBasedAlgorithmConfig;
import ai.libs.jaicore.basic.IOwnerBasedRandomizedAlgorithmConfig;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesFeature;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSCLearningAlgorithm;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.MathUtil;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.TimeSeriesUtil;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees.TimeSeriesBagOfFeaturesClassifier;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.util.WekaTimeseriesUtil;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.aeonbits.owner.Config;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.classifiers.trees.RandomForest;
import weka.core.Instances;

public class TimeSeriesBagOfFeaturesLearningAlgorithm
extends ASimplifiedTSCLearningAlgorithm<Integer, TimeSeriesBagOfFeaturesClassifier> {
    private static final Logger LOGGER = LoggerFactory.getLogger(TimeSeriesBagOfFeaturesLearningAlgorithm.class);
    public static final boolean USE_BIAS_CORRECTION = false;
    private static final int NUM_TREES_IN_FOREST = 500;

    public TimeSeriesBagOfFeaturesLearningAlgorithm(ITimeSeriesBagOfFeaturesConfig config, TimeSeriesBagOfFeaturesClassifier classifier, TimeSeriesDataset2 data) {
        super((IOwnerBasedAlgorithmConfig)config, (ASimplifiedTSClassifier)classifier, data);
        if (config.zProportion() < 0.0 || config.zProportion() > 1.0) {
            throw new IllegalArgumentException("Parameter zProportion is set to " + config.zProportion() + " but must be between 0 and 1!");
        }
    }

    public TimeSeriesBagOfFeaturesClassifier call() throws AlgorithmException {
        int minIntervalLength;
        TimeSeriesDataset2 dataset = (TimeSeriesDataset2)this.getInput();
        if (dataset == null || dataset.isEmpty()) {
            throw new IllegalArgumentException("Dataset used for training must not be null or empty!");
        }
        if (dataset.isMultivariate()) {
            LOGGER.info("Only univariate data is used for training (matrix index 0), although multivariate data is available.");
        }
        TimeSeriesUtil.shuffleTimeSeriesDataset((TimeSeriesDataset2)dataset, (int)((int)this.getConfig().seed()));
        double[][] data = dataset.getValuesOrNull(0);
        int[] targets = dataset.getTargets();
        if (data == null || data.length == 0 || targets == null || targets.length == 0) {
            throw new IllegalArgumentException("The given dataset for training must not contain a null or empty data or target matrix.");
        }
        int numClasses = TimeSeriesUtil.getNumberOfClasses((TimeSeriesDataset2)dataset);
        if (this.getConfig().zNormalization()) {
            for (int i = 0; i < dataset.getNumberOfInstances(); ++i) {
                data[i] = TimeSeriesUtil.zNormalize((double[])data[i], (boolean)true);
            }
        }
        int length = data[0].length;
        int lMin = (int)(this.getConfig().zProportion() * (double)length);
        if (lMin < (minIntervalLength = this.getConfig().minIntervalLength())) {
            lMin = minIntervalLength;
        }
        if (lMin >= length - minIntervalLength) {
            lMin -= minIntervalLength;
        }
        int d = this.getD(lMin);
        int r = this.getR(length);
        Pair<int[][], int[][][]> subSeqIntervals = this.generateSubsequencesAndIntervals(r, d, lMin, length);
        int[][] subsequences = (int[][])subSeqIntervals.getX();
        int[][][] intervals = (int[][][])subSeqIntervals.getY();
        double[][][][] generatedFeatures = TimeSeriesBagOfFeaturesLearningAlgorithm.generateFeatures(data, subsequences, intervals);
        int numFeatures = (d + 1) * 3 + 2;
        double[][] subSeqValueMatrix = new double[(r - d) * data.length][numFeatures];
        int[] targetMatrix = new int[(r - d) * data.length];
        for (int i = 0; i < r - d; ++i) {
            for (int j = 0; j < data.length; ++j) {
                double[] intervalFeatures = new double[numFeatures];
                for (int k = 0; k < d + 1; ++k) {
                    intervalFeatures[k * 3] = generatedFeatures[j][i][k][0];
                    intervalFeatures[k * 3 + 1] = generatedFeatures[j][i][k][1];
                    intervalFeatures[k * 3 + 2] = generatedFeatures[j][i][k][2];
                }
                intervalFeatures[intervalFeatures.length - 2] = subsequences[i][0];
                intervalFeatures[intervalFeatures.length - 1] = subsequences[i][1];
                subSeqValueMatrix[j * (r - d) + i] = intervalFeatures;
                targetMatrix[j * (r - d) + i] = targets[j];
            }
        }
        RandomForest subseriesClf = new RandomForest();
        subseriesClf.setNumIterations(500);
        double[][] probs = null;
        try {
            probs = TimeSeriesBagOfFeaturesLearningAlgorithm.measureOOBProbabilitiesUsingCV(subSeqValueMatrix, targetMatrix, (r - d) * data.length, this.getConfig().numFolds(), numClasses, subseriesClf);
        }
        catch (TrainingException e1) {
            throw new AlgorithmException("Could not measure OOB probabilities using CV.", (Throwable)e1);
        }
        try {
            WekaTimeseriesUtil.buildWekaClassifierFromSimplifiedTS((Classifier)subseriesClf, TimeSeriesUtil.createDatasetForMatrix((int[])targetMatrix, (double[][][])new double[][][]{subSeqValueMatrix}));
        }
        catch (TrainingException e) {
            throw new AlgorithmException("Could not train the sub series Random Forest classifier due to an internal Weka exception.", (Throwable)e);
        }
        int[][] discretizedProbs = TimeSeriesBagOfFeaturesLearningAlgorithm.discretizeProbs(this.getConfig().numBins(), probs);
        Pair<int[][][], int[][]> histFreqPair = TimeSeriesBagOfFeaturesLearningAlgorithm.formHistogramsAndRelativeFreqs(discretizedProbs, data.length, numClasses, this.getConfig().numBins());
        int[][][] histograms = (int[][][])histFreqPair.getX();
        int[][] relativeFrequencies = (int[][])histFreqPair.getY();
        double[][] finalInstances = TimeSeriesBagOfFeaturesLearningAlgorithm.generateHistogramInstances(histograms, relativeFrequencies);
        RandomForest finalClf = new RandomForest();
        finalClf.setNumIterations(500);
        try {
            WekaTimeseriesUtil.buildWekaClassifierFromSimplifiedTS((Classifier)finalClf, TimeSeriesUtil.createDatasetForMatrix((int[])targets, (double[][][])new double[][][]{finalInstances}));
        }
        catch (TrainingException e) {
            throw new AlgorithmException("Could not train the final Random Forest classifier due to an internal Weka exception.", (Throwable)e);
        }
        TimeSeriesBagOfFeaturesClassifier model = (TimeSeriesBagOfFeaturesClassifier)this.getClassifier();
        model.setSubseriesClf(subseriesClf);
        model.setFinalClf(finalClf);
        model.setNumClasses(numClasses);
        model.setIntervals(intervals);
        model.setSubsequences(subsequences);
        return model;
    }

    public Pair<int[][], int[][][]> generateSubsequencesAndIntervals(int r, int d, int lMin, int T) {
        int[][] subsequences = new int[r - d][2];
        int[][][] intervals = new int[r - d][d][2];
        int minIntervalLength = this.getConfig().minIntervalLength();
        Random random = new Random(this.getConfig().seed());
        for (int i = 0; i < r - d; ++i) {
            int startIndex = random.nextInt(T - lMin);
            int subSeqLength = random.nextInt(T - lMin - startIndex) + lMin;
            subsequences[i][0] = startIndex;
            subsequences[i][1] = startIndex + subSeqLength + 1;
            int intervalLength = (int)((double)(subsequences[i][1] - subsequences[i][0]) / (double)d);
            if (intervalLength < minIntervalLength) {
                throw new IllegalStateException("The induced interval length must not be lower than the minimum interval length!");
            }
            if (intervalLength > minIntervalLength) {
                intervalLength = random.nextInt(intervalLength - minIntervalLength + 1) + minIntervalLength;
            }
            for (int j = 0; j < d; ++j) {
                intervals[i][j][0] = subsequences[i][0] + j * intervalLength;
                intervals[i][j][1] = subsequences[i][0] + (j + 1) * intervalLength;
            }
        }
        return new Pair((Object)subsequences, (Object)intervals);
    }

    public static double[][][][] generateFeatures(double[][] data, int[][] subsequences, int[][][] intervals) {
        double[][][][] generatedFeatures = new double[data.length][subsequences.length][intervals[0].length + 1][TimeSeriesFeature.NUM_FEATURE_TYPES];
        for (int i = 0; i < data.length; ++i) {
            for (int j = 0; j < subsequences.length; ++j) {
                for (int k = 0; k < intervals[j].length; ++k) {
                    generatedFeatures[i][j][k] = TimeSeriesFeature.getFeatures((double[])data[i], (int)intervals[j][k][0], (int)(intervals[j][k][1] - 1), (boolean)false);
                    double[] dArray = generatedFeatures[i][j][k];
                    dArray[1] = dArray[1] * generatedFeatures[i][j][k][1];
                }
                generatedFeatures[i][j][intervals[j].length] = TimeSeriesFeature.getFeatures((double[])data[i], (int)subsequences[j][0], (int)(subsequences[j][1] - 1), (boolean)false);
                double[] dArray = generatedFeatures[i][j][intervals[j].length];
                dArray[1] = dArray[1] * generatedFeatures[i][j][intervals[j].length][1];
            }
        }
        return generatedFeatures;
    }

    private int getD(int lMin) {
        return lMin > this.getConfig().minIntervalLength() ? (int)Math.floor((double)lMin / (double)this.getConfig().minIntervalLength()) : 1;
    }

    private int getR(int T) {
        return (int)Math.floor((double)T / (double)this.getConfig().minIntervalLength());
    }

    public static double[][] generateHistogramInstances(int[][][] histograms, int[][] relativeFreqsOfClasses) {
        int featureLength = histograms[0].length * histograms[0][0].length + relativeFreqsOfClasses[0].length;
        double[][] results = new double[histograms.length][featureLength];
        for (int i = 0; i < results.length; ++i) {
            int j;
            double[] instFeatures = new double[featureLength];
            int featureIdx = 0;
            for (j = 0; j < histograms[i].length; ++j) {
                for (int k = 0; k < histograms[i][j].length; ++k) {
                    instFeatures[featureIdx++] = histograms[i][j][k];
                }
            }
            for (j = 0; j < relativeFreqsOfClasses[i].length; ++j) {
                instFeatures[featureIdx++] = relativeFreqsOfClasses[i][j];
            }
            results[i] = instFeatures;
        }
        return results;
    }

    public static double[][] measureOOBProbabilitiesUsingCV(double[][] subSeqValueMatrix, int[] targetMatrix, int numProbInstances, int numFolds, int numClasses, RandomForest rf) throws TrainingException {
        double[][] probs = new double[numProbInstances][numClasses];
        int numTestInstsPerFold = (int)((double)probs.length / (double)numFolds);
        for (int i = 0; i < numFolds; ++i) {
            Pair trainingTestDatasets = TimeSeriesUtil.getTrainingAndTestDataForFold((int)i, (int)numFolds, (double[][])subSeqValueMatrix, (int[])targetMatrix);
            TimeSeriesDataset2 trainingDS = (TimeSeriesDataset2)trainingTestDatasets.getX();
            WekaTimeseriesUtil.buildWekaClassifierFromSimplifiedTS((Classifier)rf, trainingDS);
            TimeSeriesDataset2 testDataset = (TimeSeriesDataset2)trainingTestDatasets.getY();
            Instances testInstances = WekaTimeseriesUtil.simplifiedTimeSeriesDatasetToWekaInstances(testDataset, IntStream.rangeClosed(0, numClasses - 1).boxed().map(String::valueOf).collect(Collectors.toList()));
            double[][] testProbs = null;
            try {
                testProbs = rf.distributionsForInstances(testInstances);
            }
            catch (Exception e) {
                throw new TrainingException("Could not induce test probabilities in OOB probability estimation due to an internal Weka error.", (Throwable)e);
            }
            for (int j = 0; j < testProbs.length; ++j) {
                probs[i * numTestInstsPerFold + j] = testProbs[j];
            }
        }
        return probs;
    }

    public static Pair<int[][][], int[][]> formHistogramsAndRelativeFreqs(int[][] discretizedProbs, int numInstances, int numClasses, int numBins) {
        int i;
        if (discretizedProbs.length < numInstances) {
            throw new IllegalArgumentException("The number of discretized probabilities must not be lower than the number of instances!");
        }
        if (discretizedProbs.length % numInstances != 0) {
            throw new IllegalArgumentException("The number of discretized probabilities must be divisible by the number of instances!");
        }
        int[][][] histograms = new int[numInstances][numClasses - 1][numBins];
        int[][] relativeFrequencies = new int[numInstances][numClasses];
        int numEntries = discretizedProbs.length / numInstances;
        for (i = 0; i < discretizedProbs.length; ++i) {
            int instanceIdx = i / numEntries;
            for (int c = 0; c < numClasses - 1; ++c) {
                int bin = discretizedProbs[i][c];
                int[] nArray = histograms[instanceIdx][c];
                int n = bin;
                nArray[n] = nArray[n] + 1;
            }
            int predClass = MathUtil.argmax((int[])discretizedProbs[i]);
            int[] nArray = relativeFrequencies[instanceIdx];
            int n = predClass;
            nArray[n] = nArray[n] + 1;
        }
        for (i = 0; i < relativeFrequencies.length; ++i) {
            int j = 0;
            while (j < relativeFrequencies[i].length) {
                int[] nArray = relativeFrequencies[i];
                int n = j++;
                nArray[n] = nArray[n] / numEntries;
            }
        }
        return new Pair((Object)histograms, (Object)relativeFrequencies);
    }

    public static int[][] discretizeProbs(int numBins, double[][] probs) {
        int[][] results = new int[probs.length][probs[0].length];
        double steps = 1.0 / (double)numBins;
        for (int i = 0; i < results.length; ++i) {
            int[] discretizedProbs = new int[probs[i].length];
            for (int j = 0; j < discretizedProbs.length; ++j) {
                discretizedProbs[j] = probs[i][j] == 1.0 ? numBins - 1 : (int)(probs[i][j] / steps);
            }
            results[i] = discretizedProbs;
        }
        return results;
    }

    public ITimeSeriesBagOfFeaturesConfig getConfig() {
        return (ITimeSeriesBagOfFeaturesConfig)super.getConfig();
    }

    public static interface ITimeSeriesBagOfFeaturesConfig
    extends IOwnerBasedRandomizedAlgorithmConfig {
        public static final String K_NUMBINS = "numbins";
        public static final String K_NUMFOLDS = "numfolds";
        public static final String K_ZPROP = "zprop";
        public static final String K_MIN_INTERVAL_LENGTH = "minintervallength";
        public static final String K_USE_ZNORMALIZATION = "useznormalization";

        @Config.Key(value="numbins")
        @Config.DefaultValue(value="-1")
        public int numBins();

        @Config.Key(value="numfolds")
        @Config.DefaultValue(value="-1")
        public int numFolds();

        @Config.Key(value="zprop")
        @Config.DefaultValue(value="1.0")
        public double zProportion();

        @Config.Key(value="useznormalization")
        @Config.DefaultValue(value="false")
        public boolean zNormalization();

        @Config.Key(value="minintervallength")
        @Config.DefaultValue(value="1")
        public int minIntervalLength();
    }
}

