/*
 * Decompiled with CFR 0.152.
 */
package com.opengamma.strata.math.impl.statistics.leastsquare;

import com.google.common.collect.Lists;
import com.google.common.primitives.Doubles;
import com.opengamma.strata.collect.ArgChecker;
import com.opengamma.strata.collect.DoubleArrayMath;
import com.opengamma.strata.collect.array.DoubleArray;
import com.opengamma.strata.collect.array.DoubleMatrix;
import com.opengamma.strata.collect.array.Matrix;
import com.opengamma.strata.math.MathUtils;
import com.opengamma.strata.math.impl.linearalgebra.SVDecompositionCommons;
import com.opengamma.strata.math.impl.matrix.CommonsMatrixAlgebra;
import com.opengamma.strata.math.impl.matrix.MatrixAlgebra;
import com.opengamma.strata.math.impl.statistics.leastsquare.GeneralizedLeastSquareResults;
import com.opengamma.strata.math.linearalgebra.Decomposition;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import org.apache.commons.math3.util.CombinatoricsUtils;

public class GeneralizedLeastSquare {
    private final Decomposition<?> _decomposition = new SVDecompositionCommons();
    private final MatrixAlgebra _algebra = new CommonsMatrixAlgebra();

    public <T> GeneralizedLeastSquareResults<T> solve(T[] x, double[] y, double[] sigma, List<Function<T, Double>> basisFunctions) {
        return this.solve(x, y, sigma, basisFunctions, 0.0, 0);
    }

    public <T> GeneralizedLeastSquareResults<T> solve(T[] x, double[] y, double[] sigma, List<Function<T, Double>> basisFunctions, double lambda, int differenceOrder) {
        ArgChecker.notNull(x, (String)"x null");
        ArgChecker.notNull((Object)y, (String)"y null");
        ArgChecker.notNull((Object)sigma, (String)"sigma null");
        ArgChecker.notEmpty(basisFunctions, (String)"empty basisFunctions");
        int n = x.length;
        ArgChecker.isTrue((n > 0 ? 1 : 0) != 0, (String)"no data");
        ArgChecker.isTrue((y.length == n ? 1 : 0) != 0, (String)"y wrong length");
        ArgChecker.isTrue((sigma.length == n ? 1 : 0) != 0, (String)"sigma wrong length");
        ArgChecker.isTrue((lambda >= 0.0 ? 1 : 0) != 0, (String)"negative lambda");
        ArgChecker.isTrue((differenceOrder >= 0 ? 1 : 0) != 0, (String)"difference order");
        ArrayList lx = Lists.newArrayList((Object[])x);
        ArrayList ly = Lists.newArrayList((Iterable)Doubles.asList((double[])y));
        ArrayList lsigma = Lists.newArrayList((Iterable)Doubles.asList((double[])sigma));
        return this.solveImp(lx, ly, lsigma, basisFunctions, lambda, differenceOrder);
    }

    GeneralizedLeastSquareResults<Double> solve(double[] x, double[] y, double[] sigma, List<Function<Double, Double>> basisFunctions, double lambda, int differenceOrder) {
        return this.solve((T[])DoubleArrayMath.toObject((double[])x), y, sigma, basisFunctions, lambda, differenceOrder);
    }

    public <T> GeneralizedLeastSquareResults<T> solve(List<T> x, List<Double> y, List<Double> sigma, List<Function<T, Double>> basisFunctions) {
        return this.solve(x, y, sigma, basisFunctions, 0.0, 0);
    }

    public <T> GeneralizedLeastSquareResults<T> solve(List<T> x, List<Double> y, List<Double> sigma, List<Function<T, Double>> basisFunctions, double lambda, int differenceOrder) {
        ArgChecker.notEmpty(x, (String)"empty measurement points");
        ArgChecker.notEmpty(y, (String)"empty measurement values");
        ArgChecker.notEmpty(sigma, (String)"empty measurement errors");
        ArgChecker.notEmpty(basisFunctions, (String)"empty basisFunctions");
        int n = x.size();
        ArgChecker.isTrue((n > 0 ? 1 : 0) != 0, (String)"no data");
        ArgChecker.isTrue((y.size() == n ? 1 : 0) != 0, (String)"y wrong length");
        ArgChecker.isTrue((sigma.size() == n ? 1 : 0) != 0, (String)"sigma wrong length");
        ArgChecker.isTrue((lambda >= 0.0 ? 1 : 0) != 0, (String)"negative lambda");
        ArgChecker.isTrue((differenceOrder >= 0 ? 1 : 0) != 0, (String)"difference order");
        return this.solveImp(x, y, sigma, basisFunctions, lambda, differenceOrder);
    }

    public <T> GeneralizedLeastSquareResults<T> solve(List<T> x, List<Double> y, List<Double> sigma, List<Function<T, Double>> basisFunctions, int[] sizes, double[] lambda, int[] differenceOrder) {
        ArgChecker.notEmpty(x, (String)"empty measurement points");
        ArgChecker.notEmpty(y, (String)"empty measurement values");
        ArgChecker.notEmpty(sigma, (String)"empty measurement errors");
        ArgChecker.notEmpty(basisFunctions, (String)"empty basisFunctions");
        int n = x.size();
        ArgChecker.isTrue((n > 0 ? 1 : 0) != 0, (String)"no data");
        ArgChecker.isTrue((y.size() == n ? 1 : 0) != 0, (String)"y wrong length");
        ArgChecker.isTrue((sigma.size() == n ? 1 : 0) != 0, (String)"sigma wrong length");
        int dim = sizes.length;
        ArgChecker.isTrue((dim == lambda.length ? 1 : 0) != 0, (String)"number of penalty functions {} must be equal to number of directions {}", (Object[])new Object[]{lambda.length, dim});
        ArgChecker.isTrue((dim == differenceOrder.length ? 1 : 0) != 0, (String)"number of difference order {} must be equal to number of directions {}", (Object[])new Object[]{differenceOrder.length, dim});
        for (int i = 0; i < dim; ++i) {
            ArgChecker.isTrue((sizes[i] > 0 ? 1 : 0) != 0, (String)"sizes must be >= 1");
            ArgChecker.isTrue((lambda[i] >= 0.0 ? 1 : 0) != 0, (String)"negative lambda");
            ArgChecker.isTrue((differenceOrder[i] >= 0 ? 1 : 0) != 0, (String)"difference order");
        }
        return this.solveImp(x, y, sigma, basisFunctions, sizes, lambda, differenceOrder);
    }

    private <T> GeneralizedLeastSquareResults<T> solveImp(List<T> x, List<Double> y, List<Double> sigma, List<Function<T, Double>> basisFunctions, double lambda, int differenceOrder) {
        int k;
        int i;
        int n = x.size();
        int m = basisFunctions.size();
        double[] b = new double[m];
        double[] invSigmaSqr = new double[n];
        double[][] f = new double[m][n];
        for (i = 0; i < n; ++i) {
            double temp = sigma.get(i);
            ArgChecker.isTrue((temp > 0.0 ? 1 : 0) != 0, (String)"sigma must be greater than zero");
            invSigmaSqr[i] = 1.0 / temp / temp;
        }
        for (i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                f[i][j] = basisFunctions.get(i).apply(x.get(j));
            }
        }
        for (i = 0; i < m; ++i) {
            double sum = 0.0;
            for (k = 0; k < n; ++k) {
                sum += y.get(k) * f[i][k] * invSigmaSqr[k];
            }
            b[i] = sum;
        }
        DoubleArray mb = DoubleArray.copyOf((double[])b);
        DoubleMatrix ma = this.getAMatrix(f, invSigmaSqr);
        if (lambda > 0.0) {
            DoubleMatrix d = this.getDiffMatrix(m, differenceOrder);
            ma = (DoubleMatrix)this._algebra.add((Matrix)ma, this._algebra.scale((Matrix)d, lambda));
        }
        Object decmp = this._decomposition.apply(ma);
        DoubleArray w = decmp.solve(mb);
        DoubleMatrix covar = decmp.solve(DoubleMatrix.identity((int)m));
        double chiSq = 0.0;
        for (i = 0; i < n; ++i) {
            double temp = 0.0;
            for (k = 0; k < m; ++k) {
                temp += w.get(k) * f[k][i];
            }
            chiSq += MathUtils.pow2(y.get(i) - temp) * invSigmaSqr[i];
        }
        return new GeneralizedLeastSquareResults<T>(basisFunctions, chiSq, w, covar);
    }

    private <T> GeneralizedLeastSquareResults<T> solveImp(List<T> x, List<Double> y, List<Double> sigma, List<Function<T, Double>> basisFunctions, int[] sizes, double[] lambda, int[] differenceOrder) {
        int k;
        int i;
        int dim = sizes.length;
        int n = x.size();
        int m = basisFunctions.size();
        double[] b = new double[m];
        double[] invSigmaSqr = new double[n];
        double[][] f = new double[m][n];
        for (i = 0; i < n; ++i) {
            double temp = sigma.get(i);
            ArgChecker.isTrue((temp > 0.0 ? 1 : 0) != 0, (String)"sigma must be great than zero");
            invSigmaSqr[i] = 1.0 / temp / temp;
        }
        for (i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                f[i][j] = basisFunctions.get(i).apply(x.get(j));
            }
        }
        for (i = 0; i < m; ++i) {
            double sum = 0.0;
            for (k = 0; k < n; ++k) {
                sum += y.get(k) * f[i][k] * invSigmaSqr[k];
            }
            b[i] = sum;
        }
        DoubleArray mb = DoubleArray.copyOf((double[])b);
        DoubleMatrix ma = this.getAMatrix(f, invSigmaSqr);
        for (i = 0; i < dim; ++i) {
            if (!(lambda[i] > 0.0)) continue;
            DoubleMatrix d = this.getDiffMatrix(sizes, differenceOrder[i], i);
            ma = (DoubleMatrix)this._algebra.add((Matrix)ma, this._algebra.scale((Matrix)d, lambda[i]));
        }
        Object decmp = this._decomposition.apply(ma);
        DoubleArray w = decmp.solve(mb);
        DoubleMatrix covar = decmp.solve(DoubleMatrix.identity((int)m));
        double chiSq = 0.0;
        for (i = 0; i < n; ++i) {
            double temp = 0.0;
            for (k = 0; k < m; ++k) {
                temp += w.get(k) * f[k][i];
            }
            chiSq += MathUtils.pow2(y.get(i) - temp) * invSigmaSqr[i];
        }
        return new GeneralizedLeastSquareResults<T>(basisFunctions, chiSq, w, covar);
    }

    private DoubleMatrix getAMatrix(double[][] funcMatrix, double[] invSigmaSqr) {
        int m = funcMatrix.length;
        int n = funcMatrix[0].length;
        double[][] a = new double[m][m];
        for (int i = 0; i < m; ++i) {
            double sum = 0.0;
            for (int k = 0; k < n; ++k) {
                sum += MathUtils.pow2(funcMatrix[i][k]) * invSigmaSqr[k];
            }
            a[i][i] = sum;
            for (int j = i + 1; j < m; ++j) {
                sum = 0.0;
                for (int k = 0; k < n; ++k) {
                    sum += funcMatrix[i][k] * funcMatrix[j][k] * invSigmaSqr[k];
                }
                a[i][j] = sum;
                a[j][i] = sum;
            }
        }
        return DoubleMatrix.copyOf((double[][])a);
    }

    private DoubleMatrix getDiffMatrix(int m, int k) {
        int i;
        ArgChecker.isTrue((k < m ? 1 : 0) != 0, (String)"difference order too high");
        double[][] data = new double[m][m];
        if (m == 0) {
            return DoubleMatrix.copyOf((double[][])data);
        }
        int[] coeff = new int[k + 1];
        int sign = 1;
        for (i = k; i >= 0; --i) {
            coeff[i] = (int)((long)sign * CombinatoricsUtils.binomialCoefficient((int)k, (int)i));
            sign *= -1;
        }
        for (i = k; i < m; ++i) {
            for (int j = 0; j < k + 1; ++j) {
                data[i][j + i - k] = coeff[j];
            }
        }
        DoubleMatrix d = DoubleMatrix.copyOf((double[][])data);
        DoubleMatrix dt = this._algebra.getTranspose((Matrix)d);
        return (DoubleMatrix)this._algebra.multiply((Matrix)dt, (Matrix)d);
    }

    private DoubleMatrix getDiffMatrix(int[] size, int k, int indices) {
        int j;
        int dim = size.length;
        DoubleMatrix d = this.getDiffMatrix(size[indices], k);
        int preProduct = 1;
        int postProduct = 1;
        for (j = indices + 1; j < dim; ++j) {
            preProduct *= size[j];
        }
        for (j = 0; j < indices; ++j) {
            postProduct *= size[j];
        }
        DoubleMatrix temp = d;
        if (preProduct != 1) {
            temp = (DoubleMatrix)this._algebra.kroneckerProduct((Matrix)DoubleMatrix.identity((int)preProduct), (Matrix)temp);
        }
        if (postProduct != 1) {
            temp = (DoubleMatrix)this._algebra.kroneckerProduct((Matrix)temp, (Matrix)DoubleMatrix.identity((int)postProduct));
        }
        return temp;
    }
}

