/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees;

import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.TimeSeriesUtil;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees.TimeSeriesForestLearningAlgorithm;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees.TimeSeriesTreeClassifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.aeonbits.owner.ConfigCache;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TimeSeriesForestClassifier
extends ASimplifiedTSClassifier<Integer> {
    private static final Logger LOGGER = LoggerFactory.getLogger(TimeSeriesForestClassifier.class);
    private final TimeSeriesForestLearningAlgorithm.ITimeSeriesForestConfig config;
    private TimeSeriesTreeClassifier[] trees;

    public TimeSeriesForestClassifier() {
        this((TimeSeriesForestLearningAlgorithm.ITimeSeriesForestConfig)ConfigCache.getOrCreate(TimeSeriesForestLearningAlgorithm.ITimeSeriesForestConfig.class, (Map[])new Map[0]));
    }

    public TimeSeriesForestClassifier(TimeSeriesForestLearningAlgorithm.ITimeSeriesForestConfig config) {
        this.config = config;
    }

    public void setNumberOfTrees(int numTrees) {
        this.config.setProperty("numtrees", "" + numTrees);
    }

    public void setMaxDepth(int maxDepth) {
        this.config.setProperty("maxdepth", "" + maxDepth);
    }

    public void setFeatureCaching(boolean enableFeatureCaching) {
        this.config.setProperty("featurecaching", "" + enableFeatureCaching);
    }

    public void setSeed(int seed) {
        this.config.setProperty("seed", "" + seed);
    }

    public Integer predict(double[] univInstance) throws PredictionException {
        if (!this.isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        if (univInstance == null) {
            throw new IllegalArgumentException("Instance to be predicted must not be null or empty!");
        }
        HashMap<Integer, Integer> votes = new HashMap<Integer, Integer>();
        for (int i = 0; i < this.trees.length; ++i) {
            int prediction = this.trees[i].predict(univInstance);
            if (!votes.containsKey(prediction)) {
                votes.put(prediction, 1);
                continue;
            }
            votes.replace(prediction, (Integer)votes.get(prediction) + 1);
        }
        return (Integer)TimeSeriesUtil.getMaximumKeyByValue(votes);
    }

    public Integer predict(List<double[]> multivInstance) throws PredictionException {
        LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        return this.predict(multivInstance.get(0));
    }

    public List<Integer> predict(TimeSeriesDataset2 dataset) throws PredictionException {
        double[][] data = this.checkWhetherPredictionIsPossible(dataset);
        if (dataset.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate instances are not supported yet.");
        }
        ArrayList<Integer> predictions = new ArrayList<Integer>();
        LOGGER.debug("Starting prediction...");
        for (int i = 0; i < data.length; ++i) {
            predictions.add(this.predict(data[i]));
        }
        LOGGER.debug("Finished prediction.");
        return predictions;
    }

    public TimeSeriesTreeClassifier[] getTrees() {
        return this.trees;
    }

    public void setTrees(TimeSeriesTreeClassifier[] trees) {
        this.trees = trees;
    }

    public TimeSeriesForestLearningAlgorithm getLearningAlgorithm(TimeSeriesDataset2 dataset) {
        return new TimeSeriesForestLearningAlgorithm(this.config, this, dataset);
    }
}

