/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.optimizer;

import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import net.finmath.montecarlo.automaticdifferentiation.RandomVariableDifferentiableInterface;
import net.finmath.optimizer.SolverException;
import net.finmath.optimizer.StochasticLevenbergMarquardt;
import net.finmath.stochastic.RandomVariableInterface;

public abstract class StochasticLevenbergMarquardtAD
extends StochasticLevenbergMarquardt {
    private static final long serialVersionUID = -8852002990042152135L;
    private final boolean isGradientValuationParallel;

    public StochasticLevenbergMarquardtAD(StochasticLevenbergMarquardt.RegularizationMethod regularizationMethod, RandomVariableInterface[] initialParameters, RandomVariableInterface[] targetValues, RandomVariableInterface[] parameterSteps, int maxIteration, double errorTolerance, ExecutorService executorService, boolean isGradientValuationParallel) {
        super(regularizationMethod, initialParameters, targetValues, parameterSteps, maxIteration, errorTolerance, executorService);
        this.isGradientValuationParallel = isGradientValuationParallel;
    }

    public StochasticLevenbergMarquardtAD(StochasticLevenbergMarquardt.RegularizationMethod regularizationMethod, RandomVariableInterface[] initialParameters, RandomVariableInterface[] targetValues, RandomVariableInterface[] parameterSteps, int maxIteration, double errorTolerance, ExecutorService executorService) {
        this(regularizationMethod, initialParameters, targetValues, parameterSteps, maxIteration, errorTolerance, executorService, false);
    }

    protected void prepareAndSetValues(RandomVariableInterface[] parameters, RandomVariableInterface[] values) throws SolverException {
        for (int i = 0; i < parameters.length; ++i) {
            if (!(parameters[i] instanceof RandomVariableDifferentiableInterface)) continue;
            parameters[i] = ((RandomVariableDifferentiableInterface)parameters[i]).getCloneIndependent();
        }
        this.setValues(parameters, values);
    }

    protected void prepareAndSetDerivatives(RandomVariableInterface[] parameters, RandomVariableInterface[] values, RandomVariableInterface[][] derivatives) throws SolverException {
        boolean isRandomVariableDifferentiable = true;
        for (int parameterIndex = 0; parameterIndex < parameters.length && isRandomVariableDifferentiable; ++parameterIndex) {
            isRandomVariableDifferentiable = parameters[parameterIndex] instanceof RandomVariableDifferentiableInterface && isRandomVariableDifferentiable;
        }
        for (int valueIndex2 = 0; valueIndex2 < values.length && isRandomVariableDifferentiable; ++valueIndex2) {
            isRandomVariableDifferentiable = values[valueIndex2] instanceof RandomVariableDifferentiableInterface && isRandomVariableDifferentiable;
        }
        if (isRandomVariableDifferentiable) {
            Map gradients = null;
            if (this.isGradientValuationParallel) {
                gradients = IntStream.range(0, values.length).parallel().boxed().collect(Collectors.toConcurrentMap(Function.identity(), valueIndex -> ((RandomVariableDifferentiableInterface)values[valueIndex]).getGradient()));
            }
            for (int valueIndex3 = 0; valueIndex3 < values.length; ++valueIndex3) {
                Map<Long, RandomVariableInterface> gradient = gradients != null ? (Map<Long, RandomVariableInterface>)gradients.get(valueIndex3) : ((RandomVariableDifferentiableInterface)values[valueIndex3]).getGradient();
                for (int parameterIndex = 0; parameterIndex < parameters.length; ++parameterIndex) {
                    derivatives[parameterIndex][valueIndex3] = gradient.get(((RandomVariableDifferentiableInterface)parameters[parameterIndex]).getID());
                    if (derivatives[parameterIndex][valueIndex3] == null) continue;
                    derivatives[parameterIndex][valueIndex3] = derivatives[parameterIndex][valueIndex3].average();
                }
            }
        } else {
            this.setDerivatives(parameters, derivatives);
        }
    }
}

