/*
 * Decompiled with CFR 0.152.
 */
package org.tech.vineyard.linear.algebra;

import org.tech.vineyard.linear.algebra.LinearEquationSystem;
import org.tech.vineyard.linear.algebra.Matrix;
import org.tech.vineyard.linear.algebra.Vector;

public class LUDecomposition
implements LinearEquationSystem {
    private Matrix L;
    private Matrix U;
    private Matrix P;

    @Override
    public Vector solution(Matrix A, Vector b) {
        this.decompose(A);
        return this.U.upperTriangularSolve(this.inverseLowerTriangular(this.L).multiply(this.P).multiply(b));
    }

    public void decompose(Matrix A) {
        this.P = Matrix.identity(A.n);
        this.L = Matrix.identity(A.n);
        for (int n = 0; n < A.n; ++n) {
            A = this.L(A, n).multiply(A);
        }
        this.U = A;
    }

    public Matrix L() {
        return this.L;
    }

    public Matrix U() {
        return this.U;
    }

    private Matrix L(Matrix A, int n) {
        if (A.M[n][n] == 0.0) {
            int n2 = this.findPivot(A, n);
            this.swapRows(A, n, n2);
            this.swapRows(this.P, n, n2);
        }
        Matrix L = Matrix.identity(A.n);
        for (int i = n + 1; i < A.n; ++i) {
            L.M[i][n] = -A.M[i][n] / A.M[n][n];
            this.L.M[i][n] = -L.M[i][n];
        }
        return L;
    }

    private void swapRows(Matrix m, int i1, int i2) {
        double[] tmp = m.M[i1];
        m.M[i1] = m.M[i2];
        m.M[i2] = tmp;
    }

    private int findPivot(Matrix m, int n) {
        for (int i = n + 1; i < m.n; ++i) {
            if (m.M[i][n] == 0.0) continue;
            return i;
        }
        return -1;
    }

    private Matrix inverseLowerTriangular(Matrix m) {
        Matrix inverse = Matrix.identity(m.n);
        for (int k = 1; k < m.n; ++k) {
            int j = 0;
            int i = k;
            while (i < m.n) {
                inverse.M[i][j] = -this.dotProduct(m, inverse, i, j, j, i - 1);
                ++i;
                ++j;
            }
        }
        return inverse;
    }

    private double dotProduct(Matrix m, Matrix n, int i, int j, int start, int end) {
        double sum = 0.0;
        for (int k = start; k <= end; ++k) {
            sum += m.M[i][k] * n.M[k][j];
        }
        return sum;
    }
}

