/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.mlplan.metamining.similaritymeasures;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.mlplan.metamining.similaritymeasures.IHeterogenousSimilarityMeasureComputer;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.GradientDescent;
import java.util.ArrayList;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class F3Optimizer
implements IHeterogenousSimilarityMeasureComputer {
    private static final Logger logger = LoggerFactory.getLogger(F3Optimizer.class);
    private static final double ALPHA_START = 1.0E-9;
    private static final double ALPHA_MAX = 1.0E-7;
    private static final int ITERATIONS_PER_PROBE = 100;
    private static final int LIMIT = 1;
    private static final double MAX_DESIRED_ERROR = 0.0;
    private final double mu;
    private INDArray R;
    private INDArray X;
    private INDArray W;
    private INDArray U;
    private INDArray V;

    public F3Optimizer(double mu) {
        this.mu = mu;
    }

    @Override
    public void build(INDArray X, INDArray W, INDArray R) {
        this.R = R;
        this.W = W;
        this.X = X;
        int n = X.rows();
        final int d = X.columns();
        int m = W.rows();
        final int l = W.columns();
        boolean numberOfImplicitFeatures = true;
        System.out.println("X = " + X + " (" + n + " x " + X.columns() + ")");
        System.out.println("W = " + W + " (" + m + " x " + W.columns() + ")");
        boolean succesfullyBooted = false;
        DoubleVector currentSolutionAsVector = this.getRandomInitSolution(d, l, 1);
        SetUtil.Pair<INDArray, INDArray> currentUAndVAsMatrix = this.vector2matrices(currentSolutionAsVector, d, 1, l, 1);
        System.out.println("randomly initialized U = " + currentUAndVAsMatrix.getX() + " (" + d + " x " + 1 + ")");
        System.out.println("randomly initialized V = " + currentUAndVAsMatrix.getY() + " (" + l + " x " + 1 + ")");
        double currentCost = this.getCost((INDArray)currentUAndVAsMatrix.getX(), (INDArray)currentUAndVAsMatrix.getY());
        System.out.println("loss of randomly initialized U and V: " + currentCost);
        CostFunction cf = new CostFunction(){

            public CostGradientTuple evaluateCost(DoubleVector input) {
                SetUtil.Pair<INDArray, INDArray> UAndV = F3Optimizer.this.vector2matrices(input, d, 1, l, 1);
                INDArray U = (INDArray)UAndV.getX();
                INDArray V = (INDArray)UAndV.getY();
                assert (U.rows() == d && U.columns() == 1) : "Incorrect shape of U: (" + U.rows() + " x " + U.columns() + ") instead of (" + d + " x " + 1 + ")";
                assert (V.rows() == l && V.columns() == 1) : "Incorrect shape of V: (" + V.rows() + " x " + V.columns() + ") instead of (" + l + " x " + 1 + ")";
                double cost = F3Optimizer.this.getCost(U, V);
                INDArray gradientMatrixForU = F3Optimizer.this.getGradientAsMatrix(U, V, true);
                INDArray gradientMatrixForV = F3Optimizer.this.getGradientAsMatrix(U, V, false);
                CostGradientTuple cgt = new CostGradientTuple(cost, F3Optimizer.this.matrices2vector(gradientMatrixForU, gradientMatrixForV));
                return cgt;
            }
        };
        double alpha = 1.0E-9;
        int turnsWithoutImprovement = 0;
        while (currentCost > 0.0) {
            double lastCost = currentCost;
            DoubleVector lastSolution = currentSolutionAsVector;
            GradientDescent gd = new GradientDescent(alpha, 1.0);
            currentSolutionAsVector = gd.minimize(cf, currentSolutionAsVector, 100, false);
            logger.debug("Produced gd solution vector {}", (Object)currentSolutionAsVector);
            boolean hasNanEntry = false;
            for (int i = 0; i < currentSolutionAsVector.getLength(); ++i) {
                if (!Double.valueOf(currentSolutionAsVector.get(i)).equals(Double.NaN)) continue;
                hasNanEntry = true;
                break;
            }
            if (hasNanEntry) {
                currentSolutionAsVector = lastSolution;
                currentCost = lastCost;
                if (alpha > 1.0E-20) {
                    alpha /= 2.0;
                }
            } else {
                currentUAndVAsMatrix = this.vector2matrices(currentSolutionAsVector, d, 1, l, 1);
                currentCost = this.getCost((INDArray)currentUAndVAsMatrix.getX(), (INDArray)currentUAndVAsMatrix.getY());
                if (lastCost <= currentCost) {
                    currentSolutionAsVector = lastSolution;
                    currentCost = lastCost;
                    if (lastCost == currentCost) {
                        ++turnsWithoutImprovement;
                        alpha *= 2.0;
                    } else if (alpha > 1.0E-20) {
                        alpha /= 2.0;
                    }
                    if (turnsWithoutImprovement > 10) {
                        System.out.println("No further improvement, canceling");
                        break;
                    }
                } else {
                    if (!succesfullyBooted) {
                        succesfullyBooted = true;
                    }
                    turnsWithoutImprovement = 0;
                    alpha *= 2.0;
                }
                alpha = Math.min(alpha, 1.0E-7);
                System.out.println(currentCost + " (alpha = " + alpha + ")");
            }
            if (succesfullyBooted) continue;
            currentSolutionAsVector = this.getRandomInitSolution(d, l, 1);
            currentUAndVAsMatrix = this.vector2matrices(currentSolutionAsVector, d, 1, l, 1);
            currentCost = this.getCost((INDArray)currentUAndVAsMatrix.getX(), (INDArray)currentUAndVAsMatrix.getY());
            alpha = 1.0E-9;
            logger.info("Rebooting approach with solution vector {} that has cost {}", (Object)currentSolutionAsVector, (Object)currentCost);
        }
        this.U = (INDArray)currentUAndVAsMatrix.getX();
        this.V = (INDArray)currentUAndVAsMatrix.getY();
        System.out.println("Finished learning");
        System.out.println("U = " + this.U);
        System.out.println("V = " + this.V);
    }

    private DoubleVector getRandomInitSolution(int d, int l, int numberOfImplicitFeatures) {
        int j;
        int i;
        double[] denseVector = new double[(d + l) * numberOfImplicitFeatures];
        int c = 0;
        for (i = 0; i < d; ++i) {
            for (j = 0; j < numberOfImplicitFeatures; ++j) {
                denseVector[c++] = (Math.random() - 0.5) * 100.0;
            }
        }
        for (i = 0; i < l; ++i) {
            for (j = 0; j < numberOfImplicitFeatures; ++j) {
                denseVector[c++] = (Math.random() - 0.5) * 100.0;
            }
        }
        DenseDoubleVector currentSolutionAsVector = new DenseDoubleVector(denseVector);
        return currentSolutionAsVector;
    }

    public INDArray vector2matrix(DoubleVector vector, int m, int n) {
        double[] inputs = new double[vector.getLength()];
        for (int i = 0; i < vector.getLength(); ++i) {
            inputs[i] = vector.get(i);
        }
        return Nd4j.create((double[])inputs, (int[])new int[]{m, n});
    }

    public SetUtil.Pair<INDArray, INDArray> vector2matrices(DoubleVector vector, int n, int d, int m, int l) {
        DoubleVector inputForU = vector.sliceByLength(0, n * d);
        DoubleVector inputForV = vector.sliceByLength(n * d, vector.getLength() - inputForU.getLength());
        INDArray U = this.vector2matrix(inputForU, n, d);
        INDArray V = this.vector2matrix(inputForV, m, l);
        return new SetUtil.Pair((Object)U, (Object)V);
    }

    public DoubleVector matrix2vector(INDArray matrix) {
        int m = matrix.rows();
        int n = matrix.columns();
        double[] denseVector = new double[m * n];
        int c = 0;
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                denseVector[c++] = matrix.getDouble(i, j);
            }
        }
        return new DenseDoubleVector(denseVector);
    }

    public DoubleVector matrices2vector(INDArray ... matrices) {
        ArrayList<DoubleVector> vectors = new ArrayList<DoubleVector>();
        int length = 0;
        for (INDArray matrix : matrices) {
            DoubleVector vector = this.matrix2vector(matrix);
            vectors.add(vector);
            length += vector.getLength();
        }
        double[] collapsed = new double[length];
        int c = 0;
        for (DoubleVector vector : vectors) {
            for (int i = 0; i < vector.getLength(); ++i) {
                collapsed[c++] = vector.get(i);
            }
        }
        return new DenseDoubleVector(collapsed);
    }

    public double getCost(INDArray U, INDArray V) {
        INDArray Z1 = this.X.mmul(U);
        INDArray Z2 = this.W.mmul(V).transpose();
        INDArray Z = Z1.mmul(Z2);
        INDArray Q = this.R.sub(Z);
        return this.getSquaredFrobeniusNorm(Q) + this.mu * (this.getSquaredFrobeniusNorm(U) + this.getSquaredFrobeniusNorm(V));
    }

    public double getSquaredFrobeniusNorm(INDArray matrix) {
        double norm = 0.0;
        int m = matrix.rows();
        int n = matrix.columns();
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                norm += Math.pow(matrix.getDouble(i, j), 2.0);
            }
        }
        return norm;
    }

    public INDArray getGradientAsMatrix(INDArray U, INDArray V, boolean computeDerivationsOfU) {
        if (computeDerivationsOfU) {
            int m = U.rows();
            int n = U.columns();
            float[][] derivatives = new float[m][n];
            for (int s = 0; s < m; ++s) {
                for (int t = 0; t < n; ++t) {
                    derivatives[s][t] = this.getFirstDerivative(U, V, s, t, true);
                }
            }
            return Nd4j.create((float[][])derivatives);
        }
        int m = V.rows();
        int n = V.columns();
        float[][] derivatives = new float[m][n];
        for (int s = 0; s < m; ++s) {
            for (int t = 0; t < n; ++t) {
                derivatives[s][t] = this.getFirstDerivative(U, V, s, t, false);
            }
        }
        return Nd4j.create((float[][])derivatives);
    }

    public float getFirstDerivative(INDArray U, INDArray V, int s, int t, boolean deriveForU) {
        INDArray Z1 = this.X.mmul(U);
        INDArray Z2 = this.W.mmul(V).transpose();
        INDArray Z = Z1.mmul(Z2);
        INDArray Q = this.R.sub(Z);
        float derivative = 0.0f;
        int n = Q.rows();
        int m = Q.columns();
        assert (m == this.W.rows()) : "W has " + this.W.rows() + " but is expected to have m = " + m + " rows";
        assert (t < V.columns()) : "V has only " + V.columns() + " but would have to have " + (t + 1) + " columns to proceed! I.e. deriving a derivative for t = " + t + " is not possible.";
        if (deriveForU) {
            for (int i = 0; i < n; ++i) {
                float Xis = this.X.getFloat(i, s);
                for (int j = 0; j < m; ++j) {
                    double factor1 = Q.getFloat(i, j);
                    double factor2 = Xis;
                    double scalarProduct = this.W.getRow(j).mmul(V.getColumn(t)).getDouble(0, 0);
                    derivative = (float)((double)derivative - 2.0 * factor1 * factor2 * scalarProduct);
                }
            }
            derivative = (float)((double)derivative + 2.0 * this.mu * U.getDouble(s, t));
        } else {
            for (int i = 0; i < n; ++i) {
                double scalarProduct = this.X.getRow(i).mmul(V.getColumn(t)).getDouble(0, 0);
                for (int j = 0; j < m; ++j) {
                    double factor1 = Q.getFloat(i, j);
                    double Wjs = this.W.getFloat(j, s);
                    derivative = (float)((double)derivative - 2.0 * factor1 * Wjs * scalarProduct);
                }
            }
            derivative = (float)((double)derivative + 2.0 * this.mu * V.getDouble(s, t));
        }
        return derivative;
    }

    @Override
    public double computeSimilarity(INDArray x, INDArray w) {
        return 0.0;
    }
}

