/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.marketdata.model.curves.locallinearregression;

import java.time.LocalDate;
import net.finmath.marketdata.model.curves.Curve;
import net.finmath.marketdata.model.curves.CurveInterpolation;
import net.finmath.marketdata.model.curves.DiscountCurveInterpolation;
import net.finmath.marketdata.model.curves.locallinearregression.Partition;
import org.apache.commons.math3.distribution.AbstractRealDistribution;
import org.apache.commons.math3.distribution.CauchyDistribution;
import org.apache.commons.math3.distribution.LaplaceDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.jblas.DoubleMatrix;
import org.jblas.Solve;

public class CurveEstimation {
    private final LocalDate referenceDate;
    private final double bandwidth;
    private final double[] independentValues;
    private final double[] dependentValues;
    private final Partition partition;
    private final DiscountCurveInterpolation regressionCurve = null;
    private AbstractRealDistribution kernel;

    public CurveEstimation(LocalDate referenceDate, double bandwidth, double[] independentValues, double[] dependentValues, double[] partitionValues, double weight, Distribution distribution) {
        this.referenceDate = referenceDate;
        this.bandwidth = bandwidth;
        this.independentValues = independentValues;
        this.dependentValues = dependentValues;
        this.partition = new Partition((double[])partitionValues.clone(), weight);
        switch (distribution) {
            case LAPLACE: {
                this.kernel = new LaplaceDistribution(0.0, 1.0);
                break;
            }
            case CAUCHY: {
                this.kernel = new CauchyDistribution();
                break;
            }
            default: {
                this.kernel = new NormalDistribution();
            }
        }
    }

    public CurveEstimation(LocalDate referenceDate, double bandwidth, double[] independentValues, double[] dependentValues, double[] partitionValues, double weight) {
        this(referenceDate, bandwidth, independentValues, dependentValues, partitionValues, weight, Distribution.NORMAL);
    }

    public Curve getRegressionCurve() {
        if (this.regressionCurve != null) {
            return this.regressionCurve;
        }
        DoubleMatrix a = this.solveEquationSystem();
        double[] curvePoints = new double[this.partition.getLength()];
        curvePoints[0] = a.get(0);
        for (int i = 1; i < curvePoints.length; ++i) {
            curvePoints[i] = curvePoints[i - 1] + a.get(i) * this.partition.getIntervalLength(i - 1);
        }
        return new CurveInterpolation("RegressionCurve", this.referenceDate, CurveInterpolation.InterpolationMethod.LINEAR, CurveInterpolation.ExtrapolationMethod.CONSTANT, CurveInterpolation.InterpolationEntity.VALUE, this.partition.getPoints(), curvePoints);
    }

    private DoubleMatrix solveEquationSystem() {
        DoubleMatrix R = new DoubleMatrix(this.partition.getLength());
        DoubleMatrix M = new DoubleMatrix(this.partition.getLength(), this.partition.getLength());
        DoubleMatrix partitionAsVector = new DoubleMatrix(this.partition.getPoints());
        DoubleMatrix shiftedPartition = new DoubleMatrix(this.partition.getLength());
        for (int j = 1; j < shiftedPartition.length; ++j) {
            shiftedPartition.put(j, this.partition.getPoint(j - 1));
        }
        DoubleMatrix partitionIncrements = partitionAsVector.sub(shiftedPartition).put(0, 1.0);
        DoubleMatrix kernelValues = new DoubleMatrix(this.partition.getLength() - 1);
        DoubleMatrix M1_1 = new DoubleMatrix(1);
        DoubleMatrix MFirstCol = new DoubleMatrix(this.partition.getLength() - 1);
        DoubleMatrix MSubDiagonal = new DoubleMatrix(this.partition.getLength() - 1);
        DoubleMatrix MSubMatrix = new DoubleMatrix(this.partition.getLength() - 1, this.partition.getLength() - 1);
        DoubleMatrix MSubMatrixSum = new DoubleMatrix(this.partition.getLength() - 1);
        for (int i = 0; i < this.independentValues.length; ++i) {
            DoubleMatrix oneZeroVector = new DoubleMatrix(this.partition.getLength());
            DoubleMatrix kernelSum = new DoubleMatrix(this.partition.getLength());
            DoubleMatrix shiftedKernelVector = new DoubleMatrix(this.partition.getLength());
            for (int r = 0; r < this.partition.getLength() - 1; ++r) {
                oneZeroVector.put(r, 1.0);
                kernelValues.put(r, this.kernel.density((this.partition.getIntervalReferencePoint(r) - this.independentValues[i]) / this.bandwidth));
                shiftedKernelVector.put(r + 1, kernelValues.get(r));
                kernelSum = kernelSum.add(oneZeroVector.mmul(kernelValues.get(r)));
            }
            R = R.add(shiftedPartition.neg().add(this.independentValues[i]).mul(shiftedKernelVector).add(partitionIncrements.mul(kernelSum)).mul(this.dependentValues[i]));
            M1_1 = M1_1.add(kernelSum.get(0));
            MFirstCol = MFirstCol.add(partitionAsVector.getRange(0, partitionAsVector.length - 1).neg().add(this.independentValues[i]).mul(kernelValues).add(partitionIncrements.getRange(1, partitionAsVector.length).mul(kernelSum.getRange(1, kernelSum.length))));
            MSubDiagonal = MSubDiagonal.add(partitionAsVector.getRange(0, partitionAsVector.length - 1).neg().add(this.independentValues[i]).mul(partitionAsVector.getRange(0, partitionAsVector.length - 1).neg().add(this.independentValues[i])).mul(kernelValues).add(partitionIncrements.getRange(1, partitionAsVector.length).mul(partitionIncrements.getRange(1, partitionAsVector.length).mul(kernelSum.getRange(1, kernelSum.length)))));
            MSubMatrixSum = MSubMatrixSum.add(partitionAsVector.getRange(0, partitionAsVector.length - 1).neg().add(this.independentValues[i]).mul(kernelValues).add(partitionIncrements.getRange(1, partitionIncrements.length).mul(kernelSum.getRange(1, kernelSum.length))));
        }
        DoubleMatrix partitionIncrementMatrix = new DoubleMatrix(this.partition.getLength() - 1, this.partition.getLength() - 1);
        DoubleMatrix matrixDefine = DoubleMatrix.ones((int)(this.partition.getLength() - 1));
        for (int m = 0; m < matrixDefine.length - 1; ++m) {
            matrixDefine.put(m, 0.0);
            partitionIncrementMatrix.putColumn(m, matrixDefine.mul(partitionIncrements.get(m + 1)));
        }
        MSubMatrix = partitionIncrementMatrix.mulColumnVector(MSubMatrixSum);
        MSubMatrix = MSubMatrix.add(MSubMatrix.transpose()).add(DoubleMatrix.diag((DoubleMatrix)MSubDiagonal));
        int[] rowColIndex = new int[this.partition.getLength() - 1];
        for (int n = 0; n < rowColIndex.length; ++n) {
            rowColIndex[n] = n + 1;
        }
        M.put(0, 0, M1_1.get(0));
        M.put(rowColIndex, 0, MFirstCol);
        M.put(0, rowColIndex, MFirstCol.transpose());
        M.put(rowColIndex, rowColIndex, MSubMatrix);
        return Solve.solve((DoubleMatrix)M, (DoubleMatrix)R);
    }

    public static enum Distribution {
        NORMAL,
        LAPLACE,
        CAUCHY;

    }
}

