/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.trees;

import ai.libs.jaicore.graph.TreeNode;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesFeature;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.trees.TimeSeriesTreeLearningAlgorithm;
import java.util.ArrayList;
import java.util.List;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TimeSeriesTreeClassifier
extends ASimplifiedTSClassifier<Integer> {
    private static final Logger LOGGER = LoggerFactory.getLogger(TimeSeriesTreeClassifier.class);
    private final TimeSeriesTreeLearningAlgorithm.ITimeSeriesTreeConfig config;
    private final TreeNode<TimeSeriesTreeNodeDecisionFunction> rootNode;

    public TimeSeriesTreeClassifier(TimeSeriesTreeLearningAlgorithm.ITimeSeriesTreeConfig config) {
        this.config = config;
        this.rootNode = new TreeNode((Object)new TimeSeriesTreeNodeDecisionFunction(), null);
    }

    public TreeNode<TimeSeriesTreeNodeDecisionFunction> getRootNode() {
        return this.rootNode;
    }

    public Integer predict(double[] univInstance) throws PredictionException {
        TreeNode<TimeSeriesTreeNodeDecisionFunction> tmpNode;
        if (!this.isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        TreeNode<TimeSeriesTreeNodeDecisionFunction> currNode = this.rootNode;
        while ((tmpNode = TimeSeriesTreeClassifier.decide(currNode, univInstance)) != null) {
            currNode = tmpNode;
        }
        return ((TimeSeriesTreeNodeDecisionFunction)currNode.getValue()).classPrediction;
    }

    public Integer predict(List<double[]> multivInstance) throws PredictionException {
        LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        return this.predict(multivInstance.get(0));
    }

    public List<Integer> predict(TimeSeriesDataset2 dataset) throws PredictionException {
        if (!this.isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        if (dataset.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate instances are not supported yet.");
        }
        if (dataset.isEmpty()) {
            throw new IllegalArgumentException("The dataset to be predicted must not be null!");
        }
        double[][] data = dataset.getValuesOrNull(0);
        ArrayList<Integer> predictions = new ArrayList<Integer>();
        for (int i = 0; i < data.length; ++i) {
            predictions.add(this.predict(data[i]));
        }
        return predictions;
    }

    public static TreeNode<TimeSeriesTreeNodeDecisionFunction> decide(TreeNode<TimeSeriesTreeNodeDecisionFunction> treeNode, double[] instance) {
        if (((TimeSeriesTreeNodeDecisionFunction)treeNode.getValue()).classPrediction != -1) {
            return null;
        }
        if (treeNode.getChildren().size() != 2) {
            throw new IllegalStateException("A binary tree node assumed to be complete has not two children nodes.");
        }
        if (TimeSeriesFeature.calculateFeature((TimeSeriesFeature.FeatureType)((TimeSeriesTreeNodeDecisionFunction)treeNode.getValue()).f, (double[])instance, (int)((TimeSeriesTreeNodeDecisionFunction)treeNode.getValue()).t1, (int)((TimeSeriesTreeNodeDecisionFunction)treeNode.getValue()).t2, (boolean)true) <= ((TimeSeriesTreeNodeDecisionFunction)treeNode.getValue()).threshold) {
            return (TreeNode)treeNode.getChildren().get(0);
        }
        return (TreeNode)treeNode.getChildren().get(1);
    }

    public TimeSeriesTreeLearningAlgorithm getLearningAlgorithm(TimeSeriesDataset2 dataset) {
        return new TimeSeriesTreeLearningAlgorithm(this.config, this, dataset);
    }

    static class TimeSeriesTreeNodeDecisionFunction {
        protected TimeSeriesFeature.FeatureType f;
        protected int t1;
        protected int t2;
        protected double threshold;
        protected int classPrediction = -1;

        TimeSeriesTreeNodeDecisionFunction() {
        }

        public String toString() {
            return "TimeSeriesTreeNodeDecisionFunction [f=" + this.f + ", t1=" + this.t1 + ", t2=" + this.t2 + ", threshold=" + this.threshold + ", classPrediction=" + this.classPrediction + "]";
        }
    }
}

