/*
 * Decompiled with CFR 0.152.
 */
package com.github.signaflo.timeseries.model.arima;

import com.github.signaflo.data.Range;
import com.github.signaflo.math.linear.doubles.Matrix;
import com.github.signaflo.math.linear.doubles.Vector;
import com.github.signaflo.math.operations.DoubleFunctions;
import com.github.signaflo.math.stats.distributions.Normal;
import com.github.signaflo.timeseries.TimePeriod;
import com.github.signaflo.timeseries.TimeSeries;
import com.github.signaflo.timeseries.forecast.Forecast;
import com.github.signaflo.timeseries.forecast.Forecaster;
import com.github.signaflo.timeseries.model.arima.ArimaCoefficients;
import com.github.signaflo.timeseries.model.arima.ArimaForecast;
import com.github.signaflo.timeseries.model.arima.ArimaOrder;
import com.github.signaflo.timeseries.operators.LagPolynomial;
import java.time.OffsetDateTime;

class ArimaForecaster
implements Forecaster {
    private final TimeSeries observations;
    private final ArimaCoefficients coefficients;
    private final ArimaOrder order;
    private final TimeSeries differencedSeries;
    private final TimeSeries residuals;
    private final Matrix regressionMatrix;
    private final double sigma2;

    ArimaForecaster(TimeSeries observations, ArimaCoefficients coefficients, ArimaOrder order, TimeSeries differencedSeries, TimeSeries residuals, Matrix regressionMatrix, double sigma2) {
        this.observations = observations;
        this.coefficients = coefficients;
        this.order = order;
        this.differencedSeries = differencedSeries;
        this.residuals = residuals;
        this.regressionMatrix = regressionMatrix;
        this.sigma2 = sigma2;
    }

    @Override
    public Forecast forecast(int steps, double alpha) {
        TimeSeries pointForecasts = this.computePointForecasts(steps);
        TimeSeries lowerValues = this.computeLowerPredictionBounds(pointForecasts, steps, alpha);
        TimeSeries upperValues = this.computeUpperPredictionBounds(pointForecasts, steps, alpha);
        return new ArimaForecast(pointForecasts, lowerValues, upperValues, alpha);
    }

    @Override
    public TimeSeries computeUpperPredictionBounds(TimeSeries observations, int steps, double alpha) {
        TimeSeries forecast = this.computePointForecasts(steps);
        double criticalValue = new Normal().quantile(1.0 - alpha / 2.0);
        double[] upperPredictionValues = new double[steps];
        double[] errors = this.getStdErrors(forecast, criticalValue);
        for (int t = 0; t < steps; ++t) {
            upperPredictionValues[t] = forecast.at(t) + errors[t];
        }
        return TimeSeries.from(forecast.timePeriod(), forecast.observationTimes().get(0), upperPredictionValues);
    }

    @Override
    public TimeSeries computeLowerPredictionBounds(TimeSeries forecast, int steps, double alpha) {
        double criticalValue = new Normal().quantile(alpha / 2.0);
        double[] lowerPredictionValues = new double[steps];
        double[] errors = this.getStdErrors(forecast, criticalValue);
        for (int t = 0; t < steps; ++t) {
            lowerPredictionValues[t] = forecast.at(t) + errors[t];
        }
        return TimeSeries.from(forecast.timePeriod(), forecast.observationTimes().get(0), lowerPredictionValues);
    }

    @Override
    public TimeSeries computePointForecasts(int steps) {
        int n = this.observations.size();
        double[] fcst = this.fcst(steps);
        TimePeriod timePeriod = this.observations.timePeriod();
        OffsetDateTime startTime = this.observations.observationTimes().get(n - 1).plus(timePeriod.periodLength() * timePeriod.timeUnit().unitLength(), timePeriod.timeUnit().temporalUnit());
        return TimeSeries.from(timePeriod, startTime, fcst);
    }

    public double[] fcst(int steps) {
        int d = this.order.d();
        int D = this.order.D();
        int n = this.differencedSeries.size();
        int m = this.observations.size();
        int seasonalFrequency = this.coefficients.seasonalFrequency();
        double[] arSarCoeffs = this.coefficients.getAllAutoRegressiveCoefficients();
        double[] maSmaCoeffs = this.coefficients.getAllMovingAverageCoefficients();
        double[] resid = this.residuals.asArray();
        double[] diffedFcst = new double[n + steps];
        double[] fcst = new double[m + steps];
        Vector regressionParameters = Vector.from((double[])this.coefficients.getRegressors(this.order));
        Vector regressionEffects = this.regressionMatrix.times(regressionParameters);
        TimeSeries armaSeries = this.observations.minus(regressionEffects.elements());
        TimeSeries differencedSeries = armaSeries.difference(1, this.order.d()).difference(seasonalFrequency, this.order.D());
        System.arraycopy(differencedSeries.asArray(), 0, diffedFcst, 0, n);
        System.arraycopy(armaSeries.asArray(), 0, fcst, 0, m);
        LagPolynomial diffPolynomial = LagPolynomial.differences(d);
        LagPolynomial seasDiffPolynomial = LagPolynomial.seasonalDifferences(seasonalFrequency, D);
        LagPolynomial lagPolynomial = diffPolynomial.times(seasDiffPolynomial);
        for (int t = 0; t < steps; ++t) {
            fcst[m + t] = lagPolynomial.solve(fcst, m + t);
            for (int i = 0; i < arSarCoeffs.length; ++i) {
                int n2 = n + t;
                diffedFcst[n2] = diffedFcst[n2] + arSarCoeffs[i] * diffedFcst[n + t - i - 1];
                int n3 = m + t;
                fcst[n3] = fcst[n3] + arSarCoeffs[i] * diffedFcst[n + t - i - 1];
            }
            for (int j = maSmaCoeffs.length; j > 0 && t < j; --j) {
                int n4 = n + t;
                diffedFcst[n4] = diffedFcst[n4] + maSmaCoeffs[j - 1] * resid[m + t - j];
                int n5 = m + t;
                fcst[n5] = fcst[n5] + maSmaCoeffs[j - 1] * resid[m + t - j];
            }
        }
        Matrix forecastRegressionMatrix = this.getForecastRegressionMatrix(steps, this.order);
        Vector forecastRegressionEffects = forecastRegressionMatrix.times(regressionParameters);
        Vector forecast = Vector.from((double[])DoubleFunctions.slice((double[])fcst, (int)m, (int)(m + steps)));
        return forecast.plus(forecastRegressionEffects).elements();
    }

    private double[] getStdErrors(TimeSeries forecast, double criticalValue) {
        double[] psiCoeffs = this.getPsiCoefficients(forecast);
        double[] stdErrors = new double[forecast.size()];
        double sigma = Math.sqrt(this.sigma2);
        double psiWeightSum = 0.0;
        for (int i = 0; i < stdErrors.length; ++i) {
            double sd = sigma * Math.sqrt(psiWeightSum += psiCoeffs[i] * psiCoeffs[i]);
            stdErrors[i] = criticalValue * sd;
        }
        return stdErrors;
    }

    private double[] getPsiCoefficients(TimeSeries forecast) {
        int steps = forecast.size();
        LagPolynomial arPoly = LagPolynomial.autoRegressive(this.coefficients.getAllAutoRegressiveCoefficients());
        LagPolynomial diffPoly = LagPolynomial.differences(this.order.d());
        LagPolynomial seasDiffPoly = LagPolynomial.seasonalDifferences(this.coefficients.seasonalFrequency(), this.order.D());
        double[] phi = diffPoly.times(seasDiffPoly).times(arPoly).inverseParams();
        double[] theta = this.coefficients.getAllMovingAverageCoefficients();
        double[] psi = new double[steps];
        psi[0] = 1.0;
        System.arraycopy(theta, 0, psi, 1, Math.min(steps - 1, theta.length));
        for (int j = 1; j < psi.length; ++j) {
            for (int i = 0; i < Math.min(j, phi.length); ++i) {
                int n = j;
                psi[n] = psi[n] + psi[j - i - 1] * phi[i];
            }
        }
        return psi;
    }

    private Matrix getForecastRegressionMatrix(int steps, ArimaOrder order) {
        double[][] matrix = new double[order.numRegressors()][steps];
        if (order.constant().include()) {
            matrix[0] = DoubleFunctions.fill((int)steps, (double)1.0);
        }
        if (order.drift().include()) {
            int startTime = this.observations.size() + 1;
            matrix[order.constant().asInt()] = Range.inclusiveRange(startTime, startTime + steps).asArray();
        }
        return Matrix.create((Matrix.Layout)Matrix.Layout.BY_COLUMN, (double[][])matrix);
    }

    static Builder builder() {
        return new Builder();
    }

    static class Builder {
        private TimeSeries observations;
        private ArimaCoefficients coefficients;
        private ArimaOrder order;
        private TimeSeries differencedSeries;
        private TimeSeries residuals;
        private Matrix regressionMatrix;
        private double sigma2;

        Builder() {
        }

        public Builder setObservations(TimeSeries observations) {
            this.observations = observations;
            return this;
        }

        public Builder setCoefficients(ArimaCoefficients coefficients) {
            this.coefficients = coefficients;
            return this;
        }

        public Builder setOrder(ArimaOrder order) {
            this.order = order;
            return this;
        }

        public Builder setDifferencedSeries(TimeSeries differencedSeries) {
            this.differencedSeries = differencedSeries;
            return this;
        }

        public Builder setResiduals(TimeSeries residuals) {
            this.residuals = residuals;
            return this;
        }

        public Builder setRegressionMatrix(Matrix regressionMatrix) {
            this.regressionMatrix = regressionMatrix;
            return this;
        }

        public Builder setSigma2(double sigma2) {
            this.sigma2 = sigma2;
            return this;
        }

        public ArimaForecaster build() {
            return new ArimaForecaster(this.observations, this.coefficients, this.order, this.differencedSeries, this.residuals, this.regressionMatrix, this.sigma2);
        }
    }
}

