/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.shapelets;

import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.ASimplifiedTSClassifier;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.learner.shapelets.LearnShapeletsLearningAlgorithm;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.MathUtil;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.util.TimeSeriesUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.aeonbits.owner.ConfigCache;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LearnShapeletsClassifier
extends ASimplifiedTSClassifier<Integer> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LearnShapeletsClassifier.class);
    private double[][][] s;
    private double[][][] w;
    private double[] w0;
    private int c;
    private final LearnShapeletsLearningAlgorithm.ILearnShapeletsLearningAlgorithmConfig config = (LearnShapeletsLearningAlgorithm.ILearnShapeletsLearningAlgorithmConfig)ConfigCache.getOrCreate(LearnShapeletsLearningAlgorithm.ILearnShapeletsLearningAlgorithmConfig.class, (Map[])new Map[0]);

    public LearnShapeletsClassifier(int K, double learningRate, double regularization, int scaleR, double minShapeLengthPercentage, int maxIter, int seed) {
        this(K, learningRate, regularization, scaleR, minShapeLengthPercentage, maxIter, 0.5, seed);
    }

    public LearnShapeletsClassifier(int K, double learningRate, double regularization, int scaleR, double minShapeLengthPercentage, int maxIter, double gamma, int seed) {
        this.config.setProperty("numshapelets", "" + K);
        this.config.setProperty("regularization", "" + regularization);
        this.config.setProperty("scaler", "" + scaleR);
        this.config.setProperty("relativeminshapeletlength", "" + minShapeLengthPercentage);
        this.config.setProperty("seed", "" + seed);
        this.config.setProperty("maxiter", "" + maxIter);
        this.config.setProperty("learningrate", "" + learningRate);
        this.config.setProperty("gamma", "" + gamma);
    }

    public void setEstimateK(boolean estimateK) {
        this.config.setProperty("estimatek", "" + estimateK);
    }

    public double[][][] getS() {
        return this.s;
    }

    public void setS(double[][][] s) {
        this.s = s;
    }

    public double[][][] getW() {
        return this.w;
    }

    public void setW(double[][][] w) {
        this.w = w;
    }

    public double[] getW0() {
        return this.w0;
    }

    public void setW0(double[] w0) {
        this.w0 = w0;
    }

    public void setC(int c) {
        this.c = c;
    }

    public void setMinShapeLength(int minShapeLength) {
        this.config.setProperty("minshapeletlength", "" + minShapeLength);
    }

    public Integer predict(double[] univInstance) throws PredictionException {
        if (!this.isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        HashMap<Integer, Double> scoring = new HashMap<Integer, Double>();
        univInstance = TimeSeriesUtil.zNormalize((double[])univInstance, (boolean)false);
        for (int i = 0; i < this.c; ++i) {
            double tmpScore = this.w0[i];
            for (int r = 0; r < this.config.scaleR(); ++r) {
                for (int k = 0; k < this.s[r].length; ++k) {
                    tmpScore += LearnShapeletsLearningAlgorithm.calculateMHat(this.s, this.config.minShapeletLength(), r, univInstance, k, univInstance.length, -30.0) * this.w[i][r][k];
                }
            }
            scoring.put(i, MathUtil.sigmoid((double)tmpScore));
        }
        return (Integer)Collections.max(scoring.entrySet(), Map.Entry.comparingByValue()).getKey();
    }

    public Integer predict(List<double[]> multivInstance) throws PredictionException {
        LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        return this.predict(multivInstance.get(0));
    }

    public List<Integer> predict(TimeSeriesDataset2 dataset) throws PredictionException {
        double[][] timeSeries;
        if (!this.isTrained()) {
            throw new PredictionException("Model has not been built before!");
        }
        if (dataset.isMultivariate()) {
            LOGGER.warn("Dataset to be predicted is multivariate but only first time series (univariate) will be considered.");
        }
        if ((timeSeries = dataset.getValuesOrNull(0)) == null) {
            throw new IllegalArgumentException("Dataset matrix of the instances to be predicted must not be null!");
        }
        ArrayList<Integer> predictions = new ArrayList<Integer>();
        LOGGER.debug("Starting prediction...");
        for (int inst = 0; inst < timeSeries.length; ++inst) {
            double[] instanceValues = timeSeries[inst];
            predictions.add(this.predict(instanceValues));
        }
        LOGGER.debug("Finished prediction.");
        return predictions;
    }

    public LearnShapeletsLearningAlgorithm getLearningAlgorithm(TimeSeriesDataset2 dataset) {
        return new LearnShapeletsLearningAlgorithm(this.config, this, dataset);
    }
}

