/*
 * Decompiled with CFR 0.152.
 */
package hivemall.anomaly;

import hivemall.utils.lang.Preconditions;
import hivemall.utils.math.MatrixUtils;
import hivemall.utils.math.StatsUtils;
import java.util.Arrays;
import javax.annotation.Nonnull;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

public final class SDAR2D {
    private final double _r;
    private final RealMatrix[] _C;
    private RealVector _mu;
    private RealMatrix _sigma;
    private RealVector _muOld;
    private RealMatrix _sigmaOld;
    private boolean _initialized;

    public SDAR2D(double r, int k) {
        Preconditions.checkArgument(0.0 < r && r < 1.0, "Invalid forgetfullness parameter r: " + r);
        Preconditions.checkArgument(k >= 1, "Invalid smoothing parameter k: " + k);
        this._r = r;
        this._C = new RealMatrix[k + 1];
        this._initialized = false;
    }

    @Nonnull
    public RealVector update(@Nonnull ArrayRealVector[] x, int k) {
        Preconditions.checkArgument(x.length >= 1, "x.length MUST be greater than 1: " + x.length);
        Preconditions.checkArgument(k >= 0, "k MUST be greater than or equals to 0: ", k);
        Preconditions.checkArgument(k < this._C.length, "k MUST be less than |C| but k=" + k + ", |C|=" + this._C.length);
        ArrayRealVector x_t = x[0];
        int dims = x_t.getDimension();
        if (!this._initialized) {
            this._mu = x_t.copy();
            this._sigma = new BlockRealMatrix(dims, dims);
            assert (this._sigma.isSquare());
            this._initialized = true;
            return new ArrayRealVector(dims);
        }
        Preconditions.checkArgument(k >= 1, "k MUST be greater than 0: ", k);
        this._muOld = this._mu.copy();
        this._sigmaOld = this._sigma.copy();
        this._mu = this._mu.mapMultiply(1.0 - this._r).add(x_t.mapMultiply(this._r));
        RealVector[] xResidual = new RealVector[k + 1];
        for (int j = 0; j <= k; ++j) {
            xResidual[j] = x[j].subtract(this._mu);
        }
        RealMatrix[] C = this._C;
        RealVector rxResidual0 = xResidual[0].mapMultiply(this._r);
        for (int j = 0; j <= k; ++j) {
            RealMatrix Cj = C[j];
            C[j] = Cj == null ? rxResidual0.outerProduct(x[j].subtract(this._mu)) : Cj.scalarMultiply(1.0 - this._r).add(rxResidual0.outerProduct(x[j].subtract(this._mu)));
        }
        RealMatrix[][] rhs = MatrixUtils.toeplitz(C, k);
        RealMatrix[] lhs = Arrays.copyOfRange(C, 1, k + 1);
        RealMatrix R = MatrixUtils.combinedMatrices(rhs, dims);
        RealMatrix L = MatrixUtils.combinedMatrices(lhs);
        RealMatrix A = MatrixUtils.solve(L, R, false);
        RealVector x_hat = this._mu.copy();
        for (int i = 0; i < k; ++i) {
            int offset = i * dims;
            RealMatrix Ai = A.getSubMatrix(offset, offset + dims - 1, 0, dims - 1);
            x_hat = x_hat.add(Ai.operate(xResidual[i + 1]));
        }
        ArrayRealVector xEstimateResidual = x_t.subtract(x_hat);
        this._sigma = this._sigma.scalarMultiply(1.0 - this._r).add(xEstimateResidual.mapMultiply(this._r).outerProduct(xEstimateResidual));
        return x_hat;
    }

    public double logLoss(@Nonnull RealVector actual, RealVector predicted) {
        return StatsUtils.logLoss(actual, predicted, this._sigma);
    }

    public double hellingerDistance() {
        return StatsUtils.hellingerDistance(this._muOld, this._sigmaOld, this._mu, this._sigma);
    }
}

