/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest.examples.parkservices;

import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.examples.Example;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import com.amazon.randomcutforest.returntypes.RangeVector;
import com.amazon.randomcutforest.testutils.MultiDimDataWithKey;
import com.amazon.randomcutforest.testutils.ShingledMultiDimDataWithKeys;
import java.util.Random;

public class ThresholdedPredictive
implements Example {
    public static void main(String[] args) throws Exception {
        new ThresholdedPredictive().run();
    }

    @Override
    public String command() {
        return "Thresholded_Predictive_example";
    }

    @Override
    public String description() {
        return "Example of predictive forecast across multiple time series using ThresholdedRCF";
    }

    @Override
    public void run() throws Exception {
        int sampleSize = 256;
        int baseDimensions = 1;
        int length = 4 * sampleSize;
        int outputAfter = 128;
        long seed = 2022L;
        Random random = new Random(seed);
        int numberOfModels = 10;
        MultiDimDataWithKey[] dataWithKeys = new MultiDimDataWithKey[numberOfModels];
        ThresholdedRandomCutForest[] forests = new ThresholdedRandomCutForest[numberOfModels];
        int[] period = new int[numberOfModels];
        double alertThreshold = 300.0;
        double lastActualSum = 0.0;
        int anomalies = 0;
        for (int k = 0; k < numberOfModels; ++k) {
            period[k] = (int)Math.round(40.0 + 30.0 * random.nextDouble());
            dataWithKeys[k] = ShingledMultiDimDataWithKeys.getMultiDimData((int)length, (int)period[k], (double)100.0, (double)10.0, (long)seed, (int)baseDimensions, (boolean)false);
            anomalies += dataWithKeys[k].changes.length;
        }
        System.out.println(anomalies + " anomalies injected ");
        int shingleSize = 10;
        int horizon = 20;
        for (int k = 0; k < numberOfModels; ++k) {
            forests[k] = new ThresholdedRandomCutForest.Builder().compact(true).dimensions(baseDimensions * shingleSize).precision(Precision.FLOAT_32).randomSeed(seed + (long)k).internalShinglingEnabled(true).shingleSize(shingleSize).outputAfter(outputAfter).transformMethod(TransformMethod.NORMALIZE).build();
        }
        boolean predictNextCrossing = true;
        boolean actualCrossingAlerted = false;
        boolean printPredictions = false;
        boolean printEvents = true;
        for (int i = 0; i < length; ++i) {
            int k;
            double[] prediction = new double[horizon];
            if (i > sampleSize) {
                for (int k2 = 0; k2 < numberOfModels; ++k2) {
                    RangeVector forecast = forests[k2].extrapolate((int)horizon).rangeVector;
                    for (int t = 0; t < horizon; ++t) {
                        int n = t;
                        prediction[n] = prediction[n] + (double)forecast.values[t];
                    }
                }
                if (prediction[horizon - 1] > alertThreshold && predictNextCrossing) {
                    if (printEvents) {
                        System.out.println("Currently at " + i + ", should cross " + alertThreshold + " at sequence " + (i + horizon - 1));
                    }
                    predictNextCrossing = false;
                } else if (prediction[horizon - 1] < alertThreshold && !predictNextCrossing) {
                    predictNextCrossing = true;
                }
                if (printPredictions) {
                    for (int t = 0; t < horizon; ++t) {
                        System.out.println(i + t + " " + prediction[t]);
                    }
                    System.out.println();
                    System.out.println();
                }
            }
            double sumValue = 0.0;
            for (k = 0; k < numberOfModels; ++k) {
                sumValue += dataWithKeys[k].data[i][0];
            }
            if (lastActualSum > alertThreshold && sumValue > alertThreshold) {
                if (!actualCrossingAlerted) {
                    if (printEvents) {
                        System.out.println(" Crossing " + alertThreshold + " at consecutive sequence indices " + (i - 1) + " " + i);
                    }
                    actualCrossingAlerted = true;
                }
            } else if (sumValue < alertThreshold) {
                actualCrossingAlerted = false;
            }
            lastActualSum = sumValue;
            for (k = 0; k < numberOfModels; ++k) {
                forests[k].process(dataWithKeys[k].data[i], 0L);
            }
        }
    }
}

