/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.math.impl.rootfinding.newton;

import com.google.common.primitives.Doubles;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.array.Matrix;
import com.opengamma.strata.math.MathException;
import com.opengamma.strata.math.impl.differentiation.VectorFieldFirstOrderDifferentiator;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.OGMatrixAlgebra;
import com.opengamma.strata.math.impl.rootfinding.VectorRootFinder;
import com.opengamma.strata.math.impl.rootfinding.newton.NewtonRootFinderDirectionFunction;
import com.opengamma.strata.math.impl.rootfinding.newton.NewtonRootFinderMatrixInitializationFunction;
import com.opengamma.strata.math.impl.rootfinding.newton.NewtonRootFinderMatrixUpdateFunction;
import com.opengamma.strata.math.rootfind.NewtonVectorRootFinder;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BaseNewtonVectorRootFinder
extends VectorRootFinder
implements NewtonVectorRootFinder {
    private static final Logger log = LoggerFactory.getLogger(BaseNewtonVectorRootFinder.class);
    private static final double ALPHA = 1.0E-4;
    private static final double BETA = 1.5;
    private static final int FULL_RECALC_FREQ = 20;
    private final double _absoluteTol;
    private final double _relativeTol;
    private final int _maxSteps;
    private final NewtonRootFinderDirectionFunction _directionFunction;
    private final NewtonRootFinderMatrixInitializationFunction _initializationFunction;
    private final NewtonRootFinderMatrixUpdateFunction _updateFunction;
    private final MatrixAlgebra _algebra = new OGMatrixAlgebra();

    public BaseNewtonVectorRootFinder(double absoluteTol, double relativeTol, int maxSteps, NewtonRootFinderDirectionFunction directionFunction, NewtonRootFinderMatrixInitializationFunction initializationFunction, NewtonRootFinderMatrixUpdateFunction updateFunction) {
        ArgChecker.notNegative((double)absoluteTol, (String)"absolute tolerance");
        ArgChecker.notNegative((double)relativeTol, (String)"relative tolerance");
        ArgChecker.notNegative((int)maxSteps, (String)"maxSteps");
        this._absoluteTol = absoluteTol;
        this._relativeTol = relativeTol;
        this._maxSteps = maxSteps;
        this._directionFunction = directionFunction;
        this._initializationFunction = initializationFunction;
        this._updateFunction = updateFunction;
    }

    @Override
    public DoubleArray getRoot(Function<DoubleArray, DoubleArray> function, DoubleArray startPosition) {
        return this.findRoot(function, startPosition);
    }

    @Override
    public DoubleArray findRoot(Function<DoubleArray, DoubleArray> function, DoubleArray startPosition) {
        VectorFieldFirstOrderDifferentiator jac = new VectorFieldFirstOrderDifferentiator();
        return this.findRoot(function, jac.differentiate(function), startPosition);
    }

    @Override
    public DoubleArray findRoot(Function<DoubleArray, DoubleArray> function, Function<DoubleArray, DoubleMatrix> jacobianFunction, DoubleArray startPosition) {
        DataBundle data = new DataBundle();
        DoubleArray y = this.checkInputsAndApplyFunction(function, startPosition);
        data.setX(startPosition);
        data.setY(y);
        data.setG0(this._algebra.getInnerProduct((Matrix)y, (Matrix)y));
        DoubleMatrix estimate = this._initializationFunction.getInitializedMatrix(jacobianFunction, startPosition);
        if (!this.getNextPosition(function, estimate, data)) {
            if (this.isConverged(data)) {
                return data.getX();
            }
            throw new MathException("Cannot work with this starting position. Please choose another point");
        }
        int count = 0;
        int jacReconCount = 1;
        while (!this.isConverged(data)) {
            if (jacReconCount % 20 == 0) {
                estimate = this._initializationFunction.getInitializedMatrix(jacobianFunction, data.getX());
                jacReconCount = 1;
            } else {
                estimate = this._updateFunction.getUpdatedMatrix(jacobianFunction, data.getX(), data.getDeltaX(), data.getDeltaY(), estimate);
                ++jacReconCount;
            }
            if (!this.getNextPosition(function, estimate, data)) {
                estimate = this._initializationFunction.getInitializedMatrix(jacobianFunction, data.getX());
                jacReconCount = 1;
                if (!this.getNextPosition(function, estimate, data)) {
                    if (this.isConverged(data)) {
                        return data.getX();
                    }
                    String msg = "Failed to converge in backtracking, even after a Jacobian recalculation." + this.getErrorMessage(data, jacobianFunction);
                    log.info(msg);
                    throw new MathException(msg);
                }
            }
            if (++count <= this._maxSteps) continue;
            throw new MathException("Failed to converge - maximum iterations of " + this._maxSteps + " reached." + this.getErrorMessage(data, jacobianFunction));
        }
        return data.getX();
    }

    private String getErrorMessage(DataBundle data, Function<DoubleArray, DoubleMatrix> jacobianFunction) {
        return "Final position:" + data.getX() + "\nlast deltaX:" + data.getDeltaX() + "\n function value:" + data.getY() + "\nJacobian: \n" + jacobianFunction.apply(data.getX());
    }

    private boolean getNextPosition(Function<DoubleArray, DoubleArray> function, DoubleMatrix estimate, DataBundle data) {
        DoubleArray p = this._directionFunction.getDirection(estimate, data.getY());
        if (data.getLambda0() < 1.0) {
            data.setLambda0(1.0);
        } else {
            data.setLambda0(data.getLambda0() * 1.5);
        }
        this.updatePosition(p, function, data);
        double g1 = data.getG1();
        if (!Doubles.isFinite((double)g1)) {
            this.bisectBacktrack(p, function, data);
        }
        if (data.getG1() > data.getG0() / (1.0 + 1.0E-4 * data.getLambda0())) {
            this.quadraticBacktrack(p, function, data);
            int count = 0;
            while (data.getG1() > data.getG0() / (1.0 + 1.0E-4 * data.getLambda0())) {
                if (count > 5) {
                    return false;
                }
                this.cubicBacktrack(p, function, data);
                ++count;
            }
        }
        DoubleArray deltaX = data.getDeltaX();
        DoubleArray deltaY = data.getDeltaY();
        data.setG0(data.getG1());
        data.setX((DoubleArray)this._algebra.add((Matrix)data.getX(), (Matrix)deltaX));
        data.setY((DoubleArray)this._algebra.add((Matrix)data.getY(), (Matrix)deltaY));
        return true;
    }

    protected void updatePosition(DoubleArray p, Function<DoubleArray, DoubleArray> function, DataBundle data) {
        double lambda0 = data.getLambda0();
        DoubleArray deltaX = (DoubleArray)this._algebra.scale((Matrix)p, -lambda0);
        DoubleArray xNew = (DoubleArray)this._algebra.add((Matrix)data.getX(), (Matrix)deltaX);
        DoubleArray yNew = function.apply(xNew);
        data.setDeltaX(deltaX);
        data.setDeltaY((DoubleArray)this._algebra.subtract((Matrix)yNew, (Matrix)data.getY()));
        data.setG2(data.getG1());
        data.setG1(this._algebra.getInnerProduct((Matrix)yNew, (Matrix)yNew));
    }

    private void bisectBacktrack(DoubleArray p, Function<DoubleArray, DoubleArray> function, DataBundle data) {
        do {
            data.setLambda0(data.getLambda0() * 0.1);
            this.updatePosition(p, function, data);
            if (data.getLambda0() != 0.0) continue;
            throw new MathException("Failed to converge");
        } while (Double.isNaN(data.getG1()) || Double.isInfinite(data.getG1()) || Double.isNaN(data.getG2()) || Double.isInfinite(data.getG2()));
    }

    private void quadraticBacktrack(DoubleArray p, Function<DoubleArray, DoubleArray> function, DataBundle data) {
        double lambda0 = data.getLambda0();
        double g0 = data.getG0();
        double lambda = Math.max(0.01 * lambda0, g0 * lambda0 * lambda0 / (data.getG1() + g0 * (2.0 * lambda0 - 1.0)));
        data.swapLambdaAndReplace(lambda);
        this.updatePosition(p, function, data);
    }

    private void cubicBacktrack(DoubleArray p, Function<DoubleArray, DoubleArray> function, DataBundle data) {
        double lambda0 = data.getLambda0();
        double lambda1 = data.getLambda1();
        double g0 = data.getG0();
        double temp1 = 1.0 / lambda0 / lambda0;
        double temp2 = 1.0 / lambda1 / lambda1;
        double temp3 = data.getG1() + g0 * (2.0 * lambda0 - 1.0);
        double temp4 = data.getG2() + g0 * (2.0 * lambda1 - 1.0);
        double temp5 = 1.0 / (lambda0 - lambda1);
        double a = temp5 * (temp1 * temp3 - temp2 * temp4);
        double b = temp5 * (-lambda1 * temp1 * temp3 + lambda0 * temp2 * temp4);
        double lambda = (-b + Math.sqrt(b * b + 6.0 * a * g0)) / 3.0 / a;
        lambda = Math.min(Math.max(lambda, 0.01 * lambda0), 0.75 * lambda1);
        data.swapLambdaAndReplace(lambda);
        this.updatePosition(p, function, data);
    }

    private boolean isConverged(DataBundle data) {
        DoubleArray deltaX = data.getDeltaX();
        DoubleArray x = data.getX();
        int n = deltaX.size();
        for (int i = 0; i < n; ++i) {
            double scale;
            double diff = Math.abs(deltaX.get(i));
            if (!(diff > this._absoluteTol + (scale = Math.abs(x.get(i))) * this._relativeTol)) continue;
            return false;
        }
        return Math.sqrt(data.getG0()) < this._absoluteTol;
    }

    private static class DataBundle {
        private double _g0;
        private double _g1;
        private double _g2;
        private double _lambda0;
        private double _lambda1;
        private DoubleArray _deltaY;
        private DoubleArray _y;
        private DoubleArray _deltaX;
        private DoubleArray _x;

        private DataBundle() {
        }

        public double getG0() {
            return this._g0;
        }

        public double getG1() {
            return this._g1;
        }

        public double getG2() {
            return this._g2;
        }

        public double getLambda0() {
            return this._lambda0;
        }

        public double getLambda1() {
            return this._lambda1;
        }

        public DoubleArray getDeltaY() {
            return this._deltaY;
        }

        public DoubleArray getY() {
            return this._y;
        }

        public DoubleArray getDeltaX() {
            return this._deltaX;
        }

        public DoubleArray getX() {
            return this._x;
        }

        public void setG0(double g0) {
            this._g0 = g0;
        }

        public void setG1(double g1) {
            this._g1 = g1;
        }

        public void setG2(double g2) {
            this._g2 = g2;
        }

        public void setLambda0(double lambda0) {
            this._lambda0 = lambda0;
        }

        public void setDeltaY(DoubleArray deltaY) {
            this._deltaY = deltaY;
        }

        public void setY(DoubleArray y) {
            this._y = y;
        }

        public void setDeltaX(DoubleArray deltaX) {
            this._deltaX = deltaX;
        }

        public void setX(DoubleArray x) {
            this._x = x;
        }

        public void swapLambdaAndReplace(double lambda0) {
            this._lambda1 = this._lambda0;
            this._lambda0 = lambda0;
        }
    }
}

