/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.math.impl.statistics.leastsquare;

import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.array.Matrix;
import com.opengamma.strata.math.MathException;
import com.opengamma.strata.math.MathUtils;
import com.opengamma.strata.math.impl.differentiation.VectorFieldFirstOrderDifferentiator;
import com.opengamma.strata.math.impl.differentiation.VectorFieldSecondOrderDifferentiator;
import com.opengamma.strata.math.impl.function.ParameterizedFunction;
import com.opengamma.strata.math.impl.linearalgebra.DecompositionFactory;
import com.opengamma.strata.math.impl.linearalgebra.SVDecompositionCommons;
import com.opengamma.strata.math.impl.linearalgebra.SVDecompositionResult;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebraFactory;
import com.opengamma.strata.math.impl.statistics.leastsquare.LeastSquareResults;
import com.opengamma.strata.math.linearalgebra.Decomposition;
import com.opengamma.strata.math.linearalgebra.DecompositionResult;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NonLinearLeastSquare {
    private static final Logger LOGGER = LoggerFactory.getLogger(NonLinearLeastSquare.class);
    private static final int MAX_ATTEMPTS = 10000;
    private static final Function<DoubleArray, Boolean> UNCONSTRAINED = new Function<DoubleArray, Boolean>(){

        @Override
        public Boolean apply(DoubleArray x) {
            return true;
        }
    };
    private final double _eps;
    private final Decomposition<?> _decomposition;
    private final MatrixAlgebra _algebra;

    public NonLinearLeastSquare() {
        this(DecompositionFactory.SV_COMMONS, MatrixAlgebraFactory.OG_ALGEBRA, 1.0E-8);
    }

    public NonLinearLeastSquare(Decomposition<?> decomposition, MatrixAlgebra algebra, double eps) {
        this._decomposition = decomposition;
        this._algebra = algebra;
        this._eps = eps;
    }

    public LeastSquareResults solve(DoubleArray x, DoubleArray y, ParameterizedFunction<Double, DoubleArray, Double> func, DoubleArray startPos) {
        ArgChecker.notNull((Object)x, (String)"x");
        ArgChecker.notNull((Object)y, (String)"y");
        int n = x.size();
        ArgChecker.isTrue((y.size() == n ? 1 : 0) != 0, (String)"y wrong length");
        return this.solve(x, y, DoubleArray.filled((int)n, (double)1.0), func, startPos);
    }

    public LeastSquareResults solve(DoubleArray x, DoubleArray y, double sigma, ParameterizedFunction<Double, DoubleArray, Double> func, DoubleArray startPos) {
        ArgChecker.notNull((Object)x, (String)"x");
        ArgChecker.notNull((Object)y, (String)"y");
        ArgChecker.notNull((Object)sigma, (String)"sigma");
        int n = x.size();
        ArgChecker.isTrue((y.size() == n ? 1 : 0) != 0, (String)"y wrong length");
        return this.solve(x, y, DoubleArray.filled((int)n, (double)sigma), func, startPos);
    }

    public LeastSquareResults solve(final DoubleArray x, DoubleArray y, DoubleArray sigma, final ParameterizedFunction<Double, DoubleArray, Double> func, DoubleArray startPos) {
        ArgChecker.notNull((Object)x, (String)"x");
        ArgChecker.notNull((Object)y, (String)"y");
        ArgChecker.notNull((Object)sigma, (String)"sigma");
        int n = x.size();
        ArgChecker.isTrue((y.size() == n ? 1 : 0) != 0, (String)"y wrong length");
        ArgChecker.isTrue((sigma.size() == n ? 1 : 0) != 0, (String)"sigma wrong length");
        Function<DoubleArray, DoubleArray> func1D = new Function<DoubleArray, DoubleArray>(){

            @Override
            public DoubleArray apply(DoubleArray theta) {
                return DoubleArray.of((int)x.size(), i -> (Double)func.evaluate(x.get(i), theta));
            }
        };
        return this.solve(y, sigma, func1D, startPos, null);
    }

    public LeastSquareResults solve(DoubleArray x, DoubleArray y, ParameterizedFunction<Double, DoubleArray, Double> func, ParameterizedFunction<Double, DoubleArray, DoubleArray> grad, DoubleArray startPos) {
        ArgChecker.notNull((Object)x, (String)"x");
        ArgChecker.notNull((Object)y, (String)"y");
        ArgChecker.notNull((Object)x, (String)"sigma");
        int n = x.size();
        ArgChecker.isTrue((y.size() == n ? 1 : 0) != 0, (String)"y wrong length");
        return this.solve(x, y, DoubleArray.filled((int)n, (double)1.0), func, grad, startPos);
    }

    public LeastSquareResults solve(DoubleArray x, DoubleArray y, double sigma, ParameterizedFunction<Double, DoubleArray, Double> func, ParameterizedFunction<Double, DoubleArray, DoubleArray> grad, DoubleArray startPos) {
        ArgChecker.notNull((Object)x, (String)"x");
        ArgChecker.notNull((Object)y, (String)"y");
        int n = x.size();
        ArgChecker.isTrue((y.size() == n ? 1 : 0) != 0, (String)"y wrong length");
        return this.solve(x, y, DoubleArray.filled((int)n, (double)sigma), func, grad, startPos);
    }

    public LeastSquareResults solve(final DoubleArray x, DoubleArray y, DoubleArray sigma, final ParameterizedFunction<Double, DoubleArray, Double> func, final ParameterizedFunction<Double, DoubleArray, DoubleArray> grad, DoubleArray startPos) {
        ArgChecker.notNull((Object)x, (String)"x");
        ArgChecker.notNull((Object)y, (String)"y");
        ArgChecker.notNull((Object)x, (String)"sigma");
        int n = x.size();
        ArgChecker.isTrue((y.size() == n ? 1 : 0) != 0, (String)"y wrong length");
        ArgChecker.isTrue((sigma.size() == n ? 1 : 0) != 0, (String)"sigma wrong length");
        Function<DoubleArray, DoubleArray> func1D = new Function<DoubleArray, DoubleArray>(){

            @Override
            public DoubleArray apply(DoubleArray theta) {
                return DoubleArray.of((int)x.size(), i -> (Double)func.evaluate(x.get(i), theta));
            }
        };
        Function<DoubleArray, DoubleMatrix> jac = new Function<DoubleArray, DoubleMatrix>(){

            @Override
            public DoubleMatrix apply(DoubleArray theta) {
                int m = x.size();
                double[][] res = new double[m][];
                for (int i = 0; i < m; ++i) {
                    DoubleArray temp = (DoubleArray)grad.evaluate(x.get(i), theta);
                    res[i] = temp.toArray();
                }
                return DoubleMatrix.copyOf((double[][])res);
            }
        };
        return this.solve(y, sigma, func1D, jac, startPos, null);
    }

    public LeastSquareResults solve(DoubleArray observedValues, Function<DoubleArray, DoubleArray> func, DoubleArray startPos) {
        int n = observedValues.size();
        VectorFieldFirstOrderDifferentiator jac = new VectorFieldFirstOrderDifferentiator();
        return this.solve(observedValues, DoubleArray.filled((int)n, (double)1.0), func, jac.differentiate(func), startPos, null);
    }

    public LeastSquareResults solve(DoubleArray observedValues, DoubleArray sigma, Function<DoubleArray, DoubleArray> func, DoubleArray startPos) {
        VectorFieldFirstOrderDifferentiator jac = new VectorFieldFirstOrderDifferentiator();
        return this.solve(observedValues, sigma, func, jac.differentiate(func), startPos, null);
    }

    public LeastSquareResults solve(DoubleArray observedValues, DoubleArray sigma, Function<DoubleArray, DoubleArray> func, DoubleArray startPos, DoubleArray maxJumps) {
        VectorFieldFirstOrderDifferentiator jac = new VectorFieldFirstOrderDifferentiator();
        return this.solve(observedValues, sigma, func, jac.differentiate(func), startPos, maxJumps);
    }

    public LeastSquareResults solve(DoubleArray observedValues, DoubleArray sigma, Function<DoubleArray, DoubleArray> func, Function<DoubleArray, DoubleMatrix> jac, DoubleArray startPos) {
        return this.solve(observedValues, sigma, func, jac, startPos, UNCONSTRAINED, null);
    }

    public LeastSquareResults solve(DoubleArray observedValues, DoubleArray sigma, Function<DoubleArray, DoubleArray> func, Function<DoubleArray, DoubleMatrix> jac, DoubleArray startPos, DoubleArray maxJumps) {
        return this.solve(observedValues, sigma, func, jac, startPos, UNCONSTRAINED, maxJumps);
    }

    public LeastSquareResults solve(DoubleArray observedValues, DoubleArray sigma, Function<DoubleArray, DoubleArray> func, Function<DoubleArray, DoubleMatrix> jac, DoubleArray startPos, Function<DoubleArray, Boolean> constraints, DoubleArray maxJumps) {
        ArgChecker.notNull((Object)observedValues, (String)"observedValues");
        ArgChecker.notNull((Object)sigma, (String)" sigma");
        ArgChecker.notNull(func, (String)" func");
        ArgChecker.notNull(jac, (String)" jac");
        ArgChecker.notNull((Object)startPos, (String)"startPos");
        int nObs = observedValues.size();
        int nParms = startPos.size();
        ArgChecker.isTrue((nObs == sigma.size() ? 1 : 0) != 0, (String)"observedValues and sigma must be same length");
        ArgChecker.isTrue((nObs >= nParms ? 1 : 0) != 0, (String)"must have data points greater or equal to number of parameters. #date points = {}, #parameters = {}", (Object[])new Object[]{nObs, nParms});
        ArgChecker.isTrue((boolean)constraints.apply(startPos), (String)"The inital value of the parameters (startPos) is {} - this is not an allowed value", (Object[])new Object[]{startPos});
        DoubleArray theta = startPos;
        double lambda = 0.0;
        DoubleArray error = this.getError(func, observedValues, sigma, theta);
        DoubleMatrix jacobian = this.getJacobian(jac, sigma, theta);
        double oldChiSqr = this.getChiSqr(error);
        if (oldChiSqr == 0.0) {
            return this.finish(oldChiSqr, jacobian, theta, sigma);
        }
        DoubleArray beta = this.getChiSqrGrad(error, jacobian);
        for (int count = 0; count < 10000; ++count) {
            DoubleArray deltaTheta;
            Object decmp;
            DoubleMatrix alpha = this.getModifiedCurvatureMatrix(jacobian, lambda);
            try {
                decmp = this._decomposition.apply(alpha);
                deltaTheta = decmp.solve(beta);
            }
            catch (Exception e) {
                throw new MathException(e);
            }
            DoubleArray trialTheta = (DoubleArray)this._algebra.add((Matrix)theta, (Matrix)deltaTheta);
            if (!constraints.apply(trialTheta).booleanValue() || !this.allowJump(deltaTheta, maxJumps)) {
                lambda = this.increaseLambda(lambda);
                continue;
            }
            DoubleArray newError = this.getError(func, observedValues, sigma, trialTheta);
            double newChiSqr = this.getChiSqr(newError);
            if (Math.abs(newChiSqr - oldChiSqr) / (1.0 + oldChiSqr) < this._eps) {
                int j;
                int i;
                DoubleMatrix alpha0;
                DoubleMatrix doubleMatrix = alpha0 = lambda == 0.0 ? alpha : this.getModifiedCurvatureMatrix(jacobian, 0.0);
                if (newChiSqr < this._eps) {
                    if (lambda > 0.0) {
                        decmp = this._decomposition.apply(alpha0);
                    }
                    return this.finish(alpha0, (DecompositionResult)decmp, newChiSqr, jacobian, trialTheta, sigma);
                }
                SVDecompositionCommons svd = (SVDecompositionCommons)DecompositionFactory.SV_COMMONS;
                VectorFieldSecondOrderDifferentiator diff = new VectorFieldSecondOrderDifferentiator();
                Function<DoubleArray, DoubleMatrix[]> secDivFunc = diff.differentiate(func, constraints);
                DoubleMatrix[] secDiv = secDivFunc.apply(trialTheta);
                double[][] temp = new double[nParms][nParms];
                for (int i2 = 0; i2 < nObs; ++i2) {
                    for (int j2 = 0; j2 < nParms; ++j2) {
                        for (int k = 0; k < nParms; ++k) {
                            double[] dArray = temp[j2];
                            int n = k;
                            dArray[n] = dArray[n] - newError.get(i2) * secDiv[i2].get(j2, k) / sigma.get(i2);
                        }
                    }
                }
                DoubleMatrix newAlpha = (DoubleMatrix)this._algebra.add((Matrix)alpha0, (Matrix)DoubleMatrix.copyOf((double[][])temp));
                SVDecompositionResult svdRes = svd.apply(newAlpha);
                double[] w = svdRes.getSingularValues();
                DoubleMatrix u = svdRes.getU();
                DoubleMatrix v = svdRes.getV();
                double[] p = new double[nParms];
                boolean saddle = false;
                double sum = 0.0;
                for (i = 0; i < nParms; ++i) {
                    int sign;
                    double a = 0.0;
                    for (j = 0; j < nParms; ++j) {
                        a += u.get(j, i) * v.get(j, i);
                    }
                    int n = sign = a > 0.0 ? 1 : -1;
                    if (!(w[i] * (double)sign < 0.0)) continue;
                    sum += w[i];
                    w[i] = -w[i];
                    saddle = true;
                }
                if (saddle) {
                    DoubleArray direction;
                    lambda = this.increaseLambda(lambda);
                    for (i = 0; i < nParms; ++i) {
                        if (!(w[i] < 0.0)) continue;
                        double scale = 0.5 * Math.sqrt(-oldChiSqr * w[i]) / sum;
                        for (j = 0; j < nParms; ++j) {
                            int n = j;
                            p[n] = p[n] + scale * u.get(j, i);
                        }
                    }
                    deltaTheta = direction = DoubleArray.copyOf((double[])p);
                    trialTheta = (DoubleArray)this._algebra.add((Matrix)theta, (Matrix)deltaTheta);
                    int i3 = 0;
                    double scale = 1.0;
                    while (!constraints.apply(trialTheta).booleanValue()) {
                        deltaTheta = (DoubleArray)this._algebra.scale((Matrix)direction, scale *= -0.5);
                        trialTheta = (DoubleArray)this._algebra.add((Matrix)theta, (Matrix)deltaTheta);
                        if (++i3 <= 10) continue;
                        throw new MathException("Could not satify constraint");
                    }
                    newError = this.getError(func, observedValues, sigma, trialTheta);
                    newChiSqr = this.getChiSqr(newError);
                    int counter = 0;
                    while (newChiSqr > oldChiSqr) {
                        if (counter > 10 || Math.abs(newChiSqr - oldChiSqr) / (1.0 + oldChiSqr) < this._eps) {
                            LOGGER.warn("Saddle point detected, but no improvement to chi^2 possible by moving away. It is recommended that a different starting point is used.");
                            return this.finish(newAlpha, (DecompositionResult)decmp, oldChiSqr, jacobian, theta, sigma);
                        }
                        deltaTheta = (DoubleArray)this._algebra.scale((Matrix)direction, scale /= 2.0);
                        trialTheta = (DoubleArray)this._algebra.add((Matrix)theta, (Matrix)deltaTheta);
                        newError = this.getError(func, observedValues, sigma, trialTheta);
                        newChiSqr = this.getChiSqr(newError);
                        ++counter;
                    }
                } else {
                    return this.finish(newAlpha, (DecompositionResult)decmp, newChiSqr, jacobian, trialTheta, sigma);
                }
            }
            if (newChiSqr < oldChiSqr) {
                lambda = this.decreaseLambda(lambda);
                theta = trialTheta;
                error = newError;
                jacobian = this.getJacobian(jac, sigma, trialTheta);
                beta = this.getChiSqrGrad(error, jacobian);
                oldChiSqr = newChiSqr;
                continue;
            }
            lambda = this.increaseLambda(lambda);
        }
        throw new MathException("Could not converge in 10000 attempts");
    }

    private double decreaseLambda(double lambda) {
        return lambda / 10.0;
    }

    private double increaseLambda(double lambda) {
        if (lambda == 0.0) {
            return 0.1;
        }
        return lambda * 10.0;
    }

    private boolean allowJump(DoubleArray deltaTheta, DoubleArray maxJumps) {
        if (maxJumps == null) {
            return true;
        }
        int n = deltaTheta.size();
        for (int i = 0; i < n; ++i) {
            if (!(Math.abs(deltaTheta.get(i)) > maxJumps.get(i))) continue;
            return false;
        }
        return true;
    }

    public DoubleMatrix calInverseJacobian(DoubleArray sigma, Function<DoubleArray, DoubleArray> func, Function<DoubleArray, DoubleMatrix> jac, DoubleArray originalSolution) {
        DoubleMatrix jacobian = this.getJacobian(jac, sigma, originalSolution);
        DoubleMatrix a = this.getModifiedCurvatureMatrix(jacobian, 0.0);
        DoubleMatrix bT = this.getBTranspose(jacobian, sigma);
        Object decRes = this._decomposition.apply(a);
        return decRes.solve(bT);
    }

    private LeastSquareResults finish(double newChiSqr, DoubleMatrix jacobian, DoubleArray newTheta, DoubleArray sigma) {
        DoubleMatrix alpha = this.getModifiedCurvatureMatrix(jacobian, 0.0);
        Object decmp = this._decomposition.apply(alpha);
        return this.finish(alpha, (DecompositionResult)decmp, newChiSqr, jacobian, newTheta, sigma);
    }

    private LeastSquareResults finish(DoubleMatrix alpha, DecompositionResult decmp, double newChiSqr, DoubleMatrix jacobian, DoubleArray newTheta, DoubleArray sigma) {
        DoubleMatrix covariance = decmp.solve(DoubleMatrix.identity((int)alpha.rowCount()));
        DoubleMatrix bT = this.getBTranspose(jacobian, sigma);
        DoubleMatrix inverseJacobian = decmp.solve(bT);
        return new LeastSquareResults(newChiSqr, newTheta, covariance, inverseJacobian);
    }

    private DoubleArray getError(Function<DoubleArray, DoubleArray> func, DoubleArray observedValues, DoubleArray sigma, DoubleArray theta) {
        DoubleArray modelValues;
        int n = observedValues.size();
        ArgChecker.isTrue((n == (modelValues = func.apply(theta)).size() ? 1 : 0) != 0, (String)("Number of data points different between model (" + modelValues.size() + ") and observed (" + n + ")"));
        return DoubleArray.of((int)n, i -> (observedValues.get(i) - modelValues.get(i)) / sigma.get(i));
    }

    private DoubleMatrix getBTranspose(DoubleMatrix jacobian, DoubleArray sigma) {
        int n = jacobian.rowCount();
        int m = jacobian.columnCount();
        double[][] res = new double[m][n];
        for (int i = 0; i < n; ++i) {
            double sigmaInv = 1.0 / sigma.get(i);
            for (int k = 0; k < m; ++k) {
                res[k][i] = jacobian.get(i, k) * sigmaInv;
            }
        }
        return DoubleMatrix.copyOf((double[][])res);
    }

    private DoubleMatrix getJacobian(Function<DoubleArray, DoubleMatrix> jac, DoubleArray sigma, DoubleArray theta) {
        DoubleMatrix res = jac.apply(theta);
        double[][] data = res.toArray();
        int n = res.rowCount();
        int m = res.columnCount();
        ArgChecker.isTrue((theta.size() == m ? 1 : 0) != 0, (String)"Jacobian is wrong size");
        ArgChecker.isTrue((sigma.size() == n ? 1 : 0) != 0, (String)"Jacobian is wrong size");
        for (int i = 0; i < n; ++i) {
            double sigmaInv = 1.0 / sigma.get(i);
            int j = 0;
            while (j < m) {
                double[] dArray = data[i];
                int n2 = j++;
                dArray[n2] = dArray[n2] * sigmaInv;
            }
        }
        return DoubleMatrix.ofUnsafe((double[][])data);
    }

    private double getChiSqr(DoubleArray error) {
        return this._algebra.getInnerProduct((Matrix)error, (Matrix)error);
    }

    private DoubleArray getChiSqrGrad(DoubleArray error, DoubleMatrix jacobian) {
        return (DoubleArray)this._algebra.multiply((Matrix)error, (Matrix)jacobian);
    }

    private DoubleArray getDiagonalCurvatureMatrix(DoubleMatrix jacobian) {
        int n = jacobian.rowCount();
        int m = jacobian.columnCount();
        double[] alpha = new double[m];
        for (int i = 0; i < m; ++i) {
            double sum = 0.0;
            for (int k = 0; k < n; ++k) {
                sum += MathUtils.pow2(jacobian.get(k, i));
            }
            alpha[i] = sum;
        }
        return DoubleArray.copyOf((double[])alpha);
    }

    private DoubleMatrix getModifiedCurvatureMatrix(DoubleMatrix jacobian, double lambda) {
        int m = jacobian.columnCount();
        double onePLambda = 1.0 + lambda;
        DoubleMatrix alpha = this._algebra.matrixTransposeMultiplyMatrix(jacobian);
        double[][] data = alpha.toArray();
        int i = 0;
        while (i < m) {
            double[] dArray = data[i];
            int n = i++;
            dArray[n] = dArray[n] * onePLambda;
        }
        return DoubleMatrix.ofUnsafe((double[][])data);
    }
}

