/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.mlplan.metamining.similaritymeasures;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.mlplan.metamining.similaritymeasures.IHeterogenousSimilarityMeasureComputer;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.GradientDescent;
import java.util.ArrayList;
import java.util.Random;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class F3Optimizer
implements IHeterogenousSimilarityMeasureComputer {
    private static final Logger logger = LoggerFactory.getLogger(F3Optimizer.class);
    private static final double ALPHA_START = 1.0E-9;
    private static final double ALPHA_MAX = 1.0E-6;
    private static final int ITERATIONS_PER_PROBE = 100;
    private static final int LIMIT = 1;
    private static final double MAX_DESIRED_ERROR = 0.0;
    private final double mu;
    private INDArray r;
    private INDArray x;
    private INDArray w;
    private INDArray u;
    private INDArray v;
    private final Random rand = new Random();

    public F3Optimizer(double mu) {
        this.mu = mu;
    }

    @Override
    public void build(INDArray x, INDArray w, INDArray r) {
        this.r = r;
        this.w = w;
        this.x = x;
        int n = x.rows();
        int d = x.columns();
        int m = w.rows();
        int l = w.columns();
        boolean numberOfImplicitFeatures = true;
        logger.debug("X = ( {} x {} )", (Object)n, (Object)x.columns());
        logger.debug("W = ( {} x {} )", (Object)m, (Object)w.columns());
        boolean succesfullyBooted = false;
        DoubleVector currentSolutionAsVector = this.getRandomInitSolution(d, l, 1);
        Pair<INDArray, INDArray> currentUAndVAsMatrix = this.vector2matrices(currentSolutionAsVector, d, 1, l, 1);
        logger.debug("randomly initialized U = {} ({} x {})", new Object[]{currentUAndVAsMatrix.getX(), d, 1});
        logger.debug("randomly initialized V = {} ({} x {})", new Object[]{currentUAndVAsMatrix.getY(), l, 1});
        double currentCost = this.getCost((INDArray)currentUAndVAsMatrix.getX(), (INDArray)currentUAndVAsMatrix.getY());
        logger.debug("loss of randomly initialized U and V: {}", (Object)currentCost);
        CostFunction cf = input -> {
            Pair<INDArray, INDArray> uAndV = this.vector2matrices(input, d, 1, l, 1);
            INDArray uIntermediate = (INDArray)uAndV.getX();
            INDArray vIntermediate = (INDArray)uAndV.getY();
            assert (uIntermediate.rows() == d && uIntermediate.columns() == 1) : "Incorrect shape of U: (" + uIntermediate.rows() + " x " + uIntermediate.columns() + ") instead of (" + d + " x " + 1 + ")";
            assert (vIntermediate.rows() == l && vIntermediate.columns() == 1) : "Incorrect shape of V: (" + vIntermediate.rows() + " x " + vIntermediate.columns() + ") instead of (" + l + " x " + 1 + ")";
            double cost = this.getCost(uIntermediate, vIntermediate);
            INDArray gradientMatrixForU = this.getGradientAsMatrix(uIntermediate, vIntermediate, true);
            INDArray gradientMatrixForV = this.getGradientAsMatrix(uIntermediate, vIntermediate, false);
            return new CostGradientTuple(cost, this.matrices2vector(gradientMatrixForU, gradientMatrixForV));
        };
        double alpha = 1.0E-9;
        int turnsWithoutImprovement = 0;
        while (currentCost > 0.0) {
            double lastCost = currentCost;
            DoubleVector lastSolution = currentSolutionAsVector;
            GradientDescent gd = new GradientDescent(alpha, 1.0);
            currentSolutionAsVector = gd.minimize(cf, currentSolutionAsVector, 100, false);
            logger.debug("Produced gd solution vector {}", (Object)currentSolutionAsVector);
            boolean hasNanEntry = false;
            for (int i = 0; i < currentSolutionAsVector.getLength(); ++i) {
                if (!Double.valueOf(currentSolutionAsVector.get(i)).equals(Double.NaN)) continue;
                hasNanEntry = true;
                break;
            }
            if (hasNanEntry) {
                currentSolutionAsVector = lastSolution;
                currentCost = lastCost;
                if (alpha > 1.0E-20) {
                    alpha /= 2.0;
                }
            } else {
                currentUAndVAsMatrix = this.vector2matrices(currentSolutionAsVector, d, 1, l, 1);
                currentCost = this.getCost((INDArray)currentUAndVAsMatrix.getX(), (INDArray)currentUAndVAsMatrix.getY());
                if (lastCost <= currentCost) {
                    currentSolutionAsVector = lastSolution;
                    currentCost = lastCost;
                    if (lastCost == currentCost) {
                        ++turnsWithoutImprovement;
                        alpha *= 2.0;
                    } else if (alpha > 1.0E-20) {
                        alpha /= 2.0;
                    }
                    if (turnsWithoutImprovement > 10) {
                        logger.debug("No further improvement, canceling");
                        break;
                    }
                } else {
                    if (!succesfullyBooted) {
                        succesfullyBooted = true;
                    }
                    turnsWithoutImprovement = 0;
                    alpha *= 2.0;
                }
                alpha = Math.min(alpha, 1.0E-6);
                logger.debug("Current cost: {} (alpha= {})", (Object)currentCost, (Object)alpha);
            }
            if (succesfullyBooted) continue;
            currentSolutionAsVector = this.getRandomInitSolution(d, l, 1);
            currentUAndVAsMatrix = this.vector2matrices(currentSolutionAsVector, d, 1, l, 1);
            currentCost = this.getCost((INDArray)currentUAndVAsMatrix.getX(), (INDArray)currentUAndVAsMatrix.getY());
            alpha = 1.0E-9;
            logger.info("Rebooting approach with solution vector {} that has cost {}", (Object)currentSolutionAsVector, (Object)currentCost);
        }
        this.u = (INDArray)currentUAndVAsMatrix.getX();
        this.v = (INDArray)currentUAndVAsMatrix.getY();
        logger.info("Finished learning");
        logger.debug("U = {}", (Object)this.u);
        logger.debug("V = {}", (Object)this.v);
    }

    private DoubleVector getRandomInitSolution(int d, int l, int numberOfImplicitFeatures) {
        int j;
        int i;
        double[] denseVector = new double[(d + l) * numberOfImplicitFeatures];
        int c = 0;
        for (i = 0; i < d; ++i) {
            for (j = 0; j < numberOfImplicitFeatures; ++j) {
                denseVector[c++] = (this.rand.nextDouble() - 0.5) * 100.0;
            }
        }
        for (i = 0; i < l; ++i) {
            for (j = 0; j < numberOfImplicitFeatures; ++j) {
                denseVector[c++] = (this.rand.nextDouble() - 0.5) * 100.0;
            }
        }
        return new DenseDoubleVector(denseVector);
    }

    public INDArray vector2matrix(DoubleVector vector, int m, int n) {
        double[] inputs = new double[vector.getLength()];
        for (int i = 0; i < vector.getLength(); ++i) {
            inputs[i] = vector.get(i);
        }
        return Nd4j.create((double[])inputs, (int[])new int[]{m, n});
    }

    public Pair<INDArray, INDArray> vector2matrices(DoubleVector vector, int n, int d, int m, int l) {
        DoubleVector inputForU = vector.sliceByLength(0, n * d);
        DoubleVector inputForV = vector.sliceByLength(n * d, vector.getLength() - inputForU.getLength());
        INDArray uIntermediate = this.vector2matrix(inputForU, n, d);
        INDArray vIntermediate = this.vector2matrix(inputForV, m, l);
        return new Pair((Object)uIntermediate, (Object)vIntermediate);
    }

    public DoubleVector matrix2vector(INDArray matrix) {
        int m = matrix.rows();
        int n = matrix.columns();
        double[] denseVector = new double[m * n];
        int c = 0;
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                denseVector[c++] = matrix.getDouble((long)i, (long)j);
            }
        }
        return new DenseDoubleVector(denseVector);
    }

    public DoubleVector matrices2vector(INDArray ... matrices) {
        ArrayList<DoubleVector> vectors = new ArrayList<DoubleVector>();
        int length = 0;
        for (INDArray matrix : matrices) {
            DoubleVector vector = this.matrix2vector(matrix);
            vectors.add(vector);
            length += vector.getLength();
        }
        double[] collapsed = new double[length];
        int c = 0;
        for (DoubleVector vector : vectors) {
            for (int i = 0; i < vector.getLength(); ++i) {
                collapsed[c++] = vector.get(i);
            }
        }
        return new DenseDoubleVector(collapsed);
    }

    public double getCost(INDArray u, INDArray v) {
        INDArray z1 = this.x.mmul(u);
        INDArray z2 = this.w.mmul(v).transpose();
        INDArray z = z1.mmul(z2);
        INDArray q = this.r.sub(z);
        return this.getSquaredFrobeniusNorm(q) + this.mu * (this.getSquaredFrobeniusNorm(u) + this.getSquaredFrobeniusNorm(v));
    }

    public double getSquaredFrobeniusNorm(INDArray matrix) {
        double norm = 0.0;
        int m = matrix.rows();
        int n = matrix.columns();
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                norm += Math.pow(matrix.getDouble((long)i, (long)j), 2.0);
            }
        }
        return norm;
    }

    public INDArray getGradientAsMatrix(INDArray u, INDArray v, boolean computeDerivationsOfU) {
        if (computeDerivationsOfU) {
            int m = u.rows();
            int n = u.columns();
            float[][] derivatives = new float[m][n];
            for (int s = 0; s < m; ++s) {
                for (int t = 0; t < n; ++t) {
                    derivatives[s][t] = this.getFirstDerivative(u, v, s, t, true);
                }
            }
            return Nd4j.create((float[][])derivatives);
        }
        int m = v.rows();
        int n = v.columns();
        float[][] derivatives = new float[m][n];
        for (int s = 0; s < m; ++s) {
            for (int t = 0; t < n; ++t) {
                derivatives[s][t] = this.getFirstDerivative(u, v, s, t, false);
            }
        }
        return Nd4j.create((float[][])derivatives);
    }

    public float getFirstDerivative(INDArray u, INDArray v, int s, int t, boolean deriveForU) {
        INDArray z1 = this.x.mmul(u);
        INDArray z2 = this.w.mmul(v).transpose();
        INDArray z = z1.mmul(z2);
        INDArray q = this.r.sub(z);
        float derivative = 0.0f;
        int n = q.rows();
        int m = q.columns();
        assert (m == this.w.rows()) : "W has " + this.w.rows() + " but is expected to have m = " + m + " rows";
        if (t >= v.columns()) {
            throw new IllegalArgumentException("V has only " + v.columns() + " but would have to have " + (t + 1) + " columns to proceed! I.e. deriving a derivative for t = " + t + " is not possible.");
        }
        if (deriveForU) {
            for (int i = 0; i < n; ++i) {
                float xis = this.x.getFloat((long)i, (long)s);
                for (int j = 0; j < m; ++j) {
                    double factor1 = q.getFloat((long)i, (long)j);
                    double factor2 = xis;
                    double scalarProduct = this.w.getRow((long)j).mmul(v.getColumn((long)t)).getDouble(0L, 0L);
                    derivative = (float)((double)derivative - 2.0 * factor1 * factor2 * scalarProduct);
                }
            }
            derivative = (float)((double)derivative + 2.0 * this.mu * u.getDouble((long)s, (long)t));
        } else {
            for (int i = 0; i < n; ++i) {
                double scalarProduct = this.x.getRow((long)i).mmul(v.getColumn((long)t)).getDouble(0L, 0L);
                for (int j = 0; j < m; ++j) {
                    double factor1 = q.getFloat((long)i, (long)j);
                    double wjs = this.w.getFloat((long)j, (long)s);
                    derivative = (float)((double)derivative - 2.0 * factor1 * wjs * scalarProduct);
                }
            }
            derivative = (float)((double)derivative + 2.0 * this.mu * v.getDouble((long)s, (long)t));
        }
        return derivative;
    }

    @Override
    public double computeSimilarity(INDArray x, INDArray w) {
        return x.mmul(this.u).mmul(this.v.transpose()).mmul(w.transpose()).getDouble(0L);
    }
}

