/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.optimizer;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.finmath.functions.LinearAlgebra;
import net.finmath.montecarlo.RandomVariableFromDoubleArray;
import net.finmath.optimizer.SolverException;
import net.finmath.optimizer.StochasticOptimizer;
import net.finmath.stochastic.RandomVariable;
import net.finmath.stochastic.Scalar;

public abstract class StochasticPathwiseLevenbergMarquardt
implements Serializable,
Cloneable,
StochasticOptimizer {
    private static final long serialVersionUID = 4560864869394838155L;
    private RandomVariable[] initialParameters = null;
    private RandomVariable[] parameterSteps = null;
    private RandomVariable[] targetValues = null;
    private RandomVariable[] weights = null;
    private final int maxIteration;
    private double[] lambda;
    private final double lambdaInitialValue = 0.001;
    private double lambdaDivisor = 1.3;
    private double lambdaMultiplicator = 2.0;
    private int numberOfPaths;
    private final RandomVariable errorTolerance;
    private int iteration = 0;
    private RandomVariable[] parameterTest = null;
    private RandomVariable[] valueTest = null;
    private RandomVariable[] parameterCurrent = null;
    private RandomVariable[] valueCurrent = null;
    private RandomVariable[][] derivativeCurrent = null;
    private RandomVariable errorMeanSquaredCurrent = new RandomVariableFromDoubleArray(Double.POSITIVE_INFINITY);
    private RandomVariable errorRootMeanSquaredChange = new RandomVariableFromDoubleArray(Double.POSITIVE_INFINITY);
    private boolean[] isParameterCurrentDerivativeValid;
    private ExecutorService executor = null;
    private boolean executorShutdownWhenDone = true;
    private final Logger logger = Logger.getLogger("net.finmath");

    public static void main(String[] args) throws SolverException {
        RandomVariable[] initialParameters = new RandomVariable[]{new RandomVariableFromDoubleArray(2.0), new RandomVariableFromDoubleArray(2.0)};
        RandomVariable[] weights = new RandomVariable[]{new RandomVariableFromDoubleArray(1.0), new RandomVariableFromDoubleArray(1.0)};
        RandomVariable[] parameterSteps = new RandomVariable[]{new RandomVariableFromDoubleArray(1.0), new RandomVariableFromDoubleArray(1.0)};
        int maxIteration = 100;
        RandomVariable[] targetValues = new RandomVariable[]{new RandomVariableFromDoubleArray(25.0), new RandomVariableFromDoubleArray(100.0)};
        StochasticPathwiseLevenbergMarquardt optimizer = new StochasticPathwiseLevenbergMarquardt(initialParameters, targetValues, weights, parameterSteps, 100, null, null){
            private static final long serialVersionUID = -282626938650139518L;

            @Override
            public void setValues(RandomVariable[] parameters, RandomVariable[] values) {
                values[0] = parameters[0].mult(0.0).add(parameters[1]).squared();
                values[1] = parameters[0].mult(2.0).add(parameters[1]).squared();
            }
        };
        optimizer.run();
        RandomVariable[] bestParameters = optimizer.getBestFitParameters();
        System.out.println("The solver for problem 1 required " + optimizer.getIterations() + " iterations. The best fit parameters are:");
        for (int i = 0; i < bestParameters.length; ++i) {
            System.out.println("\tparameter[" + i + "]: " + bestParameters[i]);
        }
        System.out.println("The solver accuracy is " + optimizer.getRootMeanSquaredError());
    }

    public StochasticPathwiseLevenbergMarquardt(RandomVariable[] initialParameters, RandomVariable[] targetValues, RandomVariable[] weights, RandomVariable[] parameterSteps, int maxIteration, RandomVariable errorTolerance, ExecutorService executorService) {
        this.initialParameters = initialParameters;
        this.targetValues = targetValues;
        this.weights = weights;
        this.parameterSteps = parameterSteps;
        this.maxIteration = maxIteration;
        RandomVariable randomVariable = this.errorTolerance = errorTolerance != null ? errorTolerance : new RandomVariableFromDoubleArray(0.0);
        if (weights == null) {
            this.weights = new RandomVariable[targetValues.length];
            for (int i = 0; i < targetValues.length; ++i) {
                this.weights[i] = new RandomVariableFromDoubleArray(1.0);
            }
        }
        this.executor = executorService;
        this.executorShutdownWhenDone = executorService == null;
    }

    public StochasticPathwiseLevenbergMarquardt(RandomVariable[] initialParameters, RandomVariable[] targetValues, int maxIteration, int numberOfThreads) {
        this(initialParameters, targetValues, null, null, maxIteration, null, numberOfThreads > 1 ? Executors.newFixedThreadPool(numberOfThreads) : null);
    }

    public StochasticPathwiseLevenbergMarquardt(List<RandomVariable> initialParameters, List<RandomVariable> targetValues, int maxIteration, ExecutorService executorService) {
        this(StochasticPathwiseLevenbergMarquardt.numberListToDoubleArray(initialParameters), StochasticPathwiseLevenbergMarquardt.numberListToDoubleArray(targetValues), null, null, maxIteration, null, executorService);
    }

    public StochasticPathwiseLevenbergMarquardt(List<RandomVariable> initialParameters, List<RandomVariable> targetValues, int maxIteration, int numberOfThreads) {
        this(StochasticPathwiseLevenbergMarquardt.numberListToDoubleArray(initialParameters), StochasticPathwiseLevenbergMarquardt.numberListToDoubleArray(targetValues), maxIteration, numberOfThreads);
    }

    private static RandomVariable[] numberListToDoubleArray(List<RandomVariable> listOfNumbers) {
        RandomVariable[] array = new RandomVariable[listOfNumbers.size()];
        for (int i = 0; i < array.length; ++i) {
            array[i] = listOfNumbers.get(i);
        }
        return array;
    }

    public double[] getLambda() {
        return this.lambda;
    }

    public void setLambda(double[] lambda) {
        this.lambda = lambda;
    }

    public double getLambdaMultiplicator() {
        return this.lambdaMultiplicator;
    }

    public void setLambdaMultiplicator(double lambdaMultiplicator) {
        if (lambdaMultiplicator <= 1.0) {
            throw new IllegalArgumentException("Parameter lambdaMultiplicator is required to be > 1.");
        }
        this.lambdaMultiplicator = lambdaMultiplicator;
    }

    public double getLambdaDivisor() {
        return this.lambdaDivisor;
    }

    public void setLambdaDivisor(double lambdaDivisor) {
        if (lambdaDivisor <= 1.0) {
            throw new IllegalArgumentException("Parameter lambdaDivisor is required to be > 1.");
        }
        this.lambdaDivisor = lambdaDivisor;
    }

    @Override
    public RandomVariable[] getBestFitParameters() {
        return this.parameterCurrent;
    }

    @Override
    public double getRootMeanSquaredError() {
        return this.errorMeanSquaredCurrent.average().sqrt().doubleValue();
    }

    public void setErrorMeanSquaredCurrent(RandomVariable errorMeanSquaredCurrent) {
        this.errorMeanSquaredCurrent = errorMeanSquaredCurrent;
    }

    @Override
    public int getIterations() {
        return this.iteration;
    }

    protected void prepareAndSetValues(RandomVariable[] parameters, RandomVariable[] values) throws SolverException {
        this.setValues(parameters, values);
    }

    protected void prepareAndSetDerivatives(RandomVariable[] parameters, RandomVariable[] values, RandomVariable[][] derivatives) throws SolverException {
        this.setDerivatives(parameters, derivatives);
    }

    public abstract void setValues(RandomVariable[] var1, RandomVariable[] var2) throws SolverException;

    public void setDerivatives(RandomVariable[] parameters, RandomVariable[][] derivatives) throws SolverException {
        int parameterIndex;
        parameters = this.parameterCurrent;
        Vector<Future<RandomVariable[]>> valueFutures = new Vector<Future<RandomVariable[]>>(this.parameterCurrent.length);
        for (parameterIndex = 0; parameterIndex < this.parameterCurrent.length; ++parameterIndex) {
            final RandomVariable[] parametersNew = (RandomVariable[])parameters.clone();
            final RandomVariable[] derivative = derivatives[parameterIndex];
            final int workerParameterIndex = parameterIndex;
            Callable<RandomVariable[]> worker = new Callable<RandomVariable[]>(){

                @Override
                public RandomVariable[] call() {
                    RandomVariable parameterFiniteDifference = StochasticPathwiseLevenbergMarquardt.this.parameterSteps != null ? StochasticPathwiseLevenbergMarquardt.this.parameterSteps[workerParameterIndex] : parametersNew[workerParameterIndex].abs().add(1.0).mult(1.0E-8);
                    parametersNew[workerParameterIndex] = parametersNew[workerParameterIndex].add(parameterFiniteDifference);
                    try {
                        StochasticPathwiseLevenbergMarquardt.this.prepareAndSetValues(parametersNew, derivative);
                    }
                    catch (Exception e) {
                        Arrays.fill(derivative, new RandomVariableFromDoubleArray(Double.NaN));
                    }
                    for (int valueIndex = 0; valueIndex < StochasticPathwiseLevenbergMarquardt.this.valueCurrent.length; ++valueIndex) {
                        derivative[valueIndex] = derivative[valueIndex].sub(StochasticPathwiseLevenbergMarquardt.this.valueCurrent[valueIndex]).div(parameterFiniteDifference);
                        derivative[valueIndex] = derivative[valueIndex].isNaN().sub(0.5).mult(-1.0).choose(derivative[valueIndex], new Scalar(0.0));
                    }
                    return derivative;
                }
            };
            if (this.executor != null) {
                Future<RandomVariable[]> valueFuture = this.executor.submit(worker);
                valueFutures.add(parameterIndex, valueFuture);
                continue;
            }
            FutureTask<RandomVariable[]> valueFutureTask = new FutureTask<RandomVariable[]>(worker);
            valueFutureTask.run();
            valueFutures.add(parameterIndex, valueFutureTask);
        }
        for (parameterIndex = 0; parameterIndex < this.parameterCurrent.length; ++parameterIndex) {
            try {
                derivatives[parameterIndex] = (RandomVariable[])((Future)valueFutures.get(parameterIndex)).get();
                continue;
            }
            catch (InterruptedException | ExecutionException e) {
                throw new SolverException(e);
            }
        }
    }

    boolean done() {
        return this.iteration > this.maxIteration || this.errorRootMeanSquaredChange.sub(this.errorTolerance).getMax() <= 0.0;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void run() throws SolverException {
        try {
            int numberOfParameters = this.initialParameters.length;
            int numberOfValues = this.targetValues.length;
            this.parameterTest = (RandomVariable[])this.initialParameters.clone();
            this.parameterCurrent = (RandomVariable[])this.initialParameters.clone();
            this.valueTest = new RandomVariable[numberOfValues];
            this.valueCurrent = new RandomVariable[numberOfValues];
            Arrays.fill(this.valueCurrent, new RandomVariableFromDoubleArray(Double.NaN));
            this.derivativeCurrent = new RandomVariable[numberOfParameters][numberOfValues];
            this.iteration = 0;
            while (true) {
                ++this.iteration;
                this.prepareAndSetValues(this.parameterTest, this.valueTest);
                RandomVariable errorMeanSquaredTest = this.getMeanSquaredError(this.valueTest);
                RandomVariable isPointAccepted = this.errorMeanSquaredCurrent.sub(errorMeanSquaredTest);
                for (int parameterIndex = 0; parameterIndex < this.parameterCurrent.length; ++parameterIndex) {
                    this.parameterCurrent[parameterIndex] = isPointAccepted.choose(this.parameterTest[parameterIndex], this.parameterCurrent[parameterIndex]);
                }
                for (int valueIndex = 0; valueIndex < this.valueCurrent.length; ++valueIndex) {
                    this.valueCurrent[valueIndex] = isPointAccepted.choose(this.valueTest[valueIndex], this.valueCurrent[valueIndex]);
                }
                this.errorRootMeanSquaredChange = isPointAccepted.choose(this.errorMeanSquaredCurrent.sqrt().sub(errorMeanSquaredTest.sqrt()), this.errorRootMeanSquaredChange);
                this.errorMeanSquaredCurrent = errorMeanSquaredTest.cap(this.errorMeanSquaredCurrent);
                if (this.done()) {
                    break;
                }
                this.numberOfPaths = isPointAccepted.size();
                if (this.lambda == null) {
                    this.lambda = new double[this.numberOfPaths];
                    Arrays.fill(this.lambda, 0.001);
                }
                if (this.isParameterCurrentDerivativeValid == null) {
                    this.isParameterCurrentDerivativeValid = new boolean[this.numberOfPaths];
                    Arrays.fill(this.isParameterCurrentDerivativeValid, false);
                }
                for (int pathIndex = 0; pathIndex < isPointAccepted.size(); ++pathIndex) {
                    this.isParameterCurrentDerivativeValid[pathIndex] = isPointAccepted.get(pathIndex) <= 0.0;
                    this.lambda[pathIndex] = isPointAccepted.get(pathIndex) >= 0.0 ? this.lambda[pathIndex] / this.lambdaDivisor : this.lambda[pathIndex] * this.lambdaMultiplicator;
                }
                this.prepareAndSetDerivatives(this.parameterTest, this.valueTest, this.derivativeCurrent);
                double[][] parameterIncrement = new double[this.parameterCurrent.length][this.numberOfPaths];
                for (int pathIndex = 0; pathIndex < this.numberOfPaths; ++pathIndex) {
                    double[][] hessianMatrix = new double[this.parameterCurrent.length][this.parameterCurrent.length];
                    double[] beta = new double[this.parameterCurrent.length];
                    boolean hessianInvalid = true;
                    while (hessianInvalid) {
                        int i;
                        for (i = 0; i < this.parameterCurrent.length; ++i) {
                            for (int j = i; j < this.parameterCurrent.length; ++j) {
                                double alphaElement = 0.0;
                                for (int valueIndex = 0; valueIndex < this.valueCurrent.length; ++valueIndex) {
                                    alphaElement += this.weights[valueIndex].get(pathIndex) * this.derivativeCurrent[i][valueIndex].get(pathIndex) * this.derivativeCurrent[j][valueIndex].get(pathIndex);
                                }
                                if (i == j) {
                                    alphaElement = alphaElement == 0.0 ? 1.0 : (alphaElement *= 1.0 + this.lambda[pathIndex]);
                                }
                                hessianMatrix[i][j] = alphaElement;
                                hessianMatrix[j][i] = alphaElement;
                            }
                        }
                        for (i = 0; i < this.parameterCurrent.length; ++i) {
                            double betaElement = 0.0;
                            for (int k = 0; k < this.valueCurrent.length; ++k) {
                                betaElement += this.weights[k].get(pathIndex) * (this.targetValues[k].get(pathIndex) - this.valueCurrent[k].get(pathIndex)) * this.derivativeCurrent[i][k].get(pathIndex);
                            }
                            beta[i] = betaElement;
                        }
                        try {
                            double[] parameterIncrementOnPath = LinearAlgebra.solveLinearEquationSymmetric(hessianMatrix, beta);
                            for (int i2 = 0; i2 < parameterIncrementOnPath.length; ++i2) {
                                parameterIncrement[i2][pathIndex] = parameterIncrementOnPath[i2];
                            }
                            hessianInvalid = false;
                        }
                        catch (Exception e) {
                            hessianInvalid = true;
                            int n = pathIndex;
                            this.lambda[n] = this.lambda[n] * 16.0;
                        }
                    }
                }
                for (int i = 0; i < this.parameterCurrent.length; ++i) {
                    this.parameterTest[i] = this.parameterCurrent[i].add(this.numberOfPaths == 1 ? new Scalar(parameterIncrement[i][0]) : new RandomVariableFromDoubleArray(0.0, parameterIncrement[i]));
                }
                if (!this.logger.isLoggable(Level.FINE)) continue;
                String logString = "Iteration: " + this.iteration + "\tLambda=" + this.lambda + "\tError Current:" + this.errorMeanSquaredCurrent + "\tError Change:" + this.errorRootMeanSquaredChange + "\t";
                for (int i = 0; i < this.parameterCurrent.length; ++i) {
                    logString = logString + "[" + i + "] = " + this.parameterCurrent[i] + "\t";
                }
                this.logger.fine(logString);
            }
        }
        finally {
            if (this.executor != null && this.executorShutdownWhenDone) {
                this.executor.shutdown();
                this.executor = null;
            }
        }
    }

    public RandomVariable getMeanSquaredError(RandomVariable[] value) {
        RandomVariable error = new RandomVariableFromDoubleArray(0.0);
        for (int valueIndex = 0; valueIndex < value.length; ++valueIndex) {
            RandomVariable deviation = value[valueIndex].sub(this.targetValues[valueIndex]);
            error = error.addProduct(this.weights[valueIndex], deviation.squared());
        }
        return error.div(value.length);
    }

    public StochasticPathwiseLevenbergMarquardt clone() {
        return null;
    }

    public StochasticPathwiseLevenbergMarquardt getCloneWithModifiedTargetValues(RandomVariable[] newTargetVaues, RandomVariable[] newWeights, boolean isUseBestParametersAsInitialParameters) throws CloneNotSupportedException {
        StochasticPathwiseLevenbergMarquardt clonedOptimizer = this.clone();
        clonedOptimizer.targetValues = (RandomVariable[])newTargetVaues.clone();
        clonedOptimizer.weights = (RandomVariable[])newWeights.clone();
        if (isUseBestParametersAsInitialParameters && this.done()) {
            clonedOptimizer.initialParameters = this.getBestFitParameters();
        }
        return clonedOptimizer;
    }

    public StochasticPathwiseLevenbergMarquardt getCloneWithModifiedTargetValues(List<RandomVariable> newTargetVaues, List<RandomVariable> newWeights, boolean isUseBestParametersAsInitialParameters) throws CloneNotSupportedException {
        StochasticPathwiseLevenbergMarquardt clonedOptimizer = this.clone();
        clonedOptimizer.targetValues = StochasticPathwiseLevenbergMarquardt.numberListToDoubleArray(newTargetVaues);
        clonedOptimizer.weights = StochasticPathwiseLevenbergMarquardt.numberListToDoubleArray(newWeights);
        if (isUseBestParametersAsInitialParameters && this.done()) {
            clonedOptimizer.initialParameters = this.getBestFitParameters();
        }
        return clonedOptimizer;
    }
}

