/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.math.impl.statistics.leastsquare;

import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.array.Matrix;
import com.opengamma.strata.math.MathException;
import com.opengamma.strata.math.impl.differentiation.VectorFieldFirstOrderDifferentiator;
import com.opengamma.strata.math.impl.linearalgebra.DecompositionFactory;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.OGMatrixAlgebra;
import com.opengamma.strata.math.impl.statistics.leastsquare.LeastSquareWithPenaltyResults;
import com.opengamma.strata.math.linearalgebra.Decomposition;
import com.opengamma.strata.math.linearalgebra.DecompositionResult;
import java.util.function.Function;

public class NonLinearLeastSquareWithPenalty {
    private static final int MAX_ATTEMPTS = 100000;
    private static final Decomposition<?> DEFAULT_DECOMP = DecompositionFactory.SV_COMMONS;
    private static final OGMatrixAlgebra MA = new OGMatrixAlgebra();
    private static final double EPS = 1.0E-8;
    public static final Function<DoubleArray, Boolean> UNCONSTRAINED = new Function<DoubleArray, Boolean>(){

        @Override
        public Boolean apply(DoubleArray x) {
            return true;
        }
    };
    private final double _eps;
    private final Decomposition<?> _decomposition;
    private final MatrixAlgebra _algebra;

    public NonLinearLeastSquareWithPenalty() {
        this(DEFAULT_DECOMP, MA, 1.0E-8);
    }

    public NonLinearLeastSquareWithPenalty(Decomposition<?> decomposition) {
        this(decomposition, MA, 1.0E-8);
    }

    public NonLinearLeastSquareWithPenalty(double eps) {
        this(DEFAULT_DECOMP, MA, eps);
    }

    public NonLinearLeastSquareWithPenalty(Decomposition<?> decomposition, double eps) {
        this(decomposition, MA, eps);
    }

    public NonLinearLeastSquareWithPenalty(Decomposition<?> decomposition, MatrixAlgebra algebra, double eps) {
        ArgChecker.notNull(decomposition, (String)"decomposition");
        ArgChecker.notNull((Object)algebra, (String)"algebra");
        ArgChecker.isTrue((eps > 0.0 ? 1 : 0) != 0, (String)"must have positive eps");
        this._decomposition = decomposition;
        this._algebra = algebra;
        this._eps = eps;
    }

    public LeastSquareWithPenaltyResults solve(DoubleArray observedValues, Function<DoubleArray, DoubleArray> func, DoubleArray startPos, DoubleMatrix penalty) {
        int n = observedValues.size();
        VectorFieldFirstOrderDifferentiator jac = new VectorFieldFirstOrderDifferentiator();
        return this.solve(observedValues, DoubleArray.filled((int)n, (double)1.0), func, jac.differentiate(func), startPos, penalty);
    }

    public LeastSquareWithPenaltyResults solve(DoubleArray observedValues, DoubleArray sigma, Function<DoubleArray, DoubleArray> func, DoubleArray startPos, DoubleMatrix penalty) {
        VectorFieldFirstOrderDifferentiator jac = new VectorFieldFirstOrderDifferentiator();
        return this.solve(observedValues, sigma, func, jac.differentiate(func), startPos, penalty);
    }

    public LeastSquareWithPenaltyResults solve(DoubleArray observedValues, DoubleArray sigma, Function<DoubleArray, DoubleArray> func, DoubleArray startPos, DoubleMatrix penalty, Function<DoubleArray, Boolean> allowedValue) {
        VectorFieldFirstOrderDifferentiator jac = new VectorFieldFirstOrderDifferentiator();
        return this.solve(observedValues, sigma, func, jac.differentiate(func), startPos, penalty, allowedValue);
    }

    public LeastSquareWithPenaltyResults solve(DoubleArray observedValues, DoubleArray sigma, Function<DoubleArray, DoubleArray> func, Function<DoubleArray, DoubleMatrix> jac, DoubleArray startPos, DoubleMatrix penalty) {
        return this.solve(observedValues, sigma, func, jac, startPos, penalty, UNCONSTRAINED);
    }

    public LeastSquareWithPenaltyResults solve(DoubleArray observedValues, DoubleArray sigma, Function<DoubleArray, DoubleArray> func, Function<DoubleArray, DoubleMatrix> jac, DoubleArray startPos, DoubleMatrix penalty, Function<DoubleArray, Boolean> allowedValue) {
        ArgChecker.notNull((Object)observedValues, (String)"observedValues");
        ArgChecker.notNull((Object)sigma, (String)" sigma");
        ArgChecker.notNull(func, (String)" func");
        ArgChecker.notNull(jac, (String)" jac");
        ArgChecker.notNull((Object)startPos, (String)"startPos");
        int nObs = observedValues.size();
        ArgChecker.isTrue((nObs == sigma.size() ? 1 : 0) != 0, (String)"observedValues and sigma must be same length");
        ArgChecker.isTrue((boolean)allowedValue.apply(startPos), (String)"The start position {} is not valid for this model. Please choose a valid start position", (Object[])new Object[]{startPos});
        DoubleArray theta = startPos;
        double lambda = 0.0;
        DoubleArray error = this.getError(func, observedValues, sigma, theta);
        DoubleMatrix jacobian = this.getJacobian(jac, sigma, theta);
        double oldChiSqr = this.getChiSqr(error);
        double p = this.getANorm(penalty, theta);
        oldChiSqr += p;
        DoubleArray beta = this.getChiSqrGrad(error, jacobian);
        DoubleArray temp = (DoubleArray)this._algebra.multiply((Matrix)penalty, (Matrix)theta);
        beta = (DoubleArray)this._algebra.subtract((Matrix)beta, (Matrix)temp);
        for (int count = 0; count < 100000; ++count) {
            DoubleArray deltaTheta;
            Object decmp;
            DoubleMatrix alpha = this.getModifiedCurvatureMatrix(jacobian, lambda, penalty);
            try {
                decmp = this._decomposition.apply(alpha);
                deltaTheta = decmp.solve(beta);
            }
            catch (Exception e) {
                throw new MathException(e);
            }
            DoubleArray trialTheta = (DoubleArray)this._algebra.add((Matrix)theta, (Matrix)deltaTheta);
            if (!allowedValue.apply(trialTheta).booleanValue()) {
                lambda = this.increaseLambda(lambda);
                continue;
            }
            DoubleArray newError = this.getError(func, observedValues, sigma, trialTheta);
            p = this.getANorm(penalty, trialTheta);
            double newChiSqr = this.getChiSqr(newError);
            if (Math.abs((newChiSqr += p) - oldChiSqr) / (1.0 + oldChiSqr) < this._eps) {
                DoubleMatrix alpha0;
                DoubleMatrix doubleMatrix = alpha0 = lambda == 0.0 ? alpha : this.getModifiedCurvatureMatrix(jacobian, 0.0, penalty);
                if (lambda > 0.0) {
                    decmp = this._decomposition.apply(alpha0);
                }
                return this.finish(alpha0, (DecompositionResult)decmp, newChiSqr - p, p, jacobian, trialTheta, sigma);
            }
            if (newChiSqr < oldChiSqr) {
                lambda = this.decreaseLambda(lambda);
                theta = trialTheta;
                error = newError;
                jacobian = this.getJacobian(jac, sigma, trialTheta);
                beta = this.getChiSqrGrad(error, jacobian);
                temp = (DoubleArray)this._algebra.multiply((Matrix)penalty, (Matrix)theta);
                beta = (DoubleArray)this._algebra.subtract((Matrix)beta, (Matrix)temp);
                oldChiSqr = newChiSqr;
                continue;
            }
            lambda = this.increaseLambda(lambda);
        }
        throw new MathException("Could not converge in 100000 attempts");
    }

    private double decreaseLambda(double lambda) {
        return lambda / 10.0;
    }

    private double increaseLambda(double lambda) {
        if (lambda == 0.0) {
            return 0.1;
        }
        return lambda * 10.0;
    }

    private LeastSquareWithPenaltyResults finish(DoubleMatrix alpha, DecompositionResult decmp, double chiSqr, double penalty, DoubleMatrix jacobian, DoubleArray newTheta, DoubleArray sigma) {
        DoubleMatrix covariance = decmp.solve(DoubleMatrix.identity((int)alpha.rowCount()));
        DoubleMatrix bT = this.getBTranspose(jacobian, sigma);
        DoubleMatrix inverseJacobian = decmp.solve(bT);
        return new LeastSquareWithPenaltyResults(chiSqr, penalty, newTheta, covariance, inverseJacobian);
    }

    private DoubleArray getError(Function<DoubleArray, DoubleArray> func, DoubleArray observedValues, DoubleArray sigma, DoubleArray theta) {
        DoubleArray modelValues;
        int n = observedValues.size();
        ArgChecker.isTrue((n == (modelValues = func.apply(theta)).size() ? 1 : 0) != 0, (String)("Number of data points different between model (" + modelValues.size() + ") and observed (" + n + ")"));
        return DoubleArray.of((int)n, i -> (observedValues.get(i) - modelValues.get(i)) / sigma.get(i));
    }

    private DoubleMatrix getBTranspose(DoubleMatrix jacobian, DoubleArray sigma) {
        int n = jacobian.rowCount();
        int m = jacobian.columnCount();
        DoubleMatrix res = DoubleMatrix.filled((int)m, (int)n);
        double[][] data = res.toArray();
        for (int i = 0; i < n; ++i) {
            double sigmaInv = 1.0 / sigma.get(i);
            for (int k = 0; k < m; ++k) {
                data[k][i] = jacobian.get(i, k) * sigmaInv;
            }
        }
        return DoubleMatrix.ofUnsafe((double[][])data);
    }

    private DoubleMatrix getJacobian(Function<DoubleArray, DoubleMatrix> jac, DoubleArray sigma, DoubleArray theta) {
        DoubleMatrix res = jac.apply(theta);
        double[][] data = res.toArray();
        int n = res.rowCount();
        int m = res.columnCount();
        ArgChecker.isTrue((theta.size() == m ? 1 : 0) != 0, (String)"Jacobian is wrong size");
        ArgChecker.isTrue((sigma.size() == n ? 1 : 0) != 0, (String)"Jacobian is wrong size");
        for (int i = 0; i < n; ++i) {
            double sigmaInv = 1.0 / sigma.get(i);
            int j = 0;
            while (j < m) {
                double[] dArray = data[i];
                int n2 = j++;
                dArray[n2] = dArray[n2] * sigmaInv;
            }
        }
        return DoubleMatrix.ofUnsafe((double[][])data);
    }

    private double getChiSqr(DoubleArray error) {
        return this._algebra.getInnerProduct((Matrix)error, (Matrix)error);
    }

    private DoubleArray getChiSqrGrad(DoubleArray error, DoubleMatrix jacobian) {
        return (DoubleArray)this._algebra.multiply((Matrix)error, (Matrix)jacobian);
    }

    private DoubleMatrix getModifiedCurvatureMatrix(DoubleMatrix jacobian, double lambda, DoubleMatrix penalty) {
        double onePLambda = 1.0 + lambda;
        int m = jacobian.columnCount();
        DoubleMatrix alpha = (DoubleMatrix)MA.add((Matrix)MA.matrixTransposeMultiplyMatrix(jacobian), (Matrix)penalty);
        double[][] data = alpha.toArray();
        int i = 0;
        while (i < m) {
            double[] dArray = data[i];
            int n = i++;
            dArray[n] = dArray[n] * onePLambda;
        }
        return DoubleMatrix.ofUnsafe((double[][])data);
    }

    private double getANorm(DoubleMatrix a, DoubleArray x) {
        int n = x.size();
        double sum = 0.0;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                sum += a.get(i, j) * x.get(i) * x.get(j);
            }
        }
        return sum;
    }
}

