/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.trees;

import ai.libs.jaicore.basic.IOwnerBasedAlgorithmConfig;
import ai.libs.jaicore.basic.IOwnerBasedRandomizedAlgorithmConfig;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.graph.TreeNode;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesFeature;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSCLearningAlgorithm;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.trees.TimeSeriesTreeClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.TimeSeriesUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.aeonbits.owner.Config;
import org.apache.commons.lang3.ArrayUtils;
import org.api4.java.algorithm.Timeout;
import org.api4.java.algorithm.events.IAlgorithmEvent;

public class TimeSeriesTreeLearningAlgorithm
extends ASimplifiedTSCLearningAlgorithm<Integer, TimeSeriesTreeClassifier> {
    public static final int NUM_THRESH_CANDIDATES = 20;
    public static final double ENTROPY_APLHA = 1.0E-22;
    private static final double PRECISION_DELTA = 1.0E-9;
    private HashMap<Long, double[]> transformedFeaturesCache = null;
    public static final boolean USE_BIAS_CORRECTION = true;

    public TimeSeriesTreeLearningAlgorithm(ITimeSeriesTreeConfig config, TimeSeriesTreeClassifier tree, TimeSeriesDataset2 data) {
        super((IOwnerBasedAlgorithmConfig)config, (ASimplifiedTSClassifier)tree, data);
    }

    public void registerListener(Object listener) {
        throw new UnsupportedOperationException();
    }

    public int getNumCPUs() {
        throw new UnsupportedOperationException();
    }

    public void setNumCPUs(int numberOfCPUs) {
        throw new UnsupportedOperationException();
    }

    public void setTimeout(long timeout, TimeUnit timeUnit) {
        throw new UnsupportedOperationException();
    }

    public void setTimeout(Timeout timeout) {
        throw new UnsupportedOperationException();
    }

    public Timeout getTimeout() {
        throw new UnsupportedOperationException();
    }

    public IAlgorithmEvent nextWithException() {
        throw new UnsupportedOperationException();
    }

    public TimeSeriesTreeClassifier call() {
        TimeSeriesDataset2 data = (TimeSeriesDataset2)this.getInput();
        if (data.isEmpty()) {
            throw new IllegalArgumentException("The dataset used for training must not be null!");
        }
        if (data.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate instances are not supported yet.");
        }
        double[][] dataMatrix = data.getValuesOrNull(0);
        int n = dataMatrix.length;
        if (n <= 0) {
            throw new IllegalArgumentException("The traning data's matrix must contain at least one instance!");
        }
        double parentEntropy = 2.0;
        if (((ITimeSeriesTreeConfig)this.getConfig()).useFeatureCaching()) {
            int q = dataMatrix[0].length;
            this.transformedFeaturesCache = new HashMap(q * q * n);
        }
        this.tree(dataMatrix, data.getTargets(), parentEntropy, ((TimeSeriesTreeClassifier)this.getClassifier()).getRootNode(), 0);
        return (TimeSeriesTreeClassifier)this.getClassifier();
    }

    public Iterator<IAlgorithmEvent> iterator() {
        throw new UnsupportedOperationException();
    }

    public boolean hasNext() {
        throw new UnsupportedOperationException();
    }

    public IAlgorithmEvent next() {
        throw new NoSuchElementException("Cannot enumerate this algorithm!");
    }

    public void cancel() {
        throw new UnsupportedOperationException();
    }

    public void tree(double[][] data, int[] targets, double parentEntropy, TreeNode<TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction> nodeToBeFilled, int depth) {
        int i;
        int n = targets.length;
        ITimeSeriesTreeConfig config = (ITimeSeriesTreeConfig)this.getConfig();
        Pair<List<Integer>, List<Integer>> pairOfIntervalLists = TimeSeriesTreeLearningAlgorithm.sampleIntervals(data[0].length, config.seed());
        double[][][] transformedInstances = this.transformInstances(data, pairOfIntervalLists);
        List<List<Double>> thresholdCandidates = TimeSeriesTreeLearningAlgorithm.generateThresholdCandidates(pairOfIntervalLists, 20, transformedInstances);
        ArrayList<Integer> classes = new ArrayList<Integer>(new HashSet<Integer>(Arrays.asList(ArrayUtils.toObject((int[])targets))));
        double deltaEntropyStar = 0.0;
        double thresholdStar = 0.0;
        int t1t2Star = -1;
        int fStar = -1;
        double[] eStarPerFeatureType = new double[TimeSeriesFeature.NUM_FEATURE_TYPES];
        for (int i2 = 0; i2 < eStarPerFeatureType.length; ++i2) {
            eStarPerFeatureType[i2] = -2.147483648E9;
        }
        double[] deltaEntropyStarPerFeatureType = new double[TimeSeriesFeature.NUM_FEATURE_TYPES];
        int[] t1t2StarPerFeatureType = new int[TimeSeriesFeature.NUM_FEATURE_TYPES];
        double[] thresholdStarPerFeatureType = new double[TimeSeriesFeature.NUM_FEATURE_TYPES];
        List t1 = (List)pairOfIntervalLists.getX();
        List t2 = (List)pairOfIntervalLists.getY();
        for (int i3 = 0; i3 < t1.size(); ++i3) {
            for (int k = 0; k < TimeSeriesFeature.NUM_FEATURE_TYPES; ++k) {
                for (double cand : thresholdCandidates.get(k)) {
                    double localDeltaEntropy = TimeSeriesTreeLearningAlgorithm.calculateDeltaEntropy(transformedInstances[k][i3], targets, cand, classes, parentEntropy);
                    double localE = TimeSeriesTreeLearningAlgorithm.calculateEntrance(localDeltaEntropy, TimeSeriesTreeLearningAlgorithm.calculateMargin(transformedInstances[k][i3], cand));
                    if (!(localE > eStarPerFeatureType[k])) continue;
                    eStarPerFeatureType[k] = localE;
                    deltaEntropyStarPerFeatureType[k] = localDeltaEntropy;
                    t1t2StarPerFeatureType[k] = i3;
                    thresholdStarPerFeatureType[k] = cand;
                }
            }
        }
        int bestK = this.getBestSplitIndex(deltaEntropyStarPerFeatureType);
        deltaEntropyStar = deltaEntropyStarPerFeatureType[bestK];
        t1t2Star = t1t2StarPerFeatureType[bestK];
        thresholdStar = thresholdStarPerFeatureType[bestK];
        fStar = bestK;
        if (Math.abs(deltaEntropyStar) <= 1.0E-9 || depth == config.maxDepth() - 1 || depth != 0 && Math.abs(deltaEntropyStar - parentEntropy) <= 1.0E-9) {
            ((TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction)nodeToBeFilled.getValue()).classPrediction = TimeSeriesUtil.getMode((int[])targets);
            return;
        }
        ((TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction)nodeToBeFilled.getValue()).f = TimeSeriesFeature.FeatureType.values()[fStar];
        ((TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction)nodeToBeFilled.getValue()).t1 = (Integer)t1.get(t1t2Star);
        ((TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction)nodeToBeFilled.getValue()).t2 = (Integer)t2.get(t1t2Star);
        ((TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction)nodeToBeFilled.getValue()).threshold = thresholdStar;
        Pair<List<Integer>, List<Integer>> childDataIndices = TimeSeriesTreeLearningAlgorithm.getChildDataIndices(transformedInstances, n, fStar, t1t2Star, thresholdStar);
        double[][] dataLeft = new double[((List)childDataIndices.getX()).size()][data[0].length];
        int[] targetsLeft = new int[((List)childDataIndices.getX()).size()];
        double[][] dataRight = new double[((List)childDataIndices.getY()).size()][data[0].length];
        int[] targetsRight = new int[((List)childDataIndices.getY()).size()];
        for (i = 0; i < ((List)childDataIndices.getX()).size(); ++i) {
            dataLeft[i] = data[(Integer)((List)childDataIndices.getX()).get(i)];
            targetsLeft[i] = targets[(Integer)((List)childDataIndices.getX()).get(i)];
        }
        for (i = 0; i < ((List)childDataIndices.getY()).size(); ++i) {
            dataRight[i] = data[(Integer)((List)childDataIndices.getY()).get(i)];
            targetsRight[i] = targets[(Integer)((List)childDataIndices.getY()).get(i)];
        }
        TreeNode leftNode = nodeToBeFilled.addChild((Object)new TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction());
        TreeNode rightNode = nodeToBeFilled.addChild((Object)new TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction());
        this.tree(dataLeft, targetsLeft, deltaEntropyStar, (TreeNode<TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction>)leftNode, depth + 1);
        this.tree(dataRight, targetsRight, deltaEntropyStar, (TreeNode<TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction>)rightNode, depth + 1);
    }

    public static Pair<List<Integer>, List<Integer>> getChildDataIndices(double[][][] transformedData, int n, int fType, int t1t2, double threshold) {
        ArrayList<Integer> leftIndices = new ArrayList<Integer>();
        ArrayList<Integer> rightIndices = new ArrayList<Integer>();
        for (int i = 0; i < n; ++i) {
            if (transformedData[fType][t1t2][i] <= threshold) {
                leftIndices.add(i);
                continue;
            }
            rightIndices.add(i);
        }
        return new Pair(leftIndices, rightIndices);
    }

    public int getBestSplitIndex(double[] deltaEntropyStarPerFeatureType) {
        if (deltaEntropyStarPerFeatureType.length != TimeSeriesFeature.NUM_FEATURE_TYPES) {
            throw new IllegalArgumentException("A delta entropy star value has to be given for each feature type!");
        }
        double max = -2.147483648E9;
        ArrayList<Integer> maxIndexes = new ArrayList<Integer>();
        for (int i = 0; i < deltaEntropyStarPerFeatureType.length; ++i) {
            if (deltaEntropyStarPerFeatureType[i] > max) {
                max = deltaEntropyStarPerFeatureType[i];
                maxIndexes.clear();
                maxIndexes.add(i);
                continue;
            }
            if (deltaEntropyStarPerFeatureType[i] != max) continue;
            maxIndexes.add(i);
        }
        if (maxIndexes.isEmpty()) {
            throw new IllegalArgumentException("Could not find any maximum delta entropy star for any feature type for the given array " + Arrays.toString(deltaEntropyStarPerFeatureType) + ".");
        }
        if (maxIndexes.size() > 1) {
            Collections.shuffle(maxIndexes, new Random(((ITimeSeriesTreeConfig)this.getConfig()).seed()));
        }
        return (Integer)maxIndexes.get(0);
    }

    public static double calculateDeltaEntropy(double[] dataValues, int[] targets, double thresholdCandidate, List<Integer> classes, double parentEntropy) {
        int i;
        if (dataValues.length != targets.length) {
            throw new IllegalArgumentException("The number of data values must be the same as the number of target values!");
        }
        double[] entropyValues = new double[2];
        int numClasses = classes.size();
        int[][] classNodeStatistic = new int[2][numClasses];
        int[] intCounter = new int[2];
        for (i = 0; i < dataValues.length; ++i) {
            if (dataValues[i] <= thresholdCandidate) {
                int[] nArray = classNodeStatistic[0];
                int n = classes.indexOf(targets[i]);
                nArray[n] = nArray[n] + 1;
                intCounter[0] = intCounter[0] + 1;
                continue;
            }
            int[] nArray = classNodeStatistic[1];
            int n = classes.indexOf(targets[i]);
            nArray[n] = nArray[n] + 1;
            intCounter[1] = intCounter[1] + 1;
        }
        for (i = 0; i < entropyValues.length; ++i) {
            double entropySum = 0.0;
            for (int c = 0; c < numClasses; ++c) {
                double gammaC = 0.0;
                if (intCounter[i] != 0) {
                    gammaC = (double)classNodeStatistic[i][c] / (double)intCounter[i];
                }
                entropySum += gammaC < 1.0E-9 ? 0.0 : gammaC * Math.log(gammaC);
            }
            entropyValues[i] = -1.0 * entropySum;
        }
        double weightedSum = 0.0;
        for (int i2 = 0; i2 < entropyValues.length; ++i2) {
            weightedSum += (double)intCounter[i2] / (double)dataValues.length * entropyValues[i2];
        }
        return parentEntropy - weightedSum;
    }

    public static double calculateEntrance(double deltaEntropy, double margin) {
        return deltaEntropy + 1.0E-22 * margin;
    }

    public static double calculateMargin(double[] dataValues, double thresholdCandidate) {
        double min = Double.MAX_VALUE;
        for (int i = 0; i < dataValues.length; ++i) {
            double localDist = Math.abs(dataValues[i] - thresholdCandidate);
            if (!(localDist < min)) continue;
            min = localDist;
        }
        return min;
    }

    public double[][][] transformInstances(double[][] dataset, Pair<List<Integer>, List<Integer>> pairOfItervalLists) {
        double[][][] result = new double[TimeSeriesFeature.NUM_FEATURE_TYPES][((List)pairOfItervalLists.getX()).size()][dataset.length];
        int n = dataset.length;
        boolean useFeatureCaching = ((ITimeSeriesTreeConfig)this.getConfig()).useFeatureCaching();
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < ((List)pairOfItervalLists.getX()).size(); ++j) {
                double[] features;
                int t1 = (Integer)((List)pairOfItervalLists.getX()).get(j);
                int t2 = (Integer)((List)pairOfItervalLists.getY()).get(j);
                if (useFeatureCaching) {
                    long key = (long)i + (long)(dataset[i].length * t1) + (long)(dataset[i].length * dataset[i].length * t2);
                    if (!this.transformedFeaturesCache.containsKey(key)) {
                        features = TimeSeriesFeature.getFeatures((double[])dataset[i], (int)t1, (int)t2, (boolean)true);
                        this.transformedFeaturesCache.put(key, features);
                    } else {
                        features = this.transformedFeaturesCache.get(key);
                    }
                } else {
                    features = TimeSeriesFeature.getFeatures((double[])dataset[i], (int)t1, (int)t2, (boolean)true);
                }
                result[0][j][i] = features[0];
                result[1][j][i] = features[1];
                result[2][j][i] = features[2];
            }
        }
        return result;
    }

    public static List<List<Double>> generateThresholdCandidates(Pair<List<Integer>, List<Integer>> pairOfIntervalLists, int numOfCandidates, double[][][] transformedFeatures) {
        int i;
        if (numOfCandidates < 1) {
            throw new IllegalArgumentException("At least one candidate must be calculated!");
        }
        ArrayList<List<Double>> result = new ArrayList<List<Double>>();
        int numInstances = transformedFeatures[0][0].length;
        double[] min = new double[TimeSeriesFeature.NUM_FEATURE_TYPES];
        double[] max = new double[TimeSeriesFeature.NUM_FEATURE_TYPES];
        for (i = 0; i < TimeSeriesFeature.NUM_FEATURE_TYPES; ++i) {
            result.add(new ArrayList());
            min[i] = Double.MAX_VALUE;
            max[i] = -2.147483648E9;
        }
        for (i = 0; i < TimeSeriesFeature.NUM_FEATURE_TYPES; ++i) {
            for (int j = 0; j < numInstances; ++j) {
                for (int l = 0; l < ((List)pairOfIntervalLists.getX()).size(); ++l) {
                    if (transformedFeatures[i][l][j] < min[i]) {
                        min[i] = transformedFeatures[i][l][j];
                    }
                    if (!(transformedFeatures[i][l][j] > max[i])) continue;
                    max[i] = transformedFeatures[i][l][j];
                }
            }
        }
        for (i = 0; i < TimeSeriesFeature.NUM_FEATURE_TYPES; ++i) {
            double width = (max[i] - min[i]) / (double)(numOfCandidates + 1);
            for (int j = 0; j < numOfCandidates; ++j) {
                ((List)result.get(i)).add(min[i] + (double)(j + 1) * width);
            }
        }
        return result;
    }

    public static Pair<List<Integer>, List<Integer>> sampleIntervals(int m, long seed) {
        if (m < 1) {
            throw new IllegalArgumentException("The series' length m must be greater than zero.");
        }
        ArrayList<Integer> iList1 = new ArrayList<Integer>();
        ArrayList<Integer> iList2 = new ArrayList<Integer>();
        List<Integer> bigW = TimeSeriesTreeLearningAlgorithm.randomlySampleNoReplacement(IntStream.rangeClosed(1, m).boxed().collect(Collectors.toList()), (int)Math.sqrt(m), seed);
        for (int w : bigW) {
            List<Integer> tmpSampling = TimeSeriesTreeLearningAlgorithm.randomlySampleNoReplacement(IntStream.rangeClosed(0, m - w).boxed().collect(Collectors.toList()), (int)Math.sqrt((double)(m - w) + 1.0), seed);
            iList1.addAll(tmpSampling);
            for (int t1 : tmpSampling) {
                iList2.add(t1 + w - 1);
            }
        }
        return new Pair(iList1, iList2);
    }

    public static List<Integer> randomlySampleNoReplacement(List<Integer> list, int sampleSize, long seed) {
        if (list == null) {
            throw new IllegalArgumentException("The list to be sampled from must not be null!");
        }
        if (sampleSize < 1 || sampleSize > list.size()) {
            throw new IllegalArgumentException("Sample size must lower equals the size of the list to be sampled from without replacement and greater zero.");
        }
        ArrayList<Integer> listCopy = new ArrayList<Integer>(list);
        Collections.shuffle(listCopy, new Random(seed));
        return listCopy.subList(0, sampleSize);
    }

    public static interface ITimeSeriesTreeConfig
    extends IOwnerBasedRandomizedAlgorithmConfig {
        public static final String K_MAXDEPTH = "maxdepth";
        public static final String K_FEATURECACHING = "featurecaching";

        @Config.Key(value="maxdepth")
        @Config.DefaultValue(value="-1")
        public int maxDepth();

        @Config.Key(value="featurecaching")
        @Config.DefaultValue(value="false")
        public boolean useFeatureCaching();
    }
}

