/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.utils.svd;

import java.util.Arrays;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.SingularValueDecomposition;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.spark.SparkConverter;
import org.broadinstitute.hellbender.utils.svd.SVD;
import org.broadinstitute.hellbender.utils.svd.SimpleSVD;
import org.broadinstitute.hellbender.utils.svd.SingularValueDecomposer;

public final class SparkSingularValueDecomposer
implements SingularValueDecomposer {
    private static final double EPS = 1.0E-32;
    private static final Logger logger = LogManager.getLogger(SparkSingularValueDecomposer.class);
    private static final int NUM_SLICES = 60;
    private final JavaSparkContext sc;

    public SparkSingularValueDecomposer(JavaSparkContext sc) {
        Utils.nonNull(sc, "Cannot perform Spark MLLib SVD using a null JavaSparkContext.");
        this.sc = sc;
    }

    @Override
    public SVD createSVD(RealMatrix realMat) {
        Utils.nonNull(realMat, "Cannot perform Spark MLLib SVD on a null matrix.");
        RowMatrix mat = SparkConverter.convertRealMatrixToSparkRowMatrix(this.sc, realMat, 60);
        SingularValueDecomposition svd = mat.computeSVD((int)mat.numCols(), true, 1.0E-9);
        RowMatrix u = (RowMatrix)svd.U();
        Vector s = svd.s();
        Matrix v = ((Matrix)svd.V()).transpose();
        logger.info("Converting distributed Spark matrix to local matrix...");
        RealMatrix uReal = SparkConverter.convertSparkRowMatrixToRealMatrix(u, realMat.getRowDimension());
        logger.info("Done converting distributed Spark matrix to local matrix...");
        logger.info("Converting Spark matrix to local matrix...");
        RealMatrix vReal = SparkConverter.convertSparkMatrixToRealMatrix(v);
        logger.info("Done converting Spark matrix to local matrix...");
        double[] singularValues = s.toArray();
        logger.info("Calculating the pseudoinverse...");
        logger.info("Pinv: calculating tolerance...");
        double tolerance = (double)Math.max(realMat.getColumnDimension(), realMat.getRowDimension()) * realMat.getNorm() * 1.0E-32;
        logger.info("Pinv: inverting the singular values (with tolerance) and creating a diagonal matrix...");
        double[] invS = Arrays.stream(singularValues).map(sv -> SparkSingularValueDecomposer.invertSVWithTolerance(sv, tolerance)).toArray();
        Matrix invSMat = Matrices.diag((Vector)Vectors.dense((double[])invS));
        logger.info("Pinv: Multiplying V * invS * U' to get the pinv (using pinv transpose = U * invS' * V') ...");
        RowMatrix pinvT = u.multiply(invSMat).multiply(v);
        logger.info("Pinv: Converting back to local matrix ...");
        RealMatrix pinv = SparkConverter.convertSparkRowMatrixToRealMatrix(pinvT, realMat.getRowDimension()).transpose();
        logger.info("Done calculating the pseudoinverse and converting it...");
        return new SimpleSVD(uReal, s.toArray(), vReal, pinv);
    }

    private static double invertSVWithTolerance(double sv, double tol) {
        if (sv <= tol) {
            return 0.0;
        }
        return 1.0 / sv;
    }
}

