/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.montecarlo.conditionalexpectation;

import net.finmath.stochastic.ConditionalExpectationEstimator;
import net.finmath.stochastic.RandomVariable;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularValueDecomposition;

public class MonteCarloConditionalExpectationRegression
implements ConditionalExpectationEstimator {
    private RegressionBasisFunctions basisFunctionsEstimator = null;
    private RegressionBasisFunctions basisFunctionsPredictor = null;
    private transient DecompositionSolver solver;
    private final transient Object solverLock = new Object();

    public MonteCarloConditionalExpectationRegression() {
    }

    public MonteCarloConditionalExpectationRegression(RandomVariable[] basisFunctions) {
        this();
        this.basisFunctionsPredictor = this.basisFunctionsEstimator = new RegressionBasisFunctionsGiven(this.getNonZeroBasisFunctions(basisFunctions));
    }

    public MonteCarloConditionalExpectationRegression(RandomVariable[] basisFunctionsEstimator, RandomVariable[] basisFunctionsPredictor) {
        this();
        this.basisFunctionsEstimator = new RegressionBasisFunctionsGiven(this.getNonZeroBasisFunctions(basisFunctionsEstimator));
        this.basisFunctionsPredictor = new RegressionBasisFunctionsGiven(this.getNonZeroBasisFunctions(basisFunctionsPredictor));
    }

    @Override
    public RandomVariable getConditionalExpectation(RandomVariable randomVariable) {
        double[] linearRegressionParameters = this.getLinearRegressionParameters(randomVariable);
        RandomVariable[] basisFunctions = this.basisFunctionsPredictor.getBasisFunctions();
        RandomVariable conditionalExpectation = basisFunctions[0].mult(linearRegressionParameters[0]);
        for (int i = 1; i < basisFunctions.length; ++i) {
            conditionalExpectation = conditionalExpectation.addProduct(basisFunctions[i], linearRegressionParameters[i]);
        }
        return conditionalExpectation;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public double[] getLinearRegressionParameters(RandomVariable dependents) {
        RandomVariable[] basisFunctions = this.basisFunctionsEstimator.getBasisFunctions();
        Object object = this.solverLock;
        synchronized (object) {
            if (this.solver == null) {
                double[][] XTX = new double[basisFunctions.length][basisFunctions.length];
                for (int i = 0; i < basisFunctions.length; ++i) {
                    for (int j = i; j < basisFunctions.length; ++j) {
                        XTX[i][j] = basisFunctions[i].mult(basisFunctions[j]).getAverage();
                        XTX[j][i] = XTX[i][j];
                    }
                }
                this.solver = new SingularValueDecomposition((RealMatrix)new Array2DRowRealMatrix(XTX, false)).getSolver();
            }
        }
        double[] XTy = new double[basisFunctions.length];
        for (int i = 0; i < basisFunctions.length; ++i) {
            XTy[i] = dependents.mult(basisFunctions[i]).getAverage();
        }
        double[] linearRegressionParameters = this.solver.solve((RealVector)new ArrayRealVector(XTy)).toArray();
        return linearRegressionParameters;
    }

    public RegressionBasisFunctions getBasisFunctionsEstimator() {
        return this.basisFunctionsEstimator;
    }

    public RegressionBasisFunctions getBasisFunctionsPredictor() {
        return this.basisFunctionsPredictor;
    }

    private RandomVariable[] getNonZeroBasisFunctions(RandomVariable[] basisFunctions) {
        int numberOfNonZeroBasisFunctions = 0;
        for (int indexBasisFunction = 0; indexBasisFunction < basisFunctions.length; ++indexBasisFunction) {
            if (basisFunctions[indexBasisFunction] == null) continue;
            ++numberOfNonZeroBasisFunctions;
        }
        RandomVariable[] nonZerobasisFunctions = new RandomVariable[numberOfNonZeroBasisFunctions];
        int indexOfNonZeroBasisFunctions = 0;
        for (RandomVariable basisFunction : basisFunctions) {
            if (basisFunction == null) continue;
            nonZerobasisFunctions[indexOfNonZeroBasisFunctions] = basisFunction;
            ++indexOfNonZeroBasisFunctions;
        }
        return nonZerobasisFunctions;
    }

    public static interface RegressionBasisFunctions {
        public RandomVariable[] getBasisFunctions();
    }

    public static class RegressionBasisFunctionsGiven
    implements RegressionBasisFunctions {
        private final RandomVariable[] basisFunctions;

        public RegressionBasisFunctionsGiven(RandomVariable[] basisFunctions) {
            this.basisFunctions = basisFunctions;
        }

        @Override
        public RandomVariable[] getBasisFunctions() {
            return this.basisFunctions;
        }
    }
}

