/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest.parkservices;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.parkservices.ForecastDescriptor;
import com.amazon.randomcutforest.parkservices.RCFCaster;
import com.amazon.randomcutforest.parkservices.calibration.Calibration;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.RangeVector;
import java.util.Arrays;
import java.util.function.BiFunction;
import lombok.Generated;

public class ErrorHandler {
    public static int MAX_ERROR_HORIZON = 1024;
    int sequenceIndex;
    double percentile;
    int forecastHorizon;
    int errorHorizon;
    protected RangeVector[] pastForecasts;
    protected float[][] actuals;
    RangeVector errorDistribution;
    DiVector errorRMSE;
    float[] errorMean;
    float[] intervalPrecision;
    RangeVector multipliers;
    RangeVector adders;

    public ErrorHandler(RCFCaster.Builder builder) {
        CommonUtils.checkArgument((builder.errorHorizon >= builder.forecastHorizon ? 1 : 0) != 0, (String)"intervalPrecision horizon should be at least as large as forecast horizon");
        CommonUtils.checkArgument((builder.errorHorizon <= MAX_ERROR_HORIZON ? 1 : 0) != 0, (String)"reduce error horizon of change MAX");
        this.forecastHorizon = builder.forecastHorizon;
        this.errorHorizon = builder.errorHorizon;
        int inputLength = builder.dimensions / builder.shingleSize;
        int length = inputLength * this.forecastHorizon;
        this.percentile = builder.percentile;
        this.pastForecasts = new RangeVector[this.errorHorizon + this.forecastHorizon];
        for (int i = 0; i < this.errorHorizon + this.forecastHorizon; ++i) {
            this.pastForecasts[i] = new RangeVector(length);
        }
        this.actuals = new float[this.errorHorizon + this.forecastHorizon][inputLength];
        this.sequenceIndex = 0;
        this.errorMean = new float[length];
        this.errorRMSE = new DiVector(length);
        this.multipliers = new RangeVector(length);
        Arrays.fill(this.multipliers.upper, 1.0f);
        Arrays.fill(this.multipliers.lower, 1.0f);
        this.adders = new RangeVector(length);
        this.intervalPrecision = new float[length];
        this.errorDistribution = new RangeVector(length);
        Arrays.fill(this.errorDistribution.upper, Float.MAX_VALUE);
        Arrays.fill(this.errorDistribution.lower, Float.MIN_VALUE);
    }

    public ErrorHandler(int errorHorizon, int forecastHorizon, int sequenceIndex, double percentile, int inputLength, float[] actualsFlattened, float[] pastForecastsFlattened, float[] auxilliary) {
        CommonUtils.checkArgument((forecastHorizon > 0 ? 1 : 0) != 0, (String)" incorrect forecast horizon");
        CommonUtils.checkArgument((errorHorizon >= forecastHorizon ? 1 : 0) != 0, (String)"incorrect error horizon");
        CommonUtils.checkArgument((actualsFlattened != null || pastForecastsFlattened == null ? 1 : 0) != 0, (String)" actuals and forecasts are a mismatch");
        CommonUtils.checkArgument((inputLength > 0 ? 1 : 0) != 0, (String)"incorrect parameters");
        this.sequenceIndex = sequenceIndex;
        this.errorHorizon = errorHorizon;
        this.percentile = percentile;
        this.forecastHorizon = forecastHorizon;
        int currentLength = actualsFlattened == null ? 0 : actualsFlattened.length;
        CommonUtils.checkArgument((currentLength % inputLength == 0 ? 1 : 0) != 0, (String)"actuals array is incorrect");
        int forecastLength = pastForecastsFlattened == null ? 0 : pastForecastsFlattened.length;
        int arrayLength = Math.max(forecastHorizon + errorHorizon, currentLength / inputLength);
        this.pastForecasts = new RangeVector[arrayLength];
        this.actuals = new float[arrayLength][inputLength];
        int length = forecastHorizon * inputLength;
        CommonUtils.checkArgument((forecastLength == currentLength * 3 * forecastHorizon ? 1 : 0) != 0, (String)"misaligned forecasts");
        this.errorMean = new float[length];
        this.errorRMSE = new DiVector(length);
        this.intervalPrecision = new float[length];
        this.adders = new RangeVector(length);
        this.multipliers = new RangeVector(length);
        this.errorDistribution = new RangeVector(length);
        if (pastForecastsFlattened != null) {
            for (int i = 0; i < arrayLength; ++i) {
                float[] values = Arrays.copyOfRange(pastForecastsFlattened, i * 3 * length, (i * 3 + 1) * length);
                float[] upper = Arrays.copyOfRange(pastForecastsFlattened, (i * 3 + 1) * length, (i * 3 + 2) * length);
                float[] lower = Arrays.copyOfRange(pastForecastsFlattened, (i * 3 + 2) * length, (i * 3 + 3) * length);
                this.pastForecasts[i] = new RangeVector(values, upper, lower);
                System.arraycopy(actualsFlattened, i * inputLength, this.actuals[i], 0, inputLength);
            }
            this.calibrate(null);
        }
    }

    public void update(ForecastDescriptor descriptor, Calibration calibrationMethod) {
        int arrayLength = this.pastForecasts.length;
        int length = this.pastForecasts[0].values.length;
        int errorIndex = this.sequenceIndex % arrayLength;
        int inputLength = descriptor.getInputLength();
        double[] input = descriptor.getCurrentInput();
        for (int i = 0; i < inputLength; ++i) {
            this.actuals[errorIndex][i] = (float)input[i];
        }
        ++this.sequenceIndex;
        this.calibrate(descriptor.deviations);
        if (calibrationMethod != Calibration.NONE) {
            if (calibrationMethod == Calibration.SIMPLE) {
                this.adjust(descriptor.timedForecast.rangeVector, this.errorDistribution);
            }
            if (calibrationMethod == Calibration.MINIMAL) {
                this.adjustMinimal(descriptor.timedForecast.rangeVector, this.errorDistribution);
            }
        }
        descriptor.setErrorMean(this.errorMean);
        descriptor.setErrorRMSE(this.errorRMSE);
        descriptor.setObservedErrorDistribution(this.errorDistribution);
        descriptor.setCalibration(this.intervalPrecision);
        System.arraycopy(descriptor.timedForecast.rangeVector.values, 0, this.pastForecasts[errorIndex].values, 0, length);
        System.arraycopy(descriptor.timedForecast.rangeVector.upper, 0, this.pastForecasts[errorIndex].upper, 0, length);
        System.arraycopy(descriptor.timedForecast.rangeVector.lower, 0, this.pastForecasts[errorIndex].lower, 0, length);
    }

    public RangeVector getErrors() {
        return new RangeVector(this.errorDistribution);
    }

    public float[] getErrorMean() {
        return Arrays.copyOf(this.errorMean, this.errorMean.length);
    }

    public DiVector getErrorRMSE() {
        return new DiVector(this.errorRMSE);
    }

    public float[] getCalibration() {
        return Arrays.copyOf(this.intervalPrecision, this.intervalPrecision.length);
    }

    public RangeVector getMultipliers() {
        return new RangeVector(this.multipliers);
    }

    public RangeVector getAdders() {
        return new RangeVector(this.adders);
    }

    public RangeVector computeErrorPercentile(double percentile, BiFunction<Float, Float, Float> error) {
        return this.computeErrorPercentile(percentile, this.pastForecasts.length, error);
    }

    public RangeVector computeErrorPercentile(double percentile, int newHorizon, BiFunction<Float, Float, Float> error) {
        CommonUtils.checkArgument((newHorizon <= this.errorHorizon && newHorizon > 0 ? 1 : 0) != 0, (String)"incorrect horizon parameter");
        int length = this.pastForecasts[0].values.length;
        float[] lower = new float[length];
        float[] upper = new float[length];
        float[] values = new float[length];
        Arrays.fill(lower, -3.4028235E38f);
        Arrays.fill(upper, Float.MAX_VALUE);
        if (this.actuals != null) {
            int inputLength = this.actuals[0].length;
            for (int i = 0; i < this.forecastHorizon; ++i) {
                int len = this.sequenceIndex > newHorizon + i + 1 ? newHorizon : this.sequenceIndex - i - 1;
                for (int j = 0; j < inputLength; ++j) {
                    int pos = i * inputLength + j;
                    if (len <= 0) continue;
                    double[] copy = this.getErrorVector(len, i + 1, j, pos, error);
                    double fracRank = percentile * (double)len;
                    Arrays.sort(copy);
                    values[pos] = this.interpolatedMedian(copy);
                    lower[pos] = this.interpolatedLowerRank(copy, fracRank, 0.0);
                    upper[pos] = this.interpolatedUpperRank(copy, len, fracRank, 0.0);
                }
            }
        }
        return new RangeVector(values, upper, lower);
    }

    protected double[] getErrorVector(int len, int leadtime, int inputCoordinate, int position, BiFunction<Float, Float, Float> error) {
        int arrayLength = this.pastForecasts.length;
        int errorIndex = (this.sequenceIndex - 1 + arrayLength) % arrayLength;
        double[] copy = new double[len];
        for (int k = 0; k < len; ++k) {
            int pastIndex = (errorIndex - leadtime - k + arrayLength) % arrayLength;
            int index = (errorIndex - k + arrayLength) % arrayLength;
            copy[k] = error.apply(Float.valueOf(this.actuals[index][inputCoordinate]), Float.valueOf(this.pastForecasts[pastIndex].values[position])).floatValue();
        }
        return copy;
    }

    protected void calibrate(double[] errorDeviations) {
        int inputLength = this.actuals[0].length;
        int arrayLength = this.pastForecasts.length;
        int errorIndex = (this.sequenceIndex - 1 + arrayLength) % arrayLength;
        double[] medianError = new double[this.errorHorizon];
        Arrays.fill(this.intervalPrecision, 0.0f);
        for (int i = 0; i < this.forecastHorizon; ++i) {
            int len = this.sequenceIndex > this.errorHorizon + i + 1 ? this.errorHorizon : this.sequenceIndex - i - 1;
            for (int j = 0; j < inputLength; ++j) {
                int pos = i * inputLength + j;
                if (len > 0) {
                    double positiveSum = 0.0;
                    int positiveCount = 0;
                    double negativeSum = 0.0;
                    double positiveSqSum = 0.0;
                    double negativeSqSum = 0.0;
                    for (int k = 0; k < len; ++k) {
                        double error;
                        int pastIndex = (errorIndex - (i + 1) - k + arrayLength) % arrayLength;
                        int index = (errorIndex - k + arrayLength) % arrayLength;
                        medianError[k] = error = (double)(this.actuals[index][j] - this.pastForecasts[pastIndex].values[pos]);
                        int n = pos;
                        this.intervalPrecision[n] = this.intervalPrecision[n] + (this.pastForecasts[pastIndex].upper[pos] >= this.actuals[index][j] && this.actuals[index][j] >= this.pastForecasts[pastIndex].lower[pos] ? 1.0f : 0.0f);
                        if (error >= 0.0) {
                            positiveSum += error;
                            positiveSqSum += error * error;
                            ++positiveCount;
                            continue;
                        }
                        negativeSum += error;
                        negativeSqSum += error * error;
                    }
                    this.errorMean[pos] = (float)(positiveSum + negativeSum) / (float)len;
                    this.errorRMSE.high[pos] = positiveCount > 0 ? Math.sqrt(positiveSqSum / (double)positiveCount) : 0.0;
                    this.errorRMSE.low[pos] = positiveCount < len ? -Math.sqrt(negativeSqSum / (double)(len - positiveCount)) : 0.0;
                    Arrays.sort(medianError, 0, len);
                    this.errorDistribution.values[pos] = this.interpolatedMedian(medianError);
                    double deviation = errorDeviations == null ? 0.0 : errorDeviations[j];
                    this.errorDistribution.upper[pos] = this.interpolatedUpperRank(medianError, len, (double)len * this.percentile, deviation);
                    this.errorDistribution.lower[pos] = this.interpolatedLowerRank(medianError, (double)len * this.percentile, deviation);
                    this.intervalPrecision[pos] = this.intervalPrecision[pos] / (float)len;
                    continue;
                }
                this.errorMean[pos] = 0.0f;
                this.errorRMSE.low[pos] = 0.0;
                this.errorRMSE.high[pos] = 0.0;
                this.errorDistribution.values[pos] = 0.0f;
                double deviation = errorDeviations == null ? 0.0 : errorDeviations[j];
                this.errorDistribution.upper[pos] = (float)(1.3 * deviation);
                this.errorDistribution.lower[pos] = -((float)(1.3 * deviation));
                this.adders.values[pos] = 0.0f;
                this.adders.lower[pos] = 0.0f;
                this.adders.upper[pos] = 0.0f;
                this.intervalPrecision[pos] = 0.0f;
            }
        }
    }

    float interpolatedMedian(double[] array) {
        CommonUtils.checkArgument((array != null ? 1 : 0) != 0, (String)" cannot be null");
        int len = array.length;
        if (len % 2 != 0) {
            return (float)array[len / 2];
        }
        return (float)((array[len / 2 - 1] + array[len / 2]) / 2.0);
    }

    float interpolatedLowerRank(double[] ascendingArray, double fracRank, double deviation) {
        if (fracRank < 1.0) {
            return (float)(-1.3 * deviation * (1.0 - fracRank) + fracRank * ascendingArray[0]);
        }
        int rank = (int)Math.floor(fracRank);
        if (!RCFCaster.USE_INTERPOLATION_IN_DISTRIBUTION) {
            fracRank = rank;
        }
        return (float)(ascendingArray[rank - 1] + (fracRank - (double)rank) * (ascendingArray[rank] - ascendingArray[rank - 1]));
    }

    float interpolatedUpperRank(double[] ascendingArray, int len, double fracRank, double deviation) {
        if (fracRank < 1.0) {
            return (float)(1.3 * deviation * (1.0 - fracRank) + fracRank * ascendingArray[len - 1]);
        }
        int rank = (int)Math.floor(fracRank);
        if (!RCFCaster.USE_INTERPOLATION_IN_DISTRIBUTION) {
            fracRank = rank;
        }
        return (float)(ascendingArray[len - rank] + (fracRank - (double)rank) * (ascendingArray[len - rank - 1] - ascendingArray[len - rank]));
    }

    void adjust(RangeVector rangeVector, RangeVector other) {
        CommonUtils.checkArgument((other.values.length == rangeVector.values.length ? 1 : 0) != 0, (String)" mismatch in lengths");
        for (int i = 0; i < rangeVector.values.length; ++i) {
            int n = i;
            rangeVector.values[n] = rangeVector.values[n] + other.values[i];
            rangeVector.upper[i] = Math.max(rangeVector.values[i], rangeVector.upper[i] + other.upper[i]);
            rangeVector.lower[i] = Math.min(rangeVector.values[i], rangeVector.lower[i] + other.lower[i]);
        }
    }

    void adjustMinimal(RangeVector rangeVector, RangeVector other) {
        CommonUtils.checkArgument((other.values.length == rangeVector.values.length ? 1 : 0) != 0, (String)" mismatch in lengths");
        for (int i = 0; i < rangeVector.values.length; ++i) {
            float oldVal = rangeVector.values[i];
            int n = i;
            rangeVector.values[n] = rangeVector.values[n] + other.values[i];
            rangeVector.upper[i] = Math.max(rangeVector.values[i], oldVal + other.upper[i]);
            rangeVector.lower[i] = Math.min(rangeVector.values[i], oldVal + other.lower[i]);
        }
    }

    @Generated
    public int getSequenceIndex() {
        return this.sequenceIndex;
    }

    @Generated
    public double getPercentile() {
        return this.percentile;
    }

    @Generated
    public int getForecastHorizon() {
        return this.forecastHorizon;
    }

    @Generated
    public int getErrorHorizon() {
        return this.errorHorizon;
    }

    @Generated
    public RangeVector[] getPastForecasts() {
        return this.pastForecasts;
    }

    @Generated
    public float[][] getActuals() {
        return this.actuals;
    }

    @Generated
    public RangeVector getErrorDistribution() {
        return this.errorDistribution;
    }

    @Generated
    public float[] getIntervalPrecision() {
        return this.intervalPrecision;
    }

    @Generated
    public void setSequenceIndex(int sequenceIndex) {
        this.sequenceIndex = sequenceIndex;
    }

    @Generated
    public void setPercentile(double percentile) {
        this.percentile = percentile;
    }

    @Generated
    public void setForecastHorizon(int forecastHorizon) {
        this.forecastHorizon = forecastHorizon;
    }

    @Generated
    public void setErrorHorizon(int errorHorizon) {
        this.errorHorizon = errorHorizon;
    }

    @Generated
    public void setPastForecasts(RangeVector[] pastForecasts) {
        this.pastForecasts = pastForecasts;
    }

    @Generated
    public void setActuals(float[][] actuals) {
        this.actuals = actuals;
    }

    @Generated
    public void setErrorDistribution(RangeVector errorDistribution) {
        this.errorDistribution = errorDistribution;
    }

    @Generated
    public void setErrorRMSE(DiVector errorRMSE) {
        this.errorRMSE = errorRMSE;
    }

    @Generated
    public void setErrorMean(float[] errorMean) {
        this.errorMean = errorMean;
    }

    @Generated
    public void setIntervalPrecision(float[] intervalPrecision) {
        this.intervalPrecision = intervalPrecision;
    }

    @Generated
    public void setMultipliers(RangeVector multipliers) {
        this.multipliers = multipliers;
    }

    @Generated
    public void setAdders(RangeVector adders) {
        this.adders = adders;
    }
}

