/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.test.unit.matrix;

import com.aliasi.matrix.SvdMatrix;
import java.util.Random;
import junit.framework.Assert;
import org.junit.Test;

public class SvdMatrixTest {
    static Random RANDOM = new Random();
    static int M = 5;
    static int N = 10;
    static int M2 = 1000;
    static int N2 = 500;
    static int MAX_INCR2 = 100;

    @Test
    public void testFixed() {
        double[][] values = new double[][]{{5.0, 9.0, 2.0}, {3.0, -4.0, 5.0}, {2.0, 5.0, 1.0}, {-8.0, 3.0, 3.0}};
        int m = values.length;
        int n = values[0].length;
        int[][] columnIds = new int[m][n];
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                columnIds[i][j] = j;
            }
        }
        this.assertConverge(m, n, columnIds, values, 3, 0.001);
    }

    @Test
    public void testFull() {
        int[][] columnIds = new int[M][N];
        for (int i = 0; i < M; ++i) {
            for (int j = 0; j < N; ++j) {
                columnIds[i][j] = j;
            }
        }
        double[][] values = new double[M][N];
        for (int i = 0; i < M; ++i) {
            for (int j = 0; j < N; ++j) {
                values[i][j] = this.random(1.0, 5.0);
            }
        }
    }

    @Test
    public void testPartial() {
        int[][] columnIds = new int[M2][];
        double[][] values = new double[M2][];
        for (int i = 0; i < columnIds.length; ++i) {
            int k;
            int[] columnIdsForRowBuf = new int[N2];
            int pos = 0;
            int j = 0;
            while (true) {
                int incr;
                if ((incr = RANDOM.nextInt(MAX_INCR2)) == 0 && j != 0) {
                    ++incr;
                }
                if ((j += incr) >= N2) break;
                columnIdsForRowBuf[pos++] = j;
            }
            columnIds[i] = new int[pos];
            for (k = 0; k < pos; ++k) {
                columnIds[i][k] = columnIdsForRowBuf[k];
            }
            values[i] = new double[pos];
            for (k = 0; k < pos; ++k) {
                values[i][k] = this.random(1.0, 5.0);
            }
        }
    }

    void assertConverge(int numRows, int numCols, int[][] columnIds, double[][] values, int maxOrder, double tolerance) {
        double featureInit = 0.1;
        double initialLearningRate = 0.001;
        double annealingRate = 100000.0;
        double regularization = 0.0;
        double minImprovement = 0.0;
        int minEpochs = 1000;
        int maxEpochs = 1000000;
        SvdMatrix matrix = SvdMatrix.partialSvd(columnIds, values, maxOrder, featureInit, initialLearningRate, annealingRate, regularization, null, minImprovement, minEpochs, maxEpochs);
        for (int i = 0; i < numRows; ++i) {
            for (int j = 0; j < numCols; ++j) {
            }
        }
        double[] singularValues = matrix.singularValues();
        Assert.assertTrue((singularValues[0] >= 0.0 ? 1 : 0) != 0);
        for (int i = 1; i < singularValues.length; ++i) {
            Assert.assertTrue((singularValues[i] <= singularValues[i - 1] ? 1 : 0) != 0);
        }
        double[][] leftSingularVectors = matrix.leftSingularVectors();
        this.assertOrthonormal(leftSingularVectors);
        double[][] rightSingularVectors = matrix.rightSingularVectors();
        this.assertOrthonormal(rightSingularVectors);
        for (int i = 0; i < columnIds.length; ++i) {
            for (int j = 0; j < columnIds[i].length; ++j) {
                int row = i;
                int column = columnIds[i][j];
                double val = values[i][j];
                double estimatedVal = matrix.value(row, column);
                Assert.assertEquals((double)val, (double)estimatedVal, (double)tolerance);
            }
        }
    }

    void assertOrthonormal(double[][] xs) {
        int numCols = xs[0].length;
        for (int j = 0; j < numCols; ++j) {
            this.assertUnitLengthColumn(xs, j);
            for (int k = j + 1; k < numCols; ++k) {
                this.assertOrthogonalColumns("col=" + j + " col2=" + k, xs, j, k);
            }
        }
    }

    void assertUnitLengthColumn(double[][] xs, int j) {
        double sum = 0.0;
        for (int i = 0; i < xs.length; ++i) {
            sum += xs[i][j] * xs[i][j];
        }
        Assert.assertEquals((String)"unit columns", (double)1.0, (double)sum, (double)0.01);
    }

    void assertOrthogonalColumns(String msg, double[][] xs, int i, int j) {
        double sum = 0.0;
        for (int k = 0; k < xs.length; ++k) {
            sum += xs[k][i] * xs[k][j];
        }
        Assert.assertEquals((String)("ortho columns " + msg), (double)0.0, (double)sum, (double)0.01);
    }

    void assertUnitLength(double[] xs) {
        this.assertProduct(xs, xs, 1.0);
    }

    void assertOrthogonal(String msg, double[] xs, double[] ys) {
        this.assertProduct(msg, xs, ys, 0.0);
    }

    void assertProduct(double[] xs, double[] ys, double expected) {
        this.assertProduct("", xs, ys, expected);
    }

    void assertProduct(String msg, double[] xs, double[] ys, double expected) {
        double sum = 0.0;
        for (int i = 0; i < xs.length; ++i) {
            sum += xs[i] * ys[i];
        }
        Assert.assertEquals((String)msg, (double)expected, (double)sum, (double)0.01);
    }

    double random(double min, double max) {
        return min + (max - min) * RANDOM.nextDouble();
    }

    void printMatrix(double[][] xs) {
        for (int i = 0; i < xs.length; ++i) {
            for (int j = 0; j < xs[i].length; ++j) {
                if (j > 0) {
                    System.out.print(", ");
                }
                this.printNumber(xs[i][j]);
            }
            System.out.println();
        }
    }

    void printNumber(double x) {
        System.out.printf("% 7.3f", x);
    }
}

