/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.tsc.classifier.shapelets;

import ai.libs.jaicore.basic.TimeOut;
import ai.libs.jaicore.basic.algorithm.IAlgorithmConfig;
import ai.libs.jaicore.basic.algorithm.IRandomAlgorithmConfig;
import ai.libs.jaicore.basic.algorithm.events.AlgorithmEvent;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm;
import ai.libs.jaicore.ml.tsc.classifier.shapelets.LearnShapeletsClassifier;
import ai.libs.jaicore.ml.tsc.dataset.TimeSeriesDataset;
import ai.libs.jaicore.ml.tsc.util.MathUtil;
import ai.libs.jaicore.ml.tsc.util.TimeSeriesUtil;
import ai.libs.jaicore.ml.tsc.util.WekaUtil;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.aeonbits.owner.Config;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.clusterers.SimpleKMeans;
import weka.core.Instances;

public class LearnShapeletsLearningAlgorithm
extends ASimplifiedTSCLearningAlgorithm<Integer, LearnShapeletsClassifier> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LearnShapeletsLearningAlgorithm.class);
    private int numInstances;
    private int q;
    private int numClasses;
    public static final boolean USE_BIAS_CORRECTION = false;
    public static final double ALPHA = -30.0;
    private static final double EPS = 1.0E-21;
    private TimeOut timeout = new TimeOut(Integer.MAX_VALUE, TimeUnit.SECONDS);
    private boolean useInstanceReordering = true;

    public LearnShapeletsLearningAlgorithm(ILearnShapeletsLearningAlgorithmConfig config, LearnShapeletsClassifier classifier, TimeSeriesDataset dataset) {
        super((IAlgorithmConfig)config, classifier, dataset);
    }

    public double[][][] initializeS(double[][] trainingMatrix) throws TrainingException {
        LOGGER.debug("Initializing S...");
        int scaleR = this.getConfig().scaleR();
        int seed = this.getConfig().seed();
        int minShapeLength = this.getConfig().minShapeletLength();
        double[][][] result = new double[scaleR][][];
        for (int r = 0; r < scaleR; ++r) {
            int numberOfSegments = LearnShapeletsLearningAlgorithm.getNumberOfSegments(this.q, minShapeLength, r);
            if (numberOfSegments < 1) {
                throw new TrainingException("The number of segments is lower than 1. Can not train the LearnShapelets model.");
            }
            int L = (r + 1) * minShapeLength;
            double[][] tmpSegments = new double[trainingMatrix.length * numberOfSegments][L];
            for (int i = 0; i < trainingMatrix.length; ++i) {
                for (int j = 0; j < numberOfSegments; ++j) {
                    for (int l = 0; l < L; ++l) {
                        tmpSegments[i * numberOfSegments + j][l] = trainingMatrix[i][j + l];
                    }
                    tmpSegments[i * numberOfSegments + j] = TimeSeriesUtil.zNormalize(tmpSegments[i * numberOfSegments + j], false);
                }
            }
            Instances wekaInstances = WekaUtil.matrixToWekaInstances(tmpSegments);
            SimpleKMeans kMeans = new SimpleKMeans();
            try {
                kMeans.setNumClusters(this.getConfig().numShapelets());
                kMeans.setSeed(seed);
                kMeans.setMaxIterations(100);
                kMeans.buildClusterer(wekaInstances);
            }
            catch (Exception e) {
                LOGGER.warn("Could not initialize matrix S using kMeans clustering for r={} due to the following problem: {}. Using zero matrix instead (possibly leading to a poor training performance).", (Object)r, (Object)e.getMessage());
                result[r] = new double[this.getConfig().numShapelets()][r * minShapeLength];
                continue;
            }
            Instances clusterCentroids = kMeans.getClusterCentroids();
            double[][] tmpResult = new double[clusterCentroids.numInstances()][clusterCentroids.numAttributes()];
            for (int j = 0; j < tmpResult.length; ++j) {
                double[] instValues = clusterCentroids.get(j).toDoubleArray();
                for (int k = 0; k < tmpResult[j].length; ++k) {
                    tmpResult[j][k] = instValues[k];
                }
            }
            result[r] = tmpResult;
        }
        LOGGER.debug("Initialized S.");
        return result;
    }

    public LearnShapeletsClassifier call() throws AlgorithmException {
        double[][][] s;
        long beginTime = System.currentTimeMillis();
        TimeSeriesDataset data = (TimeSeriesDataset)this.getInput();
        if (data.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate datasets are not supported.");
        }
        if (data.isEmpty()) {
            throw new IllegalArgumentException("The training dataset must not be null!");
        }
        double[][] dataMatrix = data.getValuesOrNull(0);
        if (dataMatrix == null) {
            throw new IllegalArgumentException("Timestamp matrix must be a valid 2D matrix containing the time series values for all instances!");
        }
        int[] targetMatrix = data.getTargets();
        List<Integer> occuringClasses = TimeSeriesUtil.getClassesInDataset(data);
        this.numInstances = data.getNumberOfInstances();
        this.q = dataMatrix[0].length;
        this.numClasses = occuringClasses.size();
        this.getConfig().setProperty("minshapeletlength", "" + this.getConfig().minShapeLengthPercentage() * (double)this.q);
        int minShapeLength = this.getConfig().minShapeletLength();
        int scaleR = this.getConfig().scaleR();
        int[][] y = new int[this.numInstances][this.numClasses];
        for (int i = 0; i < this.numInstances; ++i) {
            Integer instanceClass = targetMatrix[i];
            y[i][occuringClasses.indexOf((Object)instanceClass)] = 1;
        }
        if (this.getConfig().estimateK()) {
            int totalSegments = 0;
            for (int r = 0; r < scaleR; ++r) {
                int numberOfSegments = LearnShapeletsLearningAlgorithm.getNumberOfSegments(this.q, minShapeLength, r);
                totalSegments += numberOfSegments * this.numInstances;
            }
            int k = (int)(Math.log(totalSegments) * (double)(this.numClasses - 1));
            this.getConfig().setProperty("numshapelets", "" + (k >= 0 ? k : 1));
        }
        int k = this.getConfig().numShapelets();
        LOGGER.info("Parameters: k={}, learningRate={}, reg={}, r={}, minShapeLength={}, maxIter={}, Q={}, C={}", new Object[]{k, this.getConfig().learningRate(), this.getConfig().regularization(), scaleR, this.getConfig().minShapeletLength(), this.getConfig().maxIterations(), this.q, this.numClasses});
        try {
            s = this.initializeS(dataMatrix);
        }
        catch (TrainingException e) {
            throw new AlgorithmException((Throwable)e, "Can not train LearnShapelets model due to error during initialization of S.");
        }
        double[][][] sHist = new double[scaleR][][];
        for (int r = 0; r < scaleR; ++r) {
            sHist[r] = new double[s[r].length][s[r][0].length];
        }
        double[][][] w = new double[this.numClasses][scaleR][k];
        double[][][] wHist = new double[this.numClasses][scaleR][k];
        double[] w0 = new double[this.numClasses];
        double[] w0Hist = new double[this.numClasses];
        this.initializeWeights(w, w0);
        LOGGER.debug("Starting training for {} iterations...", (Object)this.getConfig().maxIterations());
        this.performSGD(w, wHist, w0, w0Hist, s, sHist, dataMatrix, y, beginTime, targetMatrix);
        LOGGER.debug("Finished training.");
        LearnShapeletsClassifier model = (LearnShapeletsClassifier)this.getClassifier();
        model.setS(s);
        model.setW(w);
        model.setW0(w0);
        model.setC(this.numClasses);
        return model;
    }

    public void initializeWeights(double[][][] w, double[] w0) {
        Random rand = new Random(this.getConfig().seed());
        int scaleR = this.getConfig().scaleR();
        int numShapelets = this.getConfig().numShapelets();
        for (int i = 0; i < this.numClasses; ++i) {
            w0[i] = 1.0E-21 * rand.nextDouble() * Math.pow(-1.0, rand.nextInt(2));
            for (int j = 0; j < scaleR; ++j) {
                for (int k = 0; k < numShapelets; ++k) {
                    w[i][j][k] = 1.0E-21 * rand.nextDouble() * Math.pow(-1.0, rand.nextInt(2));
                }
            }
        }
    }

    public void performSGD(double[][][] w, double[][][] wHist, double[] w0, double[] w0Hist, double[][][] s, double[][][] sHist, double[][] dataMatrix, int[][] y, long beginTime, int[] targets) {
        int scaleR = this.getConfig().scaleR();
        int minShapeLength = this.getConfig().minShapeletLength();
        int maxIter = this.getConfig().maxIterations();
        long seed = this.getConfig().seed();
        int numShapelets = this.getConfig().numShapelets();
        double learningRate = this.getConfig().learningRate();
        double regularization = this.getConfig().regularization();
        double gamma = this.getConfig().gamma();
        double[][][][] d = new double[scaleR][][][];
        double[][][][] xi = new double[scaleR][][][];
        double[][][][] phi = new double[scaleR][][][];
        int[] numberOfSegments = new int[scaleR];
        for (int r = 0; r < scaleR; ++r) {
            numberOfSegments[r] = LearnShapeletsLearningAlgorithm.getNumberOfSegments(this.q, minShapeLength, r);
            d[r] = new double[this.numInstances][numShapelets][numberOfSegments[r]];
            xi[r] = new double[this.numInstances][numShapelets][numberOfSegments[r]];
            phi[r] = new double[this.numInstances][numShapelets][numberOfSegments[r]];
        }
        double[][][] psi = new double[scaleR][this.numInstances][numShapelets];
        double[][][] mHat = new double[scaleR][this.numInstances][numShapelets];
        double[][] theta = new double[this.numInstances][this.numClasses];
        List<Integer> indices = IntStream.range(0, this.numInstances).boxed().collect(Collectors.toList());
        LOGGER.debug("Starting training for {} iterations...", (Object)maxIter);
        double[][][] velocitiesW = new double[w.length][w[0].length][w[0][0].length];
        double[] velocitiesW0 = new double[w0.length];
        double[][][] velocitiesS = new double[s.length][][];
        for (int i = 0; i < s.length; ++i) {
            velocitiesS[i] = new double[s[i].length][];
            for (int j = 0; j < s[i].length; ++j) {
                velocitiesS[i][j] = new double[s[i][j].length];
            }
        }
        for (int it = 0; it < maxIter; ++it) {
            if (this.useInstanceReordering) {
                indices = this.shuffleAccordingToAlternatingClassScheme(indices, targets, new Random(seed + (long)it));
            } else {
                Collections.shuffle(indices, new Random(seed + (long)it));
            }
            for (int idx = 0; idx < this.numInstances; ++idx) {
                int k;
                int r;
                int c;
                int i = indices.get(idx);
                for (int r2 = 0; r2 < scaleR; ++r2) {
                    long kBound = s[r2].length;
                    int k2 = 0;
                    while ((long)k2 < kBound) {
                        int jr = numberOfSegments[r2];
                        for (int j = 0; j < jr; ++j) {
                            double newDValue;
                            d[r2][i][k2][j] = newDValue = LearnShapeletsLearningAlgorithm.calculateD(s, minShapeLength, r2, dataMatrix[i], k2, j);
                            xi[r2][i][k2][j] = newDValue = Math.exp(-30.0 * newDValue);
                        }
                        double newPsiValue = 0.0;
                        double newMHatValue = 0.0;
                        for (int j = 0; j < jr; ++j) {
                            newPsiValue += xi[r2][i][k2][j];
                            newMHatValue += d[r2][i][k2][j] * xi[r2][i][k2][j];
                        }
                        psi[r2][i][k2] = newPsiValue;
                        mHat[r2][i][k2] = newMHatValue /= psi[r2][i][k2];
                        ++k2;
                    }
                }
                for (c = 0; c < this.numClasses; ++c) {
                    double newThetaValue = 0.0;
                    for (r = 0; r < scaleR; ++r) {
                        for (k = 0; k < numShapelets; ++k) {
                            newThetaValue += mHat[r][i][k] * w[c][r][k];
                        }
                    }
                    theta[i][c] = (double)y[i][c] - MathUtil.sigmoid(newThetaValue);
                }
                for (c = 0; c < this.numClasses; ++c) {
                    double gradw0 = theta[i][c];
                    for (r = 0; r < scaleR; ++r) {
                        for (k = 0; k < s[r].length; ++k) {
                            double wStep = -1.0 * theta[i][c] * mHat[r][i][k] + 2.0 * regularization / (double)this.numInstances * w[c][r][k];
                            velocitiesW[c][r][k] = gamma * velocitiesW[c][r][k] + learningRate * wStep;
                            double[] dArray = wHist[c][r];
                            int n = k;
                            dArray[n] = dArray[n] + wStep * wStep;
                            double[] dArray2 = w[c][r];
                            int n2 = k;
                            dArray2[n2] = dArray2[n2] - velocitiesW[c][r][k] / Math.sqrt(wHist[c][r][k] + 1.0E-21);
                            int jr = numberOfSegments[r];
                            double phiDenominator = 1.0 / (((double)r + 1.0) * (double)minShapeLength * psi[r][i][k]);
                            double[] distDiff = new double[jr];
                            for (int j = 0; j < jr; ++j) {
                                distDiff[j] = xi[r][i][k][j] * (1.0 + -30.0 * (d[r][i][k][j] - mHat[r][i][k]));
                            }
                            for (int l = 0; l < (r + 1) * minShapeLength; ++l) {
                                double shapeletDiff = 0.0;
                                for (int j = 0; j < jr; ++j) {
                                    shapeletDiff += distDiff[j] * (s[r][k][l] - dataMatrix[i][j + l]);
                                }
                                double sStep = -1.0 * gradw0 * shapeletDiff * w[c][r][k] * phiDenominator;
                                velocitiesS[r][k][l] = gamma * velocitiesS[r][k][l] + learningRate * sStep;
                                double[] dArray3 = sHist[r][k];
                                int n3 = l;
                                dArray3[n3] = dArray3[n3] + sStep * sStep;
                                double[] dArray4 = s[r][k];
                                int n4 = l;
                                dArray4[n4] = dArray4[n4] - velocitiesS[r][k][l] / Math.sqrt(sHist[r][k][l] + 1.0E-21);
                            }
                        }
                    }
                    velocitiesW0[c] = gamma * velocitiesW0[c] + learningRate * gradw0;
                    int n = c;
                    w0Hist[n] = w0Hist[n] + gradw0 * gradw0;
                    int n5 = c;
                    w0[n5] = w0[n5] + velocitiesW0[c] / Math.sqrt(w0Hist[c] + 1.0E-21);
                }
            }
            if (it % 10 != 0) continue;
            LOGGER.debug("Iteration {}/{}", (Object)it, (Object)maxIter);
            long currTime = System.currentTimeMillis();
            if (currTime - beginTime <= this.timeout.milliseconds()) continue;
            LOGGER.debug("Stopping training due to timeout.");
            break;
        }
    }

    public List<Integer> shuffleAccordingToAlternatingClassScheme(List<Integer> instanceIndices, int[] targets, Random random) {
        if (instanceIndices.size() != targets.length) {
            throw new IllegalArgumentException("The number of instances must be equal to the number of available target values!");
        }
        HashMap indicesPerClass = new HashMap();
        for (int i = 0; i < instanceIndices.size(); ++i) {
            int classIdx = targets[i];
            if (!indicesPerClass.containsKey(classIdx)) {
                indicesPerClass.put(classIdx, new ArrayList());
            }
            ((List)indicesPerClass.get(classIdx)).add(i);
        }
        ArrayList iteratorList = new ArrayList();
        for (List list : indicesPerClass.values()) {
            Collections.shuffle(list, random);
            iteratorList.add(list.iterator());
        }
        ArrayList<Integer> resultList = new ArrayList<Integer>();
        Iterator roundRobinIt = Iterables.cycle(iteratorList).iterator();
        block2: for (int i = 0; i < instanceIndices.size(); ++i) {
            for (int tmpCounter = 0; roundRobinIt.hasNext() && tmpCounter < this.numClasses; ++tmpCounter) {
                Iterator tmpIt = (Iterator)roundRobinIt.next();
                if (!tmpIt.hasNext()) {
                    continue;
                }
                resultList.add((Integer)tmpIt.next());
                continue block2;
            }
        }
        return resultList;
    }

    public static double calculateMHat(double[][][] s, int minShapeLength, int r, double[] instance, int k, int Q, double alpha) {
        double nominator = 0.0;
        double denominator = 0.0;
        for (int j = 0; j < LearnShapeletsLearningAlgorithm.getNumberOfSegments(Q, minShapeLength, r); ++j) {
            double d = LearnShapeletsLearningAlgorithm.calculateD(s, minShapeLength, r, instance, k, j);
            double expD = Math.exp(alpha * d);
            nominator += d * expD;
            denominator += expD;
        }
        denominator = denominator == 0.0 ? 1.0E-21 : denominator;
        return nominator / denominator;
    }

    public static double calculateD(double[][][] s, int minShapeLength, int r, double[] instance, int k, int j) {
        double result = 0.0;
        for (int l = 0; l < (r + 1) * minShapeLength; ++l) {
            result += Math.pow(instance[j + l] - s[r][k][l], 2.0);
        }
        return result / (double)((r + 1) * minShapeLength);
    }

    @Override
    public AlgorithmEvent nextWithException() {
        throw new UnsupportedOperationException("The operation to be performed is not supported.");
    }

    public static int getNumberOfSegments(int Q, int minShapeLength, int r) {
        return Q - (r + 1) * minShapeLength;
    }

    public ILearnShapeletsLearningAlgorithmConfig getConfig() {
        return (ILearnShapeletsLearningAlgorithmConfig)super.getConfig();
    }

    public boolean isUseInstanceReordering() {
        return this.useInstanceReordering;
    }

    public void setUseInstanceReordering(boolean useInstanceReordering) {
        this.useInstanceReordering = useInstanceReordering;
    }

    public int getC() {
        return this.numClasses;
    }

    public void setC(int c) {
        this.numClasses = c;
    }

    public static interface ILearnShapeletsLearningAlgorithmConfig
    extends IRandomAlgorithmConfig {
        public static final String K_NUMSHAPELETS = "numshapelets";
        public static final String K_LEARNINGRATE = "learningrate";
        public static final String K_REGULARIZATION = "regularization";
        public static final String K_SHAPELETLENGTH_MIN = "minshapeletlength";
        public static final String K_SHAPELETLENGTH_RELMIN = "relativeminshapeletlength";
        public static final String K_SCALER = "scaler";
        public static final String K_MAXITER = "maxiter";
        public static final String K_GAMMA = "gamma";
        public static final String K_ESTIMATEK = "estimatek";

        @Config.Key(value="numshapelets")
        public int numShapelets();

        @Config.Key(value="learningrate")
        public double learningRate();

        @Config.Key(value="regularization")
        public double regularization();

        @Config.Key(value="minshapeletlength")
        public int minShapeletLength();

        @Config.Key(value="relativeminshapeletlength")
        public double minShapeLengthPercentage();

        @Config.Key(value="scaler")
        public int scaleR();

        @Config.Key(value="maxiter")
        public int maxIterations();

        @Config.Key(value="gamma")
        @Config.DefaultValue(value="0.5")
        public double gamma();

        @Config.Key(value="estimatek")
        @Config.DefaultValue(value="false")
        public boolean estimateK();
    }
}

