/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.mlplan.metamining.similaritymeasures;

import ai.libs.mlplan.metamining.similaritymeasures.IHeterogenousSimilarityMeasureComputer;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.GradientDescent;
import java.util.Random;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class F1Optimizer
implements IHeterogenousSimilarityMeasureComputer {
    private Logger logger = LoggerFactory.getLogger(F1Optimizer.class);
    private static final double ALPHA_START = 1.0E-9;
    private static final double ALPHA_MAX = 1.0E-5;
    private static final int ITERATIONS_PER_PROBE = 100;
    private static final int LIMIT = 1;
    private static final double MAX_DESIRED_ERROR = 0.0;
    private INDArray rrt;
    private INDArray x;
    private INDArray u;
    private final Random rand = new Random();

    @Override
    public void build(INDArray x, INDArray w, INDArray r) {
        this.rrt = r.mmul(r.transpose());
        this.x = x;
        int m = x.columns();
        boolean numberOfImplicitFeatures = true;
        double[] denseVector = new double[m * 1];
        int c = 0;
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < 1; ++j) {
                denseVector[c++] = (this.rand.nextDouble() - 0.5) * 100.0;
            }
        }
        DenseDoubleVector currentSolutionAsVector = new DenseDoubleVector(denseVector);
        INDArray currentSolutionAsMatrix = this.vector2matrix((DoubleVector)currentSolutionAsVector, m, 1);
        double currentCost = this.getCost(currentSolutionAsMatrix);
        this.logger.debug("X = {}", (Object)x);
        this.logger.debug("randomly initialized U = {}", (Object)currentSolutionAsMatrix);
        this.logger.debug("loss of randomly initialized U: {}", (Object)currentCost);
        CostFunction cf = input -> {
            INDArray uIntermediate = this.vector2matrix(input, x.columns(), 1);
            double cost = this.getCost(uIntermediate);
            INDArray gradientMatrix = this.getGradientAsMatrix(uIntermediate);
            return new CostGradientTuple(cost, this.matrix2vector(gradientMatrix));
        };
        double alpha = 1.0E-9;
        while (currentCost > 0.0) {
            double lastCost = currentCost;
            DenseDoubleVector lastSolution = currentSolutionAsVector;
            GradientDescent gd = new GradientDescent(alpha, 1.0);
            currentSolutionAsMatrix = this.vector2matrix((DoubleVector)(currentSolutionAsVector = gd.minimize(cf, (DoubleVector)currentSolutionAsVector, 100, false)), m, 1);
            currentCost = this.getCost(currentSolutionAsMatrix);
            if (lastCost < currentCost) {
                currentSolutionAsVector = lastSolution;
                currentCost = lastCost;
                alpha /= 2.0;
            } else {
                if (!(lastCost > currentCost)) break;
                alpha *= 2.0;
            }
            alpha = Math.min(alpha, 1.0E-5);
            this.logger.debug("Current Cost {} (alpha = {})", (Object)currentCost, (Object)alpha);
        }
        this.u = currentSolutionAsMatrix;
    }

    public INDArray vector2matrix(DoubleVector vector, int m, int n) {
        double[] inputs = new double[vector.getLength()];
        for (int i = 0; i < vector.getLength(); ++i) {
            inputs[i] = vector.get(i);
        }
        return Nd4j.create((double[])inputs, (int[])new int[]{m, n});
    }

    public DoubleVector matrix2vector(INDArray matrix) {
        int m = matrix.rows();
        int n = matrix.columns();
        double[] denseVector = new double[m * n];
        int c = 0;
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                denseVector[c++] = matrix.getDouble((long)i, (long)j);
            }
        }
        return new DenseDoubleVector(denseVector);
    }

    public double getCost(INDArray u) {
        INDArray z1 = this.x.mmul(u);
        INDArray z2 = z1.transpose();
        INDArray z = z1.mmul(z2);
        INDArray q = this.rrt.sub(z);
        double cost = 0.0;
        int n = q.columns();
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                cost += Math.pow(q.getDouble((long)i, (long)j), 2.0);
            }
        }
        return cost;
    }

    public INDArray getGradientAsMatrix(INDArray u) {
        int m = this.x.columns();
        int n = u.columns();
        float[][] derivatives = new float[m][n];
        for (int k = 0; k < m; ++k) {
            for (int l = 0; l < n; ++l) {
                derivatives[k][l] = this.getFirstDerivative(u, k, l);
            }
        }
        return Nd4j.create((float[][])derivatives);
    }

    public float getFirstDerivative(INDArray u, int k, int l) {
        INDArray z1 = this.x.mmul(u);
        INDArray z2 = z1.transpose();
        INDArray z = z1.mmul(z2);
        INDArray q = this.rrt.sub(z);
        int n = this.x.rows();
        float[] sums = new float[n];
        for (int i = 0; i < n; ++i) {
            sums[i] = this.x.getRow((long)i).mmul(u.getColumn((long)l)).getFloat(0L, 0L);
        }
        float derivative = 0.0f;
        for (int i = 0; i < n; ++i) {
            float xik = this.x.getFloat((long)i, (long)k);
            for (int j = 0; j < n; ++j) {
                float sumA = xik * sums[j];
                float sumB = this.x.getFloat((long)j, (long)k) * sums[i];
                derivative += -2.0f * q.getFloat((long)i, (long)j) * (sumA + sumB);
            }
        }
        return derivative;
    }

    @Override
    public double computeSimilarity(INDArray x, INDArray w) {
        return 0.0;
    }

    public INDArray getX() {
        return this.x;
    }

    public INDArray getU() {
        return this.u;
    }
}

