/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees;

import ai.libs.jaicore.basic.IOwnerBasedAlgorithmConfig;
import ai.libs.jaicore.basic.IOwnerBasedRandomizedAlgorithmConfig;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSCLearningAlgorithm;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees.AccessibleRandomTree;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees.LearnPatternSimilarityClassifier;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.aeonbits.owner.Config;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

public class LearnPatternSimilarityLearningAlgorithm
extends ASimplifiedTSCLearningAlgorithm<Integer, LearnPatternSimilarityClassifier> {
    public LearnPatternSimilarityLearningAlgorithm(IPatternSimilarityConfig config, LearnPatternSimilarityClassifier model, TimeSeriesDataset2 dataset) {
        super((IOwnerBasedAlgorithmConfig)config, (ASimplifiedTSClassifier)model, dataset);
    }

    public LearnPatternSimilarityClassifier call() throws AlgorithmException, AlgorithmTimeoutedException {
        long beginTimeMs = System.currentTimeMillis();
        TimeSeriesDataset2 data = (TimeSeriesDataset2)this.getInput();
        if (data == null || data.isEmpty()) {
            throw new IllegalStateException("The time series input data must not be null or empty!");
        }
        double[][] dataMatrix = data.getValuesOrNull(0);
        if (dataMatrix == null) {
            throw new IllegalArgumentException("Value matrix must be a valid 2D matrix containing the time series values for all instances!");
        }
        int[] targetMatrix = data.getTargets();
        int timeSeriesLength = dataMatrix[0].length;
        int minLength = (int)(0.1 * (double)timeSeriesLength);
        int maxLength = (int)(0.9 * (double)timeSeriesLength);
        Random random = new Random(this.getConfig().seed());
        int numTrees = this.getConfig().numTrees();
        int numSegments = this.getConfig().numSegments();
        int[][] segments = new int[numTrees][numSegments];
        int[][] segmentsDifference = new int[numTrees][numSegments];
        int[] lengthPerTree = new int[numTrees];
        int[] classAttIndex = new int[numTrees];
        AccessibleRandomTree[] trees = new AccessibleRandomTree[numTrees];
        int[] numLeavesPerTree = new int[numTrees];
        int[][][] leafNodeCounts = new int[data.getNumberOfInstances()][numTrees][];
        ArrayList<Attribute> attributes = new ArrayList<Attribute>();
        for (int j = 0; j < 2 * numSegments; ++j) {
            attributes.add(new Attribute("val" + j));
        }
        for (int i = 0; i < numTrees; ++i) {
            if (System.currentTimeMillis() - beginTimeMs > this.getTimeout().milliseconds()) {
                throw new AlgorithmTimeoutedException(System.currentTimeMillis() - beginTimeMs - this.getTimeout().milliseconds());
            }
            lengthPerTree[i] = random.nextInt(maxLength - minLength) + minLength;
            this.generateSegmentsAndDifferencesForTree(segments[i], segmentsDifference[i], lengthPerTree[i], timeSeriesLength, random);
            Instances seqInstances = LearnPatternSimilarityLearningAlgorithm.generateSubseriesFeaturesInstances(attributes, lengthPerTree[i], segments[i], segmentsDifference[i], dataMatrix);
            classAttIndex[i] = random.nextInt(attributes.size());
            seqInstances.setClassIndex(classAttIndex[i]);
            trees[i] = this.initializeRegressionTree(seqInstances.numInstances());
            try {
                trees[i].buildClassifier(seqInstances);
            }
            catch (Exception e) {
                throw new AlgorithmException("Could not build tree in iteration " + i + " due to the following exception: " + e.getMessage());
            }
            numLeavesPerTree[i] = trees[i].getNosLeafNodes();
            for (int inst = 0; inst < data.getNumberOfInstances(); ++inst) {
                leafNodeCounts[inst][i] = new int[numLeavesPerTree[i]];
                for (int len = 0; len < lengthPerTree[i]; ++len) {
                    int instanceIdx = inst * lengthPerTree[i] + len;
                    try {
                        LearnPatternSimilarityLearningAlgorithm.collectLeafCounts(leafNodeCounts[inst][i], seqInstances.get(instanceIdx), trees[i]);
                        continue;
                    }
                    catch (PredictionException e1) {
                        throw new AlgorithmException("Could not prediction using the tree in iteration " + i + " due to the following exception: " + e1.getMessage());
                    }
                }
            }
        }
        LearnPatternSimilarityClassifier model = (LearnPatternSimilarityClassifier)this.getClassifier();
        model.setSegments(segments);
        model.setSegmentsDifference(segmentsDifference);
        model.setLengthPerTree(lengthPerTree);
        model.setClassAttIndexPerTree(classAttIndex);
        model.setTrees(trees);
        model.setTrainLeafNodes(leafNodeCounts);
        model.setTrainTargets(targetMatrix);
        model.setAttributes(attributes);
        return model;
    }

    public void generateSegmentsAndDifferencesForTree(int[] segments, int[] segmentsDifference, int length, int timeSeriesLength, Random random) {
        for (int i = 0; i < this.getConfig().numSegments(); ++i) {
            segments[i] = random.nextInt(timeSeriesLength - length);
            segmentsDifference[i] = random.nextInt(timeSeriesLength - length - 1);
        }
    }

    public AccessibleRandomTree initializeRegressionTree(int numInstances) {
        AccessibleRandomTree regTree = new AccessibleRandomTree();
        regTree.setSeed((int)this.getConfig().seed());
        regTree.setMaxDepth(this.getConfig().maxDepth());
        regTree.setKValue(1);
        regTree.setMinNum((int)((double)numInstances * 0.01));
        return regTree;
    }

    public static void collectLeafCounts(int[] leafNodeCountsForInstance, Instance instance, AccessibleRandomTree regTree) throws PredictionException {
        int leafNodeIdx;
        try {
            regTree.distributionForInstance(instance);
        }
        catch (Exception e) {
            throw new PredictionException("Could not predict the distribution for instance for the given instance '" + instance.toString() + "' due to an internal Weka exception.", (Throwable)e);
        }
        int n = leafNodeIdx = regTree.getLastNode();
        leafNodeCountsForInstance[n] = leafNodeCountsForInstance[n] + 1;
    }

    public static Instances generateSubseriesFeaturesInstances(List<Attribute> attributes, int length, int[] segments, int[] segmentsDifference, double[][] dataMatrix) {
        Instances seqInstances = new Instances("SeqFeatures", new ArrayList<Attribute>(attributes), dataMatrix.length * length);
        for (int inst = 0; inst < dataMatrix.length; ++inst) {
            double[] instValues = dataMatrix[inst];
            for (int len = 0; len < length; ++len) {
                seqInstances.add(LearnPatternSimilarityLearningAlgorithm.generateSubseriesFeatureInstance(instValues, segments, segmentsDifference, len));
            }
        }
        return seqInstances;
    }

    public static Instance generateSubseriesFeatureInstance(double[] instValues, int[] segments, int[] segmentsDifference, int len) {
        if (segments.length != segmentsDifference.length) {
            throw new IllegalArgumentException("The number of segments and the number of segments differences must be the same!");
        }
        if (instValues.length < len) {
            throw new IllegalArgumentException("If the segments' length is set to '" + len + "', the number of time series variables must be greater or equals!");
        }
        DenseInstance instance = new DenseInstance(2 * segments.length);
        for (int seq = 0; seq < segments.length; ++seq) {
            instance.setValue(seq * 2, instValues[segments[seq] + len]);
            double difference = instValues[segmentsDifference[seq] + len + 1] - instValues[segmentsDifference[seq] + len];
            instance.setValue(seq * 2 + 1, difference);
        }
        return instance;
    }

    public IPatternSimilarityConfig getConfig() {
        return (IPatternSimilarityConfig)super.getConfig();
    }

    public static interface IPatternSimilarityConfig
    extends IOwnerBasedRandomizedAlgorithmConfig {
        public static final String K_NUMTREES = "numtrees";
        public static final String K_MAXDEPTH = "maxdepth";
        public static final String K_NUMSEGMENTS = "numsegments";

        @Config.Key(value="numtrees")
        @Config.DefaultValue(value="-1")
        public int numTrees();

        @Config.Key(value="maxdepth")
        @Config.DefaultValue(value="-1")
        public int maxDepth();

        @Config.Key(value="numsegments")
        @Config.DefaultValue(value="1")
        public int numSegments();
    }
}

