/*
 * Decompiled with CFR 0.152.
 */
package com.github.signaflo.timeseries.model;

import com.github.signaflo.math.stats.distributions.Distribution;
import com.github.signaflo.math.stats.distributions.Normal;
import com.github.signaflo.timeseries.TimePeriod;
import com.github.signaflo.timeseries.TimeSeries;
import com.github.signaflo.timeseries.forecast.Forecast;
import com.github.signaflo.timeseries.model.Model;
import com.github.signaflo.timeseries.model.RandomWalkForecaster;
import java.time.OffsetDateTime;
import lombok.NonNull;

public final class RandomWalk
implements Model {
    private final TimeSeries timeSeries;
    private final TimeSeries fittedSeries;
    private final TimeSeries residuals;

    public RandomWalk(@NonNull TimeSeries observed) {
        if (observed == null) {
            throw new NullPointerException("observed");
        }
        if (observed.size() < 1) {
            throw new IllegalArgumentException("A random walk model requires at least one observation.");
        }
        this.timeSeries = observed;
        this.fittedSeries = this.fitSeries();
        this.residuals = this.calculateResiduals();
    }

    public static TimeSeries simulate(@NonNull Distribution dist, int n) {
        if (dist == null) {
            throw new NullPointerException("dist");
        }
        if (n < 1) {
            throw new IllegalArgumentException("the number of observations to simulate must be a positive integer.");
        }
        double[] series = new double[n];
        series[0] = dist.rand();
        for (int t = 1; t < n; ++t) {
            series[t] = series[t - 1] + dist.rand();
        }
        return TimeSeries.from(series);
    }

    public static TimeSeries simulate(double mean, double sigma, int n) {
        Normal dist = new Normal(mean, sigma);
        return RandomWalk.simulate((Distribution)dist, n);
    }

    public static TimeSeries simulate(double sigma, int n) {
        Normal dist = new Normal(0.0, sigma);
        return RandomWalk.simulate((Distribution)dist, n);
    }

    public static TimeSeries simulate(int n) {
        Normal dist = new Normal(0.0, 1.0);
        return RandomWalk.simulate((Distribution)dist, n);
    }

    @Override
    public Forecast forecast(int steps, double alpha) {
        int n = this.timeSeries.size();
        TimePeriod timePeriod = this.timeSeries.timePeriod();
        OffsetDateTime startTime = this.timeSeries.observationTimes().get(n - 1).plus(timePeriod.periodLength() * timePeriod.timeUnit().unitLength(), timePeriod.timeUnit().temporalUnit());
        double[] forecast = new double[steps];
        for (int t = 0; t < steps; ++t) {
            forecast[t] = this.timeSeries.at(n - 1);
        }
        RandomWalkForecaster forecaster = new RandomWalkForecaster(this.timeSeries, this.predictionErrors());
        return forecaster.forecast(steps, alpha);
    }

    @Override
    public TimeSeries observations() {
        return this.timeSeries;
    }

    @Override
    public TimeSeries fittedSeries() {
        return this.fittedSeries;
    }

    @Override
    public TimeSeries predictionErrors() {
        return this.residuals;
    }

    private TimeSeries fitSeries() {
        double[] fitted = new double[this.timeSeries.size()];
        fitted[0] = this.timeSeries.at(0);
        for (int t = 1; t < this.timeSeries.size(); ++t) {
            fitted[t] = this.timeSeries.at(t - 1);
        }
        return TimeSeries.from(this.timeSeries.timePeriod(), this.timeSeries.observationTimes().get(0), fitted);
    }

    private TimeSeries calculateResiduals() {
        double[] residuals = new double[this.timeSeries.size()];
        for (int t = 1; t < this.timeSeries.size(); ++t) {
            residuals[t] = this.timeSeries.at(t) - this.fittedSeries.at(t);
        }
        return TimeSeries.from(this.timeSeries.timePeriod(), this.timeSeries.observationTimes().get(0), residuals);
    }

    public String toString() {
        return "Random walk time series model";
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        RandomWalk that = (RandomWalk)o;
        if (this.timeSeries != null ? !this.timeSeries.equals(that.timeSeries) : that.timeSeries != null) {
            return false;
        }
        if (!this.fittedSeries.equals(that.fittedSeries)) {
            return false;
        }
        return this.residuals.equals(that.residuals);
    }

    public int hashCode() {
        int result = this.timeSeries != null ? this.timeSeries.hashCode() : 0;
        result = 31 * result + this.fittedSeries.hashCode();
        result = 31 * result + this.residuals.hashCode();
        return result;
    }
}

