/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.tsc.classifier.shapelets;

import ai.libs.jaicore.basic.TimeOut;
import ai.libs.jaicore.basic.algorithm.IAlgorithmConfig;
import ai.libs.jaicore.basic.algorithm.IRandomAlgorithmConfig;
import ai.libs.jaicore.basic.algorithm.events.AlgorithmEvent;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm;
import ai.libs.jaicore.ml.tsc.classifier.ensemble.EnsembleProvider;
import ai.libs.jaicore.ml.tsc.classifier.shapelets.ShapeletTransformTSClassifier;
import ai.libs.jaicore.ml.tsc.dataset.TimeSeriesDataset;
import ai.libs.jaicore.ml.tsc.quality_measures.IQualityMeasure;
import ai.libs.jaicore.ml.tsc.shapelets.Shapelet;
import ai.libs.jaicore.ml.tsc.shapelets.search.AMinimumDistanceSearchStrategy;
import ai.libs.jaicore.ml.tsc.shapelets.search.EarlyAbandonMinimumDistanceSearchStrategy;
import ai.libs.jaicore.ml.tsc.util.TimeSeriesUtil;
import ai.libs.jaicore.ml.tsc.util.WekaUtil;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import org.aeonbits.owner.Config;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;

public class ShapeletTransformLearningAlgorithm
extends ASimplifiedTSCLearningAlgorithm<Integer, ShapeletTransformTSClassifier> {
    private static final Logger logger = LoggerFactory.getLogger(ShapeletTransformLearningAlgorithm.class);
    private final IQualityMeasure qualityMeasure;
    private static final int MIN_MAX_ESTIMATION_SAMPLES = 10;
    private static final boolean USE_BIAS_CORRECTION = true;
    private AMinimumDistanceSearchStrategy minDistanceSearchStrategy = new EarlyAbandonMinimumDistanceSearchStrategy(true);
    private static final String INTERRUPTION_MESSAGE = "Interrupted training due to timeout.";

    public ShapeletTransformLearningAlgorithm(IShapeletTransformLearningAlgorithmConfig config, ShapeletTransformTSClassifier classifier, TimeSeriesDataset dataset, IQualityMeasure qualityMeasure) {
        super((IAlgorithmConfig)config, classifier, dataset);
        this.qualityMeasure = qualityMeasure;
    }

    public ShapeletTransformTSClassifier call() throws AlgorithmException, InterruptedException {
        if (this.getNumCPUs() > 1) {
            logger.warn("Multithreading is not supported for LearnShapelets yet. Therefore, the number of CPUs is not considered.");
        }
        long beginTime = System.currentTimeMillis();
        TimeSeriesDataset data = (TimeSeriesDataset)this.getInput();
        if (data == null || data.isEmpty()) {
            throw new IllegalStateException("The time series input data must not be null or empty!");
        }
        if (data.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate datasets are not supported.");
        }
        double[][] dataMatrix = data.getValuesOrNull(0);
        if (dataMatrix == null) {
            throw new IllegalArgumentException("Value matrix must be a valid 2D matrix containing the time series values for all instances!");
        }
        int[] targetMatrix = data.getTargets();
        int minShapeletLength = this.getConfig().minShapeletLength();
        int maxShapeletLength = this.getConfig().maxShapeletLength();
        int seed = this.getConfig().seed();
        ShapeletTransformTSClassifier model = (ShapeletTransformTSClassifier)this.getClassifier();
        int timeSeriesLength = dataMatrix[0].length;
        if (this.getConfig().estimateShapeletLengthBorders()) {
            logger.debug("Starting min max estimation.");
            int[] minMax = this.estimateMinMax(dataMatrix, targetMatrix, beginTime);
            minShapeletLength = minMax[0];
            maxShapeletLength = minMax[1];
            logger.debug("Finished min max estimation. min={}, max={}", (Object)minShapeletLength, (Object)maxShapeletLength);
        } else if (maxShapeletLength == -1) {
            maxShapeletLength = timeSeriesLength - 1;
        }
        if (maxShapeletLength >= timeSeriesLength) {
            logger.debug("The maximum shapelet length was larger than the total time series length. Therefore, it will be set to time series length - 1.");
            maxShapeletLength = timeSeriesLength - 1;
        }
        logger.debug("Starting cached shapelet selection with min={}, max={} and k={}...", new Object[]{minShapeletLength, maxShapeletLength, this.getConfig().numShapelets()});
        List<Shapelet> shapelets = null;
        shapelets = this.shapeletCachedSelection(dataMatrix, minShapeletLength, maxShapeletLength, this.getConfig().numShapelets(), targetMatrix, beginTime);
        logger.debug("Finished cached shapelet selection. Extracted {} shapelets.", (Object)shapelets.size());
        if (this.getConfig().clusterShapelets()) {
            logger.debug("Starting shapelet clustering...");
            shapelets = this.clusterShapelets(shapelets, this.getConfig().numClusters(), beginTime);
            logger.debug("Finished shapelet clustering. Staying with {} shapelets.", (Object)shapelets.size());
        }
        model.setShapelets(shapelets);
        logger.debug("Transforming the training data using the extracted shapelets.");
        TimeSeriesDataset transfTrainingData = ShapeletTransformLearningAlgorithm.shapeletTransform(data, model.getShapelets(), this.getTimeout(), beginTime, this.minDistanceSearchStrategy);
        logger.debug("Finished transforming the training data.");
        logger.debug("Initializing ensemble classifier...");
        Classifier classifier = null;
        try {
            classifier = this.getConfig().useHIVECOTEEnsemble() ? EnsembleProvider.provideHIVECOTEEnsembleModel(seed, this.getConfig().numFolds()) : EnsembleProvider.provideCAWPEEnsembleModel(seed, this.getConfig().numFolds());
        }
        catch (Exception e1) {
            throw new AlgorithmException((Throwable)e1, "Could not train model due to ensemble exception.");
        }
        logger.debug("Initialized ensemble classifier.");
        logger.debug("Starting ensemble training...");
        try {
            WekaUtil.buildWekaClassifierFromSimplifiedTS(classifier, transfTrainingData);
        }
        catch (TrainingException e) {
            throw new AlgorithmException((Throwable)e, "Could not train classifier due to a training exception.");
        }
        logger.debug("Finished ensemble training.");
        model.setClassifier(classifier);
        return model;
    }

    private int[] estimateMinMax(double[][] data, int[] classes, long beginTime) throws InterruptedException {
        int[] result = new int[2];
        long numInstances = data.length;
        ArrayList<Shapelet> shapelets = new ArrayList<Shapelet>();
        for (int i = 0; i < 10; ++i) {
            double[][] tmpMatrix = new double[10][data[0].length];
            Random rand = new Random(this.getConfig().seed());
            int[] tmpClasses = new int[10];
            for (int j = 0; j < 10; ++j) {
                int nextIndex = (int)((long)rand.nextInt() % numInstances);
                if (nextIndex < 0) {
                    nextIndex = (int)((long)nextIndex + numInstances);
                }
                for (int k = 0; k < data[0].length; ++k) {
                    tmpMatrix[j][k] = data[nextIndex][k];
                }
                tmpClasses[j] = classes[nextIndex];
            }
            shapelets.addAll(this.shapeletCachedSelection(tmpMatrix, 3, data[0].length, 10, tmpClasses, beginTime));
        }
        Shapelet.sortByLengthAsc(shapelets);
        logger.debug("Number of shapelets found in min/max estimation: {}", (Object)shapelets.size());
        result[0] = ((Shapelet)shapelets.get(25)).getLength();
        result[1] = ((Shapelet)shapelets.get(75)).getLength();
        return result;
    }

    public List<Shapelet> clusterShapelets(List<Shapelet> shapelets, int noClusters, long beginTime) throws InterruptedException {
        ArrayList<List<Object>> clusters = new ArrayList<List<Object>>();
        for (Shapelet shapelet : shapelets) {
            ArrayList<Shapelet> list = new ArrayList<Shapelet>();
            list.add(shapelet);
            clusters.add(list);
        }
        while (clusters.size() > noClusters) {
            if (System.currentTimeMillis() - beginTime > this.getTimeout().milliseconds()) {
                throw new InterruptedException(INTERRUPTION_MESSAGE);
            }
            INDArray distanceMatrix = Nd4j.create((int)clusters.size(), (int)clusters.size());
            for (int i = 0; i < clusters.size(); ++i) {
                int j = 0;
                while (j < clusters.size()) {
                    double distance = 0.0;
                    int comparisons = ((List)clusters.get(i)).size() * ((List)clusters.get(j)).size();
                    for (int l = 0; l < ((List)clusters.get(i)).size(); ++l) {
                        for (int k = 0; k < ((List)clusters.get(j)).size(); ++k) {
                            Shapelet cl = (Shapelet)((List)clusters.get(i)).get(l);
                            Shapelet ck = (Shapelet)((List)clusters.get(j)).get(k);
                            if (cl.getLength() > ck.getLength()) {
                                distance += this.minDistanceSearchStrategy.findMinimumDistance(ck, cl.getData());
                                continue;
                            }
                            distance += this.minDistanceSearchStrategy.findMinimumDistance(cl, ck.getData());
                        }
                    }
                    distanceMatrix.putScalar(new int[]{i, j++}, distance / (double)comparisons);
                }
            }
            double best = Double.MAX_VALUE;
            int x = 0;
            int y = 0;
            int i = 0;
            while ((long)i < distanceMatrix.shape()[0]) {
                int j = 0;
                while ((long)j < distanceMatrix.shape()[1]) {
                    if (distanceMatrix.getDouble((long)i, (long)j) < best && i != j) {
                        x = i;
                        y = j;
                        best = distanceMatrix.getDouble((long)i, (long)j);
                    }
                    ++j;
                }
                ++i;
            }
            List clusterUpdate = (List)clusters.get(x);
            clusterUpdate.addAll((Collection)clusters.get(y));
            Shapelet maxClusterShapelet = Shapelet.getHighestQualityShapeletInList(clusterUpdate);
            if (x > y) {
                clusters.remove(x);
                clusters.remove(y);
            } else {
                clusters.remove(y);
                clusters.remove(x);
            }
            clusters.add(Arrays.asList(maxClusterShapelet));
        }
        return clusters.stream().flatMap(Collection::stream).collect(Collectors.toList());
    }

    private List<Shapelet> shapeletCachedSelection(double[][] data, int min, int max, int k, int[] classes, long beginTime) throws InterruptedException {
        List<Map.Entry<Shapelet, Double>> kShapelets = new ArrayList<Map.Entry<Shapelet, Double>>();
        int numInstances = data.length;
        for (int i = 0; i < numInstances; ++i) {
            if (System.currentTimeMillis() - beginTime > this.getTimeout().milliseconds()) {
                throw new InterruptedException(INTERRUPTION_MESSAGE);
            }
            List<Map.Entry<Shapelet, Double>> shapelets = new ArrayList<Map.Entry<Shapelet, Double>>();
            for (int l = min; l < max; ++l) {
                Set<Shapelet> candidates = ShapeletTransformLearningAlgorithm.generateCandidates(data[i], l, i);
                for (Shapelet s : candidates) {
                    List<Double> distances = this.findDistances(s, data);
                    double quality = this.qualityMeasure.assessQuality(distances, classes);
                    s.setDeterminedQuality(quality);
                    shapelets.add(new AbstractMap.SimpleEntry<Shapelet, Double>(s, quality));
                }
            }
            ShapeletTransformLearningAlgorithm.sortByQualityDesc(shapelets);
            shapelets = ShapeletTransformLearningAlgorithm.removeSelfSimilar(shapelets);
            kShapelets = ShapeletTransformLearningAlgorithm.merge(k, kShapelets, shapelets);
        }
        return kShapelets.stream().map(Map.Entry::getKey).collect(Collectors.toList());
    }

    public static List<Map.Entry<Shapelet, Double>> merge(int k, List<Map.Entry<Shapelet, Double>> kShapelets, List<Map.Entry<Shapelet, Double>> shapelets) {
        kShapelets.addAll(shapelets);
        ShapeletTransformLearningAlgorithm.sortByQualityDesc(kShapelets);
        int numRemoveItems = kShapelets.size() - k;
        for (int i = 0; i < numRemoveItems; ++i) {
            kShapelets.remove(kShapelets.size() - 1);
        }
        return kShapelets;
    }

    private static void sortByQualityDesc(List<Map.Entry<Shapelet, Double>> list) {
        list.sort((e1, e2) -> -1 * ((Double)e1.getValue()).compareTo((Double)e2.getValue()));
    }

    public static List<Map.Entry<Shapelet, Double>> removeSelfSimilar(List<Map.Entry<Shapelet, Double>> shapelets) {
        ArrayList<Map.Entry<Shapelet, Double>> result = new ArrayList<Map.Entry<Shapelet, Double>>();
        for (Map.Entry<Shapelet, Double> entry : shapelets) {
            boolean selfSimilarExisting = false;
            for (Map.Entry entry2 : result) {
                if (!ShapeletTransformLearningAlgorithm.isSelfSimilar(entry.getKey(), (Shapelet)entry2.getKey())) continue;
                selfSimilarExisting = true;
            }
            if (selfSimilarExisting) continue;
            result.add(entry);
        }
        return result;
    }

    private static boolean isSelfSimilar(Shapelet s1, Shapelet s2) {
        if (s1.getInstanceIndex() == s2.getInstanceIndex()) {
            return s1.getStartIndex() < s2.getStartIndex() + s2.getLength() && s2.getStartIndex() < s1.getStartIndex() + s1.getLength();
        }
        return false;
    }

    public List<Double> findDistances(Shapelet s, double[][] matrix) {
        ArrayList<Double> result = new ArrayList<Double>();
        for (int i = 0; i < matrix.length; ++i) {
            result.add(this.minDistanceSearchStrategy.findMinimumDistance(s, matrix[i]));
        }
        return result;
    }

    public static Set<Shapelet> generateCandidates(double[] data, int l, int candidateIndex) {
        HashSet<Shapelet> result = new HashSet<Shapelet>();
        for (int i = 0; i < data.length - l + 1; ++i) {
            double[] tmpData = TimeSeriesUtil.getInterval(data, i, i + l);
            result.add(new Shapelet(TimeSeriesUtil.zNormalize(tmpData, true), i, l, candidateIndex));
        }
        return result;
    }

    public static TimeSeriesDataset shapeletTransform(TimeSeriesDataset dataSet, List<Shapelet> shapelets, TimeOut timeout, long beginTime, AMinimumDistanceSearchStrategy searchStrategy) throws InterruptedException {
        if (dataSet.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate datasets are not supported yet!");
        }
        double[][] timeSeries = dataSet.getValuesOrNull(0);
        if (timeSeries == null) {
            throw new IllegalArgumentException("Time series matrix must be a valid 2d matrix!");
        }
        double[][] transformedTS = new double[timeSeries.length][];
        for (int i = 0; i < timeSeries.length; ++i) {
            if (timeout != null && System.currentTimeMillis() - beginTime > timeout.milliseconds()) {
                throw new InterruptedException(INTERRUPTION_MESSAGE);
            }
            transformedTS[i] = ShapeletTransformLearningAlgorithm.shapeletTransform(timeSeries[i], shapelets, searchStrategy);
        }
        dataSet.replace(0, transformedTS, dataSet.getTimestampsOrNull(0));
        return dataSet;
    }

    public static double[] shapeletTransform(double[] instance, List<Shapelet> shapelets, AMinimumDistanceSearchStrategy searchStrategy) {
        double[] transformedTS = new double[shapelets.size()];
        for (int j = 0; j < shapelets.size(); ++j) {
            transformedTS[j] = searchStrategy.findMinimumDistance(shapelets.get(j), instance);
        }
        return transformedTS;
    }

    public AMinimumDistanceSearchStrategy getMinDistanceSearchStrategy() {
        return this.minDistanceSearchStrategy;
    }

    public void setMinDistanceSearchStrategy(AMinimumDistanceSearchStrategy minDistanceSearchStrategy) {
        this.minDistanceSearchStrategy = minDistanceSearchStrategy;
    }

    @Override
    public void registerListener(Object listener) {
        throw new UnsupportedOperationException("The operation to be performed is not supported.");
    }

    @Override
    public AlgorithmEvent nextWithException() {
        throw new UnsupportedOperationException("The operation to be performed is not supported.");
    }

    public IShapeletTransformLearningAlgorithmConfig getConfig() {
        return (IShapeletTransformLearningAlgorithmConfig)super.getConfig();
    }

    public static interface IShapeletTransformLearningAlgorithmConfig
    extends IRandomAlgorithmConfig {
        public static final String K_NUMSHAPELETS = "numshapelets";
        public static final String K_NUMCLUSTERS = "numclusters";
        public static final String K_CLUSTERSHAPELETS = "clustershapelets";
        public static final String K_SHAPELETLENGTH_MIN = "minshapeletlength";
        public static final String K_SHAPELETLENGTH_MAX = "maxshapeletlength";
        public static final String K_USEHIVECOTEENSEMBLE = "usehivecoteensemble";
        public static final String K_ESTIMATESHAPELETLENGTHBORDERS = "estimateshapeletlengthborders";
        public static final String K_NUMFOLDS = "numfolds";

        @Config.Key(value="numshapelets")
        @Config.DefaultValue(value="10")
        public int numShapelets();

        @Config.Key(value="numclusters")
        @Config.DefaultValue(value="10")
        public int numClusters();

        @Config.Key(value="clustershapelets")
        @Config.DefaultValue(value="false")
        public boolean clusterShapelets();

        @Config.Key(value="minshapeletlength")
        @Config.DefaultValue(value="3")
        public int minShapeletLength();

        @Config.Key(value="maxshapeletlength")
        public int maxShapeletLength();

        @Config.Key(value="usehivecoteensemble")
        public boolean useHIVECOTEEnsemble();

        @Config.Key(value="estimateshapeletlengthborders")
        public boolean estimateShapeletLengthBorders();

        @Config.Key(value="numfolds")
        @Config.DefaultValue(value="5")
        public int numFolds();
    }
}

