/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.learningcurve.extrapolation.lc;

import ai.libs.jaicore.ml.interfaces.AnalyticalLearningCurve;
import ai.libs.jaicore.ml.learningcurve.extrapolation.lc.LinearCombinationFunction;
import ai.libs.jaicore.ml.learningcurve.extrapolation.lc.LinearCombinationLearningCurveConfiguration;
import ai.libs.jaicore.ml.learningcurve.extrapolation.lc.LinearCombinationParameterSet;
import ai.libs.jaicore.ml.learningcurve.extrapolation.lc.ParametricFunction;
import java.util.ArrayList;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.solvers.BrentSolver;
import org.apache.commons.math3.exception.NoBracketingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LinearCombinationLearningCurve
implements AnalyticalLearningCurve {
    private static final Logger LOG = LoggerFactory.getLogger(LinearCombinationLearningCurve.class);
    private static final int ROOT_COMPUTATION_RETIRES = 8;
    private static final double SLOPE_SATURATION_POINT = 1.0E-4;
    private static final double TOLERANCE_CONVERGENCE_VALUE = 1.0;
    private static final double SLOPE_CONVERGENCE_VALUE = 1.0E-7;
    private LinearCombinationFunction learningCurve;
    private LinearCombinationFunction derivative;
    private int dataSetSize;

    public LinearCombinationLearningCurve(LinearCombinationLearningCurveConfiguration configuration, int dataSetSize) {
        ArrayList<UnivariateFunction> learningCurves = new ArrayList<UnivariateFunction>();
        ArrayList<UnivariateFunction> derivatives = new ArrayList<UnivariateFunction>();
        for (LinearCombinationParameterSet parameterSet : configuration.getParameterSets()) {
            learningCurves.add(this.generateLearningCurve(parameterSet));
            derivatives.add(this.generateDerivative(parameterSet));
        }
        ArrayList<Double> weights = new ArrayList<Double>();
        for (int i = 0; i < configuration.getParameterSets().size(); ++i) {
            weights.add(1.0 / (double)configuration.getParameterSets().size());
        }
        this.learningCurve = new LinearCombinationFunction(learningCurves, weights);
        this.derivative = new LinearCombinationFunction(derivatives, weights);
        this.dataSetSize = dataSetSize;
    }

    private LinearCombinationFunction generateLearningCurve(LinearCombinationParameterSet parameterSet) {
        ArrayList<UnivariateFunction> functions = new ArrayList<UnivariateFunction>();
        ArrayList<Double> weights = new ArrayList<Double>();
        if (parameterSet.getParameters().containsKey("vapor_pressure")) {
            ParametricFunction vaporPressure = new ParametricFunction(parameterSet.getParameters().get("vapor_pressure")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double b = this.getParams().get("b");
                    double c = this.getParams().get("c");
                    return Math.exp(a + b / x + c * Math.log(x));
                }
            };
            functions.add(vaporPressure);
            weights.add(parameterSet.getWeights().get("vapor_pressure"));
        }
        if (parameterSet.getParameters().containsKey("pow_3")) {
            ParametricFunction pow3 = new ParametricFunction(parameterSet.getParameters().get("pow_3")){

                public double value(double x) {
                    double alpha = this.getParams().get("alpha");
                    double a = this.getParams().get("a");
                    double c = this.getParams().get("c");
                    return c - a * Math.pow(x, -1.0 * alpha);
                }
            };
            functions.add(pow3);
            weights.add(parameterSet.getWeights().get("pow_3"));
        }
        if (parameterSet.getParameters().containsKey("log_log_linear")) {
            ParametricFunction logLogLinear = new ParametricFunction(parameterSet.getParameters().get("log_log_linear")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double b = this.getParams().get("b");
                    return Math.log(a * Math.log(x) + b);
                }
            };
            functions.add(logLogLinear);
            weights.add(parameterSet.getWeights().get("log_log_linear"));
        }
        if (parameterSet.getParameters().containsKey("hill_3")) {
            ParametricFunction hill3 = new ParametricFunction(parameterSet.getParameters().get("hill_3")){

                public double value(double x) {
                    double y = this.getParams().get("y");
                    double eta = this.getParams().get("eta");
                    double kappa = this.getParams().get("kappa");
                    return y * Math.pow(x, eta) / (Math.pow(kappa, eta) + Math.pow(x, eta));
                }
            };
            functions.add(hill3);
            weights.add(parameterSet.getWeights().get("hill_3"));
        }
        if (parameterSet.getParameters().containsKey("log_power")) {
            ParametricFunction logPower = new ParametricFunction(parameterSet.getParameters().get("log_power")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double b = this.getParams().get("b");
                    double c = this.getParams().get("c");
                    return a / (1.0 + Math.pow(x / Math.exp(b), c));
                }
            };
            functions.add(logPower);
            weights.add(parameterSet.getWeights().get("log_power"));
        }
        if (parameterSet.getParameters().containsKey("pow_4")) {
            ParametricFunction pow4 = new ParametricFunction(parameterSet.getParameters().get("pow_4")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double b = this.getParams().get("b");
                    double c = this.getParams().get("c");
                    double alpha = this.getParams().get("alpha");
                    return c - Math.pow(a * x + b, -alpha);
                }
            };
            functions.add(pow4);
            weights.add(parameterSet.getWeights().get("pow_4"));
        }
        if (parameterSet.getParameters().containsKey("mmf")) {
            ParametricFunction mmf = new ParametricFunction(parameterSet.getParameters().get("mmf")){

                public double value(double x) {
                    double alpha = this.getParams().get("alpha");
                    double beta = this.getParams().get("beta");
                    double delta = this.getParams().get("delta");
                    double kappa = this.getParams().get("kappa");
                    return alpha - (alpha - beta) / (1.0 + Math.pow(kappa * x, delta));
                }
            };
            functions.add(mmf);
            weights.add(parameterSet.getWeights().get("mmf"));
        }
        if (parameterSet.getParameters().containsKey("exp_4")) {
            ParametricFunction exp4 = new ParametricFunction(parameterSet.getParameters().get("exp_4")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double b = this.getParams().get("b");
                    double c = this.getParams().get("c");
                    double alpha = this.getParams().get("alpha");
                    return c - Math.exp(-a * Math.pow(x, alpha) + b);
                }
            };
            functions.add(exp4);
            weights.add(parameterSet.getWeights().get("exp_4"));
        }
        if (parameterSet.getParameters().containsKey("janoschek")) {
            ParametricFunction janoscheck = new ParametricFunction(parameterSet.getParameters().get("janoschek")){

                public double value(double x) {
                    double alpha = this.getParams().get("alpha");
                    double beta = this.getParams().get("beta");
                    double delta = this.getParams().get("delta");
                    double kappa = this.getParams().get("kappa");
                    return alpha - (alpha - beta) * Math.exp(-kappa * Math.pow(x, delta));
                }
            };
            functions.add(janoscheck);
            weights.add(parameterSet.getWeights().get("janoschek"));
        }
        if (parameterSet.getParameters().containsKey("weibull")) {
            ParametricFunction weibull = new ParametricFunction(parameterSet.getParameters().get("weibull")){

                public double value(double x) {
                    double alpha = this.getParams().get("alpha");
                    double beta = this.getParams().get("beta");
                    double delta = this.getParams().get("delta");
                    double kappa = this.getParams().get("kappa");
                    return alpha - (alpha - beta) * Math.exp(-1.0 * Math.pow(kappa * x, delta));
                }
            };
            functions.add(weibull);
            weights.add(parameterSet.getWeights().get("weibull"));
        }
        if (parameterSet.getParameters().containsKey("ilog_2")) {
            ParametricFunction ilog2 = new ParametricFunction(parameterSet.getParameters().get("ilog_2")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double c = this.getParams().get("c");
                    return c - a / Math.log(x);
                }
            };
            functions.add(ilog2);
            weights.add(parameterSet.getWeights().get("ilog_2"));
        }
        return new LinearCombinationFunction(functions, weights);
    }

    private LinearCombinationFunction generateDerivative(LinearCombinationParameterSet parameterSet) {
        ArrayList<UnivariateFunction> functions = new ArrayList<UnivariateFunction>();
        ArrayList<Double> weights = new ArrayList<Double>();
        if (parameterSet.getParameters().containsKey("vapor_pressure")) {
            ParametricFunction vaporPressure = new ParametricFunction(parameterSet.getParameters().get("vapor_pressure")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double b = this.getParams().get("b");
                    double c = this.getParams().get("c");
                    return Math.pow(x, c - 2.0) * Math.exp(a + b / x) * (c * x - b);
                }
            };
            functions.add(vaporPressure);
            weights.add(parameterSet.getWeights().get("vapor_pressure"));
        }
        if (parameterSet.getParameters().containsKey("pow_3")) {
            ParametricFunction pow3 = new ParametricFunction(parameterSet.getParameters().get("pow_3")){

                public double value(double x) {
                    double alpha = this.getParams().get("alpha");
                    double a = this.getParams().get("a");
                    return a * alpha * Math.pow(x, -alpha - 1.0);
                }
            };
            functions.add(pow3);
            weights.add(parameterSet.getWeights().get("pow_3"));
        }
        if (parameterSet.getParameters().containsKey("log_log_linear")) {
            ParametricFunction logLogLinear = new ParametricFunction(parameterSet.getParameters().get("log_log_linear")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double b = this.getParams().get("b");
                    return a / (a * x * Math.log(x) + b * x);
                }
            };
            functions.add(logLogLinear);
            weights.add(parameterSet.getWeights().get("log_log_linear"));
        }
        if (parameterSet.getParameters().containsKey("hill_3")) {
            ParametricFunction hill3 = new ParametricFunction(parameterSet.getParameters().get("hill_3")){

                public double value(double x) {
                    double y = this.getParams().get("y");
                    double eta = this.getParams().get("eta");
                    double kappa = this.getParams().get("kappa");
                    return y * eta * Math.pow(kappa, eta) * Math.pow(x, eta - 1.0) / Math.pow(Math.pow(kappa, eta) + Math.pow(x, eta), 2.0);
                }
            };
            functions.add(hill3);
            weights.add(parameterSet.getWeights().get("hill_3"));
        }
        if (parameterSet.getParameters().containsKey("log_power")) {
            ParametricFunction logPower = new ParametricFunction(parameterSet.getParameters().get("log_power")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double b = this.getParams().get("b");
                    double c = this.getParams().get("c");
                    return -1.0 * (a * c * Math.pow(Math.exp(-b) * x, c)) / (x * Math.pow(Math.pow(Math.exp(-b) * x, c) + 1.0, 2.0));
                }
            };
            functions.add(logPower);
            weights.add(parameterSet.getWeights().get("log_power"));
        }
        if (parameterSet.getParameters().containsKey("pow_4")) {
            ParametricFunction pow4 = new ParametricFunction(parameterSet.getParameters().get("pow_4")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double b = this.getParams().get("b");
                    double alpha = this.getParams().get("alpha");
                    return a * alpha * Math.pow(a * x + b, -alpha - 1.0);
                }
            };
            functions.add(pow4);
            weights.add(parameterSet.getWeights().get("pow_4"));
        }
        if (parameterSet.getParameters().containsKey("mmf")) {
            ParametricFunction mmf = new ParametricFunction(parameterSet.getParameters().get("mmf")){

                public double value(double x) {
                    double alpha = this.getParams().get("alpha");
                    double beta = this.getParams().get("beta");
                    double delta = this.getParams().get("delta");
                    double kappa = this.getParams().get("kappa");
                    return delta * (alpha - beta) * Math.pow(kappa * x, delta) / (x * Math.pow(1.0 + Math.pow(kappa * x, delta), 2.0));
                }
            };
            functions.add(mmf);
            weights.add(parameterSet.getWeights().get("mmf"));
        }
        if (parameterSet.getParameters().containsKey("exp_4")) {
            ParametricFunction exp4 = new ParametricFunction(parameterSet.getParameters().get("exp_4")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    double b = this.getParams().get("b");
                    double alpha = this.getParams().get("alpha");
                    return a * alpha * Math.pow(x, alpha - 1.0) * Math.exp(b - a * Math.pow(x, alpha));
                }
            };
            functions.add(exp4);
            weights.add(parameterSet.getWeights().get("exp_4"));
        }
        if (parameterSet.getParameters().containsKey("janoschek")) {
            ParametricFunction janoscheck = new ParametricFunction(parameterSet.getParameters().get("janoschek")){

                public double value(double x) {
                    double alpha = this.getParams().get("alpha");
                    double beta = this.getParams().get("beta");
                    double delta = this.getParams().get("delta");
                    double kappa = this.getParams().get("kappa");
                    return kappa * delta * (alpha - beta) * Math.pow(x, delta - 1.0) * Math.exp(-kappa * Math.pow(x, delta));
                }
            };
            functions.add(janoscheck);
            weights.add(parameterSet.getWeights().get("janoschek"));
        }
        if (parameterSet.getParameters().containsKey("weibull")) {
            ParametricFunction weibull = new ParametricFunction(parameterSet.getParameters().get("weibull")){

                public double value(double x) {
                    double alpha = this.getParams().get("alpha");
                    double beta = this.getParams().get("beta");
                    double delta = this.getParams().get("delta");
                    double kappa = this.getParams().get("kappa");
                    return delta * (alpha - beta) * Math.exp(-1.0 * Math.pow(kappa * x, delta)) * Math.pow(kappa * x, delta) / x;
                }
            };
            functions.add(weibull);
            weights.add(parameterSet.getWeights().get("weibull"));
        }
        if (parameterSet.getParameters().containsKey("ilog_2")) {
            ParametricFunction ilog2 = new ParametricFunction(parameterSet.getParameters().get("ilog_2")){

                public double value(double x) {
                    double a = this.getParams().get("a");
                    return a / (x * Math.pow(Math.log(x), 2.0));
                }
            };
            functions.add(ilog2);
            weights.add(parameterSet.getWeights().get("ilog_2"));
        }
        return new LinearCombinationFunction(functions, weights);
    }

    @Override
    public double getCurveValue(double x) {
        return this.learningCurve.value(x);
    }

    @Override
    public double getSaturationPoint(double epsilon) {
        return this.computeDerivativeRoot(epsilon, -1.0E-4, this.dataSetSize);
    }

    @Override
    public double getDerivativeCurveValue(double x) {
        this.derivative.setOffset(0.0);
        return this.derivative.value(x);
    }

    @Override
    public double getConvergenceValue() {
        int x = (int)this.computeDerivativeRoot(1.0, -1.0E-7, this.dataSetSize * 100);
        return this.getCurveValue(x);
    }

    private double computeDerivativeRoot(double epsilon, double offset, int upperIntervalBoundStart) {
        BrentSolver solver = new BrentSolver(0.0, epsilon);
        double result = -1.0;
        int lowerIntervalBound = 1;
        int upperIntervalBound = upperIntervalBoundStart;
        int retriesLeft = 8;
        this.derivative.setOffset(offset);
        while (retriesLeft > 0 && result == -1.0) {
            try {
                LOG.info("Trying to find root with offset {} in interval [{}/{}]", new Object[]{offset, lowerIntervalBound, upperIntervalBound});
                result = solver.solve(1000, (UnivariateFunction)this.derivative, (double)lowerIntervalBound, (double)upperIntervalBound);
            }
            catch (NoBracketingException e) {
                LOG.warn("Cannot find root in interval [{},{}]: {}", new Object[]{lowerIntervalBound, upperIntervalBound, e.getMessage()});
                LOG.warn("Retries left: {} / {}", (Object)(--retriesLeft), (Object)8);
                upperIntervalBound *= 2;
                lowerIntervalBound *= 2;
            }
        }
        if (result == -1.0) {
            try {
                LOG.info("Trying to find root with offset {} in interval [{}/{}]", new Object[]{offset, lowerIntervalBound, upperIntervalBound});
                result = solver.solve(1000, (UnivariateFunction)this.derivative, 50.0, (double)upperIntervalBound);
            }
            catch (NoBracketingException e) {
                LOG.warn("Cannot find root in interval [{},{}]: {}", new Object[]{lowerIntervalBound, upperIntervalBound, e.getMessage()});
            }
        }
        if (result == -1.0) {
            throw new RuntimeException(String.format("No solution could be found in interval [1,%d]", upperIntervalBound));
        }
        return result;
    }
}

