/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.tsc.classifier.shapelets;

import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.tsc.classifier.shapelets.ShapeletTransformLearningAlgorithm;
import ai.libs.jaicore.ml.tsc.dataset.TimeSeriesDataset;
import ai.libs.jaicore.ml.tsc.quality_measures.FStat;
import ai.libs.jaicore.ml.tsc.quality_measures.IQualityMeasure;
import ai.libs.jaicore.ml.tsc.shapelets.Shapelet;
import ai.libs.jaicore.ml.tsc.shapelets.search.AMinimumDistanceSearchStrategy;
import ai.libs.jaicore.ml.tsc.util.WekaUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.aeonbits.owner.ConfigCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

public class ShapeletTransformTSClassifier
extends ASimplifiedTSClassifier<Integer> {
    private static final Logger LOGGER = LoggerFactory.getLogger(ShapeletTransformTSClassifier.class);
    private List<Shapelet> shapelets;
    private Classifier classifier;
    private AMinimumDistanceSearchStrategy minDistanceSearchStrategy;
    private final ShapeletTransformLearningAlgorithm.IShapeletTransformLearningAlgorithmConfig config = (ShapeletTransformLearningAlgorithm.IShapeletTransformLearningAlgorithmConfig)ConfigCache.getOrCreate(ShapeletTransformLearningAlgorithm.IShapeletTransformLearningAlgorithmConfig.class, (Map[])new Map[0]);
    private final IQualityMeasure qualityMeasure;

    public ShapeletTransformTSClassifier(int k, int seed) {
        this(k, new FStat(), seed, true);
    }

    public ShapeletTransformTSClassifier(int k, IQualityMeasure qm, int seed, boolean clusterShapelets) {
        this(k, k / 2, qm, seed, clusterShapelets, 3, 0, false, 1);
    }

    public ShapeletTransformTSClassifier(int k, int numClusters, IQualityMeasure qm, int seed, boolean clusterShapelets) {
        this(k, numClusters, qm, seed, clusterShapelets, 3, 0, false, 1);
    }

    public ShapeletTransformTSClassifier(int k, int numClusters, IQualityMeasure qm, int seed, boolean clusterShapelets, int minShapeletLength, int maxShapeletLength, boolean useHIVECOTEEnsemble, int numFolds) {
        this.config.setProperty("numshapelets", "" + k);
        this.config.setProperty("seed", "" + seed);
        this.config.setProperty("clustershapelets", "" + clusterShapelets);
        this.config.setProperty("minshapeletlength", "" + minShapeletLength);
        this.config.setProperty("maxshapeletlength", "" + maxShapeletLength);
        this.config.setProperty("usehivecoteensemble", "" + useHIVECOTEEnsemble);
        this.config.setProperty("numfolds", "" + numFolds);
        this.config.setProperty("numclusters", "" + numClusters);
        this.qualityMeasure = qm;
    }

    public List<Shapelet> getShapelets() {
        return this.shapelets;
    }

    public void setShapelets(List<Shapelet> shapelets) {
        this.shapelets = shapelets;
    }

    public void setClassifier(Classifier classifier) {
        this.classifier = classifier;
    }

    @Override
    public Integer predict(double[] univInstance) throws PredictionException {
        if (!this.isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        double[] transformedInstance = ShapeletTransformLearningAlgorithm.shapeletTransform(univInstance, this.shapelets, this.minDistanceSearchStrategy);
        Instance inst = WekaUtil.simplifiedTSInstanceToWekaInstance(transformedInstance);
        try {
            return (int)Math.round(this.classifier.classifyInstance(inst));
        }
        catch (Exception e) {
            throw new PredictionException(String.format("Could not predict Weka instance %s.", inst.toString()), e);
        }
    }

    @Override
    public Integer predict(List<double[]> multivInstance) throws PredictionException {
        LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        return this.predict(multivInstance.get(0));
    }

    @Override
    public List<Integer> predict(TimeSeriesDataset dataset) throws PredictionException {
        if (!this.isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        if (dataset.isMultivariate()) {
            LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        }
        LOGGER.debug("Transforming dataset...");
        TimeSeriesDataset transformedDataset = null;
        try {
            transformedDataset = ShapeletTransformLearningAlgorithm.shapeletTransform(dataset, this.shapelets, null, -1L, this.minDistanceSearchStrategy);
        }
        catch (InterruptedException e1) {
            Thread.currentThread().interrupt();
            return new ArrayList<Integer>();
        }
        LOGGER.debug("Transformed dataset.");
        double[][] timeSeries = transformedDataset.getValuesOrNull(0);
        if (timeSeries == null) {
            throw new IllegalArgumentException("Dataset matrix of the instances to be predicted must not be null!");
        }
        LOGGER.debug("Converting time series dataset to Weka instances...");
        Instances insts = WekaUtil.simplifiedTimeSeriesDatasetToWekaInstances(transformedDataset);
        LOGGER.debug("Converted time series dataset to Weka instances.");
        LOGGER.debug("Starting prediction...");
        ArrayList<Integer> predictions = new ArrayList<Integer>();
        for (Instance inst : insts) {
            try {
                double prediction = this.classifier.classifyInstance(inst);
                predictions.add((int)Math.round(prediction));
            }
            catch (Exception e) {
                throw new PredictionException(String.format("Could not predict Weka instance %s.", inst.toString()), e);
            }
        }
        LOGGER.debug("Finished prediction.");
        return predictions;
    }

    public AMinimumDistanceSearchStrategy getMinDistanceSearchStrategy() {
        return this.minDistanceSearchStrategy;
    }

    public ShapeletTransformLearningAlgorithm getLearningAlgorithm(TimeSeriesDataset dataset) {
        return new ShapeletTransformLearningAlgorithm(this.config, this, dataset, this.qualityMeasure);
    }
}

