/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees;

import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.MathUtil;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees.AccessibleRandomTree;
import ai.libs.jaicore.ml.weka.classification.singlelabel.timeseries.learner.trees.LearnPatternSimilarityLearningAlgorithm;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.aeonbits.owner.ConfigCache;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;

public class LearnPatternSimilarityClassifier
extends ASimplifiedTSClassifier<Integer> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LearnPatternSimilarityClassifier.class);
    private int[][] segments;
    private int[][] segmentsDifference;
    private int[] lengthPerTree;
    private int[] classAttIndexPerTree;
    private AccessibleRandomTree[] trees;
    private int[][][] trainLeafNodes;
    private int[] trainTargets;
    private List<Attribute> attributes;
    private final LearnPatternSimilarityLearningAlgorithm.IPatternSimilarityConfig config = (LearnPatternSimilarityLearningAlgorithm.IPatternSimilarityConfig)ConfigCache.getOrCreate(LearnPatternSimilarityLearningAlgorithm.IPatternSimilarityConfig.class, (Map[])new Map[0]);

    public LearnPatternSimilarityClassifier(int seed, int numTrees, int maxTreeDepth, int numSegments) {
        this.config.setProperty("seed", "" + seed);
        this.config.setProperty("numtrees", "" + numTrees);
        this.config.setProperty("maxdepth", "" + maxTreeDepth);
        this.config.setProperty("numsegments", "" + numSegments);
    }

    public Integer predict(double[] univInstance) throws PredictionException {
        if (!this.isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        if (univInstance == null) {
            throw new IllegalArgumentException("Instance to be predicted must not be null or empty!");
        }
        int[][] leafNodeCounts = new int[this.trees.length][];
        for (int i = 0; i < this.trees.length; ++i) {
            Instances seqInstances = new Instances("SeqFeatures", new ArrayList<Attribute>(this.attributes), this.lengthPerTree[i]);
            for (int len = 0; len < this.lengthPerTree[i]; ++len) {
                Instance instance = LearnPatternSimilarityLearningAlgorithm.generateSubseriesFeatureInstance(univInstance, this.segments[i], this.segmentsDifference[i], len);
                seqInstances.add(instance);
            }
            seqInstances.setClassIndex(this.classAttIndexPerTree[i]);
            leafNodeCounts[i] = new int[this.trees[i].getNosLeafNodes()];
            for (int inst = 0; inst < seqInstances.numInstances(); ++inst) {
                LearnPatternSimilarityLearningAlgorithm.collectLeafCounts(leafNodeCounts[i], seqInstances.get(inst), this.trees[i]);
            }
        }
        return this.trainTargets[this.findNearestInstanceIndex(leafNodeCounts)];
    }

    public int findNearestInstanceIndex(int[][] leafNodeCounts) {
        double minDistance = Double.MAX_VALUE;
        int nearestInstIdx = 0;
        for (int inst = 0; inst < this.trainLeafNodes.length; ++inst) {
            double tmpDist = 0.0;
            for (int i = 0; i < this.trainLeafNodes[inst].length; ++i) {
                tmpDist += MathUtil.intManhattanDistance((int[])this.trainLeafNodes[inst][i], (int[])leafNodeCounts[i]);
            }
            if (!(tmpDist < minDistance)) continue;
            minDistance = tmpDist;
            nearestInstIdx = inst;
        }
        return nearestInstIdx;
    }

    public Integer predict(List<double[]> multivInstance) throws PredictionException {
        LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        return this.predict(multivInstance.get(0));
    }

    public List<Integer> predict(TimeSeriesDataset2 dataset) throws PredictionException {
        double[][] data = this.checkWhetherPredictionIsPossible(dataset);
        if (dataset.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate instances are not supported yet.");
        }
        ArrayList<Integer> predictions = new ArrayList<Integer>();
        LOGGER.debug("Starting prediction...");
        for (int i = 0; i < data.length; ++i) {
            predictions.add(this.predict(data[i]));
        }
        LOGGER.debug("Finished prediction.");
        return predictions;
    }

    public int[][] getSegments() {
        return this.segments;
    }

    public void setSegments(int[][] segments) {
        this.segments = segments;
    }

    public int[][] getSegmentsDifference() {
        return this.segmentsDifference;
    }

    public void setSegmentsDifference(int[][] segmentsDifference) {
        this.segmentsDifference = segmentsDifference;
    }

    public int[] getLengthPerTree() {
        return this.lengthPerTree;
    }

    public void setLengthPerTree(int[] lengthPerTree) {
        this.lengthPerTree = lengthPerTree;
    }

    public int[] getClassAttIndexPerTree() {
        return this.classAttIndexPerTree;
    }

    public void setClassAttIndexPerTree(int[] classAttIndexPerTree) {
        this.classAttIndexPerTree = classAttIndexPerTree;
    }

    public AccessibleRandomTree[] getTrees() {
        return this.trees;
    }

    public void setTrees(AccessibleRandomTree[] trees) {
        this.trees = trees;
    }

    public int[][][] getTrainLeafNodes() {
        return this.trainLeafNodes;
    }

    public void setTrainLeafNodes(int[][][] trainLeafNodes) {
        this.trainLeafNodes = trainLeafNodes;
    }

    public int[] getTrainTargets() {
        return this.trainTargets;
    }

    public void setTrainTargets(int[] trainTargets) {
        this.trainTargets = trainTargets;
    }

    public List<Attribute> getAttributes() {
        return this.attributes;
    }

    public void setAttributes(List<Attribute> attributes) {
        this.attributes = attributes;
    }

    public LearnPatternSimilarityLearningAlgorithm getLearningAlgorithm(TimeSeriesDataset2 dataset) {
        return new LearnPatternSimilarityLearningAlgorithm(this.config, this, dataset);
    }
}

