/*
 * Decompiled with CFR 0.152.
 */
package smile.math;

import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.util.function.DifferentiableMultivariateFunction;

public class LevenbergMarquardt {
    private static final Logger logger = LoggerFactory.getLogger(LevenbergMarquardt.class);
    public final double[] parameters;
    public final double[] fittedValues;
    public final double[] residuals;
    public final double sse;

    LevenbergMarquardt(double[] parameters, double[] fittedValues, double[] residuals, double sse) {
        this.parameters = parameters;
        this.fittedValues = fittedValues;
        this.residuals = residuals;
        this.sse = sse;
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction func, double[] x, double[] y, double[] p) {
        return LevenbergMarquardt.fit(func, x, y, p, 1.0E-4, 20);
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction func, double[] x, double[] y, double[] p, double stol, int maxIter) {
        double ss;
        if (stol <= 0.0) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + stol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = x.length;
        int d = p.length;
        double[] pbest = new double[d + 1];
        double[] pprev = new double[d + 1];
        double[] pnew = new double[d + 1];
        System.arraycopy(p, 0, pbest, 0, d);
        double[] f = new double[n];
        double[] r = new double[n];
        double[] g = new double[d];
        double[] gse = new double[d];
        double[] dp = new double[d];
        double[] chg = new double[d];
        double[] norm = new double[d];
        Arrays.fill(norm, 1.0);
        Matrix J = new Matrix(n, d);
        double epsLlast = 1.0;
        double[] epstab = new double[]{0.1, 1.0, 100.0, 10000.0, 1000000.0};
        for (int iter = 1; iter <= maxIter; ++iter) {
            System.arraycopy(pbest, 0, pprev, 0, d);
            ss = 0.0;
            for (int i = 0; i < n; ++i) {
                pprev[d] = x[i];
                double fi = func.g(pprev, dp);
                r[i] = y[i] - fi;
                ss += r[i] * r[i];
                for (int j = 0; j < d; ++j) {
                    J.set(i, j, dp[j]);
                }
            }
            double sbest = ss;
            double sgoal = (1.0 - stol) * sbest;
            for (int j = 0; j < d; ++j) {
                int i;
                double nrm = 0.0;
                for (i = 0; i < n; ++i) {
                    double dpi = J.get(i, j);
                    nrm += dpi * dpi;
                }
                norm[j] = nrm > 0.0 ? 1.0 / Math.sqrt(nrm) : 1.0;
                for (i = 0; i < n; ++i) {
                    J.mul(i, j, norm[j]);
                }
            }
            Matrix.SVD svd = J.svd(true, true);
            double[] s = svd.s;
            double s2 = MathEx.dot(s, s);
            Matrix U = svd.U;
            Matrix V = svd.V;
            U.tv(r, g);
            for (double eps : epstab) {
                int i;
                int j;
                double epsL = Math.max(epsLlast * eps, 1.0E-7);
                double se = Math.sqrt(s2 + epsL);
                for (j = 0; j < d; ++j) {
                    gse[j] = g[j] / se;
                }
                V.mv(gse, chg);
                for (j = 0; j < d; ++j) {
                    int n2 = j;
                    chg[n2] = chg[n2] * norm[j];
                }
                for (i = 0; i < d; ++i) {
                    pnew[i] = chg[i] + pprev[i];
                }
                ss = 0.0;
                for (i = 0; i < n; ++i) {
                    pnew[d] = x[i];
                    double fi = func.f(pnew);
                    double ri = y[i] - fi;
                    ss += ri * ri;
                }
                if (ss < sbest) {
                    System.arraycopy(pnew, 0, pbest, 0, d);
                    sbest = ss;
                }
                if (!(ss <= sgoal)) continue;
                epsLlast = epsL;
                break;
            }
            logger.info("SSE after {} iterations: {}", (Object)iter, (Object)sbest);
            if (!(ss < MathEx.EPSILON) && !(ss > sgoal)) continue;
            logger.info("converges on SSE after {} iterations", (Object)iter);
            break;
        }
        double[] pfit = new double[d];
        System.arraycopy(pbest, 0, pfit, 0, d);
        ss = 0.0;
        for (int i = 0; i < n; ++i) {
            pbest[d] = x[i];
            f[i] = func.f(pbest);
            r[i] = y[i] - f[i];
            ss += r[i] * r[i];
        }
        return new LevenbergMarquardt(pfit, f, r, ss);
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction func, double[][] x, double[] y, double[] p) {
        return LevenbergMarquardt.fit(func, x, y, p, 1.0E-4, 20);
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction func, double[][] x, double[] y, double[] p, double stol, int maxIter) {
        double ss;
        if (stol <= 0.0) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + stol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = x.length;
        int m = x[0].length;
        int d = p.length;
        double[] pbest = new double[d + m];
        double[] pprev = new double[d + m];
        double[] pnew = new double[d + m];
        System.arraycopy(p, 0, pbest, 0, d);
        double[] f = new double[n];
        double[] r = new double[n];
        double[] g = new double[d];
        double[] gse = new double[d];
        double[] dp = new double[d];
        double[] chg = new double[d];
        double[] norm = new double[d];
        Arrays.fill(norm, 1.0);
        Matrix J = new Matrix(n, d);
        double epsLlast = 1.0;
        double[] epstab = new double[]{0.1, 1.0, 100.0, 10000.0, 1000000.0};
        for (int iter = 1; iter <= maxIter; ++iter) {
            System.arraycopy(pbest, 0, pprev, 0, d);
            ss = 0.0;
            for (int i = 0; i < n; ++i) {
                System.arraycopy(x[i], 0, pprev, d, m);
                double fi = func.g(pprev, dp);
                r[i] = y[i] - fi;
                ss += r[i] * r[i];
                for (int j = 0; j < d; ++j) {
                    J.set(i, j, dp[j]);
                }
            }
            double sbest = ss;
            double sgoal = (1.0 - stol) * sbest;
            for (int j = 0; j < d; ++j) {
                int i;
                double nrm = 0.0;
                for (i = 0; i < n; ++i) {
                    double dpi = J.get(i, j);
                    nrm += dpi * dpi;
                }
                norm[j] = nrm > 0.0 ? 1.0 / Math.sqrt(nrm) : 1.0;
                for (i = 0; i < n; ++i) {
                    J.mul(i, j, norm[j]);
                }
            }
            Matrix.SVD svd = J.svd(true, true);
            double[] s = svd.s;
            double s2 = MathEx.dot(s, s);
            Matrix U = svd.U;
            Matrix V = svd.V;
            U.tv(r, g);
            for (double eps : epstab) {
                int i;
                int j;
                double epsL = Math.max(epsLlast * eps, 1.0E-7);
                double se = Math.sqrt(s2 + epsL);
                for (j = 0; j < d; ++j) {
                    gse[j] = g[j] / se;
                }
                V.mv(gse, chg);
                for (j = 0; j < d; ++j) {
                    int n2 = j;
                    chg[n2] = chg[n2] * norm[j];
                }
                for (i = 0; i < d; ++i) {
                    pnew[i] = chg[i] + pprev[i];
                }
                ss = 0.0;
                for (i = 0; i < n; ++i) {
                    System.arraycopy(x[i], 0, pnew, d, m);
                    double fi = func.f(pnew);
                    double ri = y[i] - fi;
                    ss += ri * ri;
                }
                if (ss < sbest) {
                    System.arraycopy(pnew, 0, pbest, 0, d);
                    sbest = ss;
                }
                if (!(ss <= sgoal)) continue;
                epsLlast = epsL;
                break;
            }
            logger.info("SSE after {} iterations: {}", (Object)iter, (Object)sbest);
            if (!(ss < MathEx.EPSILON) && !(ss > sgoal)) continue;
            logger.info("converges on SSE after {} iterations", (Object)iter);
            break;
        }
        double[] pfit = new double[d];
        System.arraycopy(pbest, 0, pfit, 0, d);
        ss = 0.0;
        for (int i = 0; i < n; ++i) {
            System.arraycopy(x[i], 0, pbest, d, m);
            f[i] = func.f(pbest);
            r[i] = y[i] - f[i];
            ss += r[i] * r[i];
        }
        return new LevenbergMarquardt(pfit, f, r, ss);
    }
}

