/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.singlelabel.timeseries.util;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesDataset2;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.exception.TimeSeriesLengthException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class TimeSeriesUtil {
    private TimeSeriesUtil() {
    }

    public static boolean isTimeSeries(INDArray ... array) {
        for (INDArray a : array) {
            if (a.rank() == 1) continue;
            return false;
        }
        return true;
    }

    public static boolean isTimeSeries(int length, INDArray ... array) {
        for (INDArray a : array) {
            if (a.rank() == 1 || a.length() != (long)length) continue;
            return false;
        }
        return true;
    }

    public static boolean isTimeSeries(int length, double[] ... array) {
        for (double[] a : array) {
            if (a.length == length) continue;
            return false;
        }
        return true;
    }

    public static void isTimeSeriesOrException(INDArray ... array) {
        for (INDArray a : array) {
            if (TimeSeriesUtil.isTimeSeries(array)) continue;
            String message = String.format("The given INDArray is no time series. It should have rank 1, but has a rank of %d.", a.rank());
            throw new IllegalArgumentException(message);
        }
    }

    public static void isTimeSeriesOrException(int length, INDArray ... array) {
        for (INDArray a : array) {
            if (!TimeSeriesUtil.isTimeSeries(array)) {
                String message = String.format("The given INDArray is no time series. It should have rank 1, but has a rank of %d.", a.rank());
                throw new IllegalArgumentException(message);
            }
            if (TimeSeriesUtil.isTimeSeries(length, a)) continue;
            String message = String.format("The given time series should length 7, but has a length of %d.", a.length());
            throw new IllegalArgumentException(message);
        }
    }

    public static void isTimeSeriesOrException(int length, double[] ... array) {
        for (double[] a : array) {
            if (TimeSeriesUtil.isTimeSeries(length, new double[][]{a})) continue;
            String message = String.format("The given time series should length 7, but has a length of %d.", a.length);
            throw new IllegalArgumentException(message);
        }
    }

    public static boolean isSameLength(INDArray timeSeries1, INDArray ... timeSeries) {
        for (INDArray t : timeSeries) {
            if (timeSeries1.length() == t.length()) continue;
            return false;
        }
        return true;
    }

    public static boolean isSameLength(double[] timeSeries1, double[] ... timeSeries) {
        for (double[] t : timeSeries) {
            if (timeSeries1.length == t.length) continue;
            return false;
        }
        return true;
    }

    public static void isSameLengthOrException(INDArray timeSeries1, INDArray ... timeSeries) {
        for (INDArray t : timeSeries) {
            if (TimeSeriesUtil.isSameLength(timeSeries1, t)) continue;
            String message = String.format("Length of the given time series are not equal: Length first time series: (%d). Length of seconds time series: (%d)", timeSeries1.length(), t.length());
            throw new TimeSeriesLengthException(message);
        }
    }

    public static void isSameLengthOrException(double[] timeSeries1, double[] ... timeSeries) {
        for (double[] t : timeSeries) {
            if (TimeSeriesUtil.isSameLength(timeSeries1, new double[][]{t})) continue;
            String message = String.format("Length of the given time series are not equal: Length first time series: (%d). Length of seconds time series: (%d)", timeSeries1.length, t.length);
            throw new TimeSeriesLengthException(message);
        }
    }

    public static INDArray createEquidistantTimestamps(INDArray timeSeries) {
        int n = (int)timeSeries.length();
        double[] timestamps = IntStream.range(0, n).mapToDouble(t -> t).toArray();
        int[] shape = new int[]{n};
        return Nd4j.create((double[])timestamps, (int[])shape);
    }

    public static double[] createEquidistantTimestamps(double[] timeSeries) {
        int n = timeSeries.length;
        return IntStream.range(0, n).mapToDouble(t -> t).toArray();
    }

    public static double[] getInterval(double[] timeSeries, int start, int end) {
        if (end <= start) {
            throw new IllegalArgumentException("The end index must be greater than the start index.");
        }
        double[] result = new double[end - start];
        for (int j = 0; j < end - start; ++j) {
            result[j] = timeSeries[j + start];
        }
        return result;
    }

    public static INDArray normalizeINDArray(INDArray array, boolean inplace) {
        if (array.shape().length > 2 && array.shape()[0] != 1L) {
            throw new IllegalArgumentException(String.format("Input INDArray object must be a vector with shape size 1. Actual shape: (%s)", Arrays.toString(array.shape())));
        }
        double mean = array.mean(new int[]{1}).getDouble(0L);
        double std = array.std(new int[]{1}).getDouble(0L);
        INDArray result = inplace ? array.subi((Number)mean) : array.sub((Number)mean);
        return result.addi((Number)Nd4j.EPS_THRESHOLD).divi((Number)std);
    }

    public static int getMode(int[] array) {
        HashMap<Integer, Integer> statistics = new HashMap<Integer, Integer>();
        for (int i = 0; i < array.length; ++i) {
            if (!statistics.containsKey(array[i])) {
                statistics.put(array[i], 1);
                continue;
            }
            statistics.replace(array[i], (Integer)statistics.get(array[i]) + 1);
        }
        return TimeSeriesUtil.getMaximumKeyByValue(statistics) != null ? (Integer)TimeSeriesUtil.getMaximumKeyByValue(statistics) : -1;
    }

    public static <T> T getMaximumKeyByValue(Map<T, Integer> map) {
        T maxKey = null;
        int maxCount = 0;
        for (Map.Entry<T, Integer> entry : map.entrySet()) {
            T key = entry.getKey();
            int val = entry.getValue();
            if (val <= maxCount) continue;
            maxCount = val;
            maxKey = key;
        }
        return maxKey;
    }

    public static double[] zNormalize(double[] dataVector, boolean besselsCorrection) {
        int n = dataVector.length - (besselsCorrection ? 1 : 0);
        double mean = 0.0;
        for (int i = 0; i < dataVector.length; ++i) {
            mean += dataVector[i];
        }
        mean /= (double)dataVector.length;
        double stddev = 0.0;
        for (int i = 0; i < dataVector.length; ++i) {
            stddev += Math.pow(dataVector[i] - mean, 2.0);
        }
        stddev /= (double)n;
        stddev = Math.sqrt(stddev);
        double[] result = new double[dataVector.length];
        if (stddev == 0.0) {
            return result;
        }
        for (int i = 0; i < result.length; ++i) {
            result[i] = (dataVector[i] - mean) / stddev;
        }
        return result;
    }

    public static List<Integer> sortIndexes(double[] vector, boolean ascending) {
        Integer[] indexes = new Integer[vector.length];
        for (int i = 0; i < indexes.length; ++i) {
            indexes[i] = i;
        }
        Arrays.sort(indexes, (i1, i2) -> (ascending ? 1 : -1) * Double.compare(Math.abs(vector[i1]), Math.abs(vector[i2])));
        return Arrays.asList(indexes);
    }

    public static int getNumberOfClasses(TimeSeriesDataset2 dataset) {
        if (dataset == null || dataset.getTargets() == null) {
            throw new IllegalArgumentException("Given parameter 'dataset' must not be null and must contain a target matrix!");
        }
        return TimeSeriesUtil.getClassesInDataset(dataset).size();
    }

    public static List<Integer> getClassesInDataset(TimeSeriesDataset2 dataset) {
        if (dataset == null || dataset.getTargets() == null) {
            throw new IllegalArgumentException("Given parameter 'dataset' must not be null and must contain a target matrix!");
        }
        return IntStream.of(dataset.getTargets()).boxed().collect(Collectors.toSet()).stream().collect(Collectors.toList());
    }

    public static void shuffleTimeSeriesDataset(TimeSeriesDataset2 dataset, int seed) {
        int i;
        List<Integer> indices = IntStream.range(0, dataset.getNumberOfInstances()).boxed().collect(Collectors.toList());
        Collections.shuffle(indices, new Random(seed));
        List<double[][]> valueMatrices = dataset.getValueMatrices();
        List<double[][]> timestampMatrices = dataset.getTimestampMatrices();
        int[] targets = dataset.getTargets();
        if (valueMatrices != null) {
            ArrayList<double[][]> targetValueMatrices = new ArrayList<double[][]>();
            for (i = 0; i < valueMatrices.size(); ++i) {
                targetValueMatrices.add(TimeSeriesUtil.shuffleMatrix(valueMatrices.get(i), indices));
            }
            dataset.setValueMatrices(targetValueMatrices);
        }
        if (timestampMatrices != null) {
            ArrayList<double[][]> targetTimestampMatrices = new ArrayList<double[][]>();
            for (i = 0; i < timestampMatrices.size(); ++i) {
                targetTimestampMatrices.add(TimeSeriesUtil.shuffleMatrix(timestampMatrices.get(i), indices));
            }
            dataset.setTimestampMatrices(targetTimestampMatrices);
        }
        if (targets != null) {
            dataset.setTargets(TimeSeriesUtil.shuffleMatrix(targets, indices));
        }
    }

    private static double[][] shuffleMatrix(double[][] srcMatrix, List<Integer> indices) {
        if (srcMatrix == null || srcMatrix.length < 1) {
            throw new IllegalArgumentException("Parameter 'srcMatrix' must not be null or empty!");
        }
        if (indices == null || indices.size() != srcMatrix.length) {
            throw new IllegalArgumentException("Parameter 'indices' must not be null and must have the same length as the number of instances in the source matrix!");
        }
        double[][] result = new double[srcMatrix.length][srcMatrix[0].length];
        for (int i = 0; i < indices.size(); ++i) {
            result[i] = srcMatrix[indices.get(i)];
        }
        return result;
    }

    private static int[] shuffleMatrix(int[] srcMatrix, List<Integer> indices) {
        if (srcMatrix == null || srcMatrix.length < 1) {
            throw new IllegalArgumentException("Parameter 'srcMatrix' must not be null or empty!");
        }
        if (indices == null || indices.size() != srcMatrix.length) {
            throw new IllegalArgumentException("Parameter 'indices' must not be null and must have the same length as the number of instances in the source matrix!");
        }
        int[] result = new int[srcMatrix.length];
        for (int i = 0; i < indices.size(); ++i) {
            result[i] = srcMatrix[indices.get(i)];
        }
        return result;
    }

    public static Pair<TimeSeriesDataset2, TimeSeriesDataset2> getTrainingAndTestDataForFold(int fold, int numFolds, double[][] srcValueMatrix, int[] srcTargetMatrix) {
        return new Pair((Object)TimeSeriesUtil.selectTrainingDataForFold(fold, numFolds, srcValueMatrix, srcTargetMatrix), (Object)TimeSeriesUtil.selectTestDataForFold(fold, numFolds, srcValueMatrix, srcTargetMatrix));
    }

    private static TimeSeriesDataset2 selectTrainingDataForFold(int fold, int numFolds, double[][] srcValueMatrix, int[] srcTargetMatrix) {
        int numTestInstsPerFold = (int)((double)srcValueMatrix.length / (double)numFolds);
        double[][] destValueMatrix = new double[(numFolds - 1) * numTestInstsPerFold][srcValueMatrix[0].length];
        int[] destTargetMatrix = new int[(numFolds - 1) * numTestInstsPerFold];
        if (fold == 0) {
            System.arraycopy(srcValueMatrix, numTestInstsPerFold, destValueMatrix, 0, (numFolds - 1) * numTestInstsPerFold);
            System.arraycopy(srcTargetMatrix, numTestInstsPerFold, destTargetMatrix, 0, (numFolds - 1) * numTestInstsPerFold);
        } else if (fold == numFolds - 1) {
            System.arraycopy(srcValueMatrix, 0, destValueMatrix, 0, (numFolds - 1) * numTestInstsPerFold);
            System.arraycopy(srcTargetMatrix, 0, destTargetMatrix, 0, (numFolds - 1) * numTestInstsPerFold);
        } else {
            System.arraycopy(srcValueMatrix, 0, destValueMatrix, 0, fold * numTestInstsPerFold);
            System.arraycopy(srcValueMatrix, (fold + 1) * numTestInstsPerFold, destValueMatrix, fold * numTestInstsPerFold, (numFolds - fold - 1) * numTestInstsPerFold);
            System.arraycopy(srcTargetMatrix, 0, destTargetMatrix, 0, fold * numTestInstsPerFold);
            System.arraycopy(srcTargetMatrix, (fold + 1) * numTestInstsPerFold, destTargetMatrix, fold * numTestInstsPerFold, (numFolds - fold - 1) * numTestInstsPerFold);
        }
        ArrayList<double[][]> valueMatrices = new ArrayList<double[][]>();
        valueMatrices.add(destValueMatrix);
        return new TimeSeriesDataset2(valueMatrices, destTargetMatrix);
    }

    private static TimeSeriesDataset2 selectTestDataForFold(int fold, int numFolds, double[][] srcValueMatrix, int[] srcTargetMatrix) {
        int[] currTestTargetMatrix;
        double[][] currTestMatrix;
        int numTestInstsPerFold = (int)((double)srcValueMatrix.length / (double)numFolds);
        if (fold == numFolds - 1) {
            int remainingLength = srcValueMatrix.length - (numFolds - 1) * numTestInstsPerFold;
            currTestMatrix = new double[remainingLength][srcValueMatrix[0].length];
            currTestTargetMatrix = new int[remainingLength];
        } else {
            currTestMatrix = new double[numTestInstsPerFold][srcValueMatrix[0].length];
            currTestTargetMatrix = new int[numTestInstsPerFold];
        }
        System.arraycopy(srcValueMatrix, fold * numTestInstsPerFold, currTestMatrix, 0, currTestMatrix.length);
        System.arraycopy(srcTargetMatrix, fold * numTestInstsPerFold, currTestTargetMatrix, 0, currTestTargetMatrix.length);
        ArrayList<double[][]> testValueMatrices = new ArrayList<double[][]>();
        testValueMatrices.add(currTestMatrix);
        return new TimeSeriesDataset2(testValueMatrices, currTestTargetMatrix);
    }

    public static TimeSeriesDataset2 createDatasetForMatrix(int[] targets, double[][] ... valueMatrices) {
        if (valueMatrices.length == 0) {
            throw new IllegalArgumentException("There must be at least one value matrix to generate a TimeSeriesDataset object!");
        }
        List<double[][]> values = Arrays.asList(valueMatrices);
        return targets == null ? new TimeSeriesDataset2(values) : new TimeSeriesDataset2(values, targets);
    }

    public static TimeSeriesDataset2 createDatasetForMatrix(double[][] ... valueMatrices) {
        return TimeSeriesUtil.createDatasetForMatrix(null, valueMatrices);
    }

    public static String toString(double[] timeSeries) {
        if (timeSeries.length == 0) {
            return "{}";
        }
        int stringLength = 2 + timeSeries.length * 3 - 1;
        StringBuilder sb = new StringBuilder(stringLength);
        sb.append("{" + timeSeries[0]);
        for (int i = 1; i < timeSeries.length; ++i) {
            sb.append(", " + timeSeries[i]);
        }
        sb.append("}");
        return sb.toString();
    }

    public static double[] keoghDerivate(double[] t) {
        double[] derivate = new double[t.length - 2];
        for (int i = 1; i < t.length - 1; ++i) {
            derivate[i - 1] = (t[i] - t[i - 1] + (t[i + 1] - t[i - 1]) / 2.0) / 2.0;
        }
        return derivate;
    }

    public static double[] keoghDerivateWithBoundaries(double[] t) {
        double[] derivate = new double[t.length];
        for (int i = 1; i < t.length - 1; ++i) {
            derivate[i] = (t[i] - t[i - 1] + (t[i + 1] - t[i - 1]) / 2.0) / 2.0;
        }
        derivate[0] = derivate[1];
        derivate[t.length - 1] = derivate[t.length - 2];
        return derivate;
    }

    public static double[] backwardDifferenceDerivate(double[] t) {
        double[] derivate = new double[t.length - 1];
        for (int i = 1; i < t.length; ++i) {
            derivate[i - 1] = t[i] - t[i - 1];
        }
        return derivate;
    }

    public static double[] backwardDifferenceDerivateWithBoundaries(double[] t) {
        double[] derivate = new double[t.length];
        for (int i = 1; i < t.length; ++i) {
            derivate[i] = t[i] - t[i - 1];
        }
        derivate[0] = derivate[1];
        return derivate;
    }

    public static double[] forwardDifferenceDerivate(double[] t) {
        double[] derivate = new double[t.length - 1];
        for (int i = 0; i < t.length - 1; ++i) {
            derivate[i] = t[i + 1] - t[i];
        }
        return derivate;
    }

    public static double[] forwardDifferenceDerivateWithBoundaries(double[] t) {
        double[] derivate = new double[t.length];
        for (int i = 0; i < t.length - 1; ++i) {
            derivate[i] = t[i + 1] - t[i];
        }
        derivate[t.length - 1] = derivate[t.length - 2];
        return derivate;
    }

    public static double[] gulloDerivate(double[] t) {
        double[] derivate = new double[t.length - 1];
        for (int i = 1; i < t.length; ++i) {
            derivate[i - 1] = t[i + 1] - t[i - 1] / 2.0;
        }
        return derivate;
    }

    public static double[] gulloDerivateWithBoundaries(double[] t) {
        double[] derivate = new double[t.length];
        for (int i = 1; i < t.length; ++i) {
            derivate[i] = t[i + 1] - t[i - 1] / 2.0;
        }
        derivate[0] = derivate[1];
        return derivate;
    }

    public static double sum(double[] t) {
        double sum = 0.0;
        for (int i = 0; i < t.length; ++i) {
            sum += t[i];
        }
        return sum;
    }

    public static double mean(double[] t) {
        return TimeSeriesUtil.sum(t) / (double)t.length;
    }

    public static double variance(double[] t) {
        double mean = TimeSeriesUtil.mean(t);
        double squaredDeviations = 0.0;
        for (int i = 0; i < t.length; ++i) {
            squaredDeviations += (t[i] - mean) * (t[i] - mean);
        }
        return squaredDeviations / (double)t.length;
    }

    public static double standardDeviation(double[] t) {
        return Math.sqrt(TimeSeriesUtil.variance(t));
    }

    public static double[] normalizeByStandardDeviation(double[] t) {
        double standardDeviation = TimeSeriesUtil.standardDeviation(t);
        if (standardDeviation == 0.0) {
            return new double[t.length];
        }
        double[] normalizedT = new double[t.length];
        for (int i = 0; i < t.length; ++i) {
            normalizedT[i] = t[i] / standardDeviation;
        }
        return normalizedT;
    }
}

