/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.evaluation;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.loss.LossFunction;
import org.openimaj.ml.linear.learner.loss.MatLossFunction;
import org.openimaj.util.pair.Pair;

public class SumLossEvaluator
extends BilinearEvaluator {
    Logger logger = LogManager.getLogger(SumLossEvaluator.class);

    @Override
    public double evaluate(List<Pair<Matrix>> data) {
        Matrix u = this.learner.getU();
        Matrix w = this.learner.getW();
        Matrix bias = this.learner.getBias();
        double sumloss = this.sumLoss(data, u, w, bias, this.learner.getParams());
        return sumloss;
    }

    public double sumLoss(List<Pair<Matrix>> pairs, Matrix u, Matrix w, Matrix bias, BilinearLearnerParameters params) {
        LossFunction loss = (LossFunction)params.getTyped("loss");
        loss = new MatLossFunction(loss);
        double total = 0.0;
        int i = 0;
        for (Pair<Matrix> pair : pairs) {
            Matrix X = (Matrix)pair.firstObject();
            Matrix Y = (Matrix)pair.secondObject();
            SparseMatrix Yexp = BilinearSparseOnlineLearner.expandY(Y);
            Matrix expectedAll = u.transpose().times(X.transpose()).times(w);
            loss.setY((Matrix)Yexp);
            loss.setX(expectedAll);
            if (bias != null) {
                loss.setBias(bias);
            }
            this.logger.debug("Testing pair: " + i);
            total += loss.eval(null);
            ++i;
        }
        return total;
    }
}

