/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.experiments.sinabill;

import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import org.openimaj.io.IOUtils;
import org.openimaj.io.WriteableBinary;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator;
import org.openimaj.ml.linear.experiments.sinabill.BilinearExperiment;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.init.SingleValueInitStrat;
import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
import org.openimaj.util.pair.Pair;

public class BillAustrianDampeningExperiments
extends BilinearExperiment {
    public static void main(String[] args) throws Exception {
        BillAustrianDampeningExperiments exp = new BillAustrianDampeningExperiments();
        ((BilinearExperiment)exp).performExperiment();
    }

    @Override
    public void performExperiment() throws IOException {
        double dampening;
        Pair<Matrix> next;
        BilinearLearnerParameters params = new BilinearLearnerParameters();
        params.put("eta0u", 0.02);
        params.put("eta0w", 0.02);
        params.put("lambda", 0.001);
        params.put("biconvex_tol", 0.01);
        params.put("biconvex_maxiter", 10);
        params.put("bias", true);
        params.put("biaseta0", 0.5);
        params.put("winitstrat", new SingleValueInitStrat(0.1));
        params.put("uinitstrat", new SparseZerosInitStrategy());
        BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator(new File(this.MATLAB_DATA()), 98, true);
        this.prepareExperimentLog(params);
        int foldNumber = 5;
        this.logger.debug("Starting dampening experiments");
        this.logger.debug("Fold: " + foldNumber);
        bmfdg.setFold(foldNumber, BillMatlabFileDataGenerator.Mode.TEST);
        ArrayList<Pair<Matrix>> testpairs = new ArrayList<Pair<Matrix>>();
        while ((next = bmfdg.generate()) != null) {
            testpairs.add(next);
        }
        double dampeningIncr = 1.0E-4;
        double dampeningMax = 0.02;
        this.logger.debug(String.format("Beggining dampening experiments: min=%2.5f,max=%2.5f,incr=%2.5f", dampening, dampeningMax, dampeningIncr));
        for (dampening = 0.0; dampening < dampeningMax; dampening += dampeningIncr) {
            Pair<Matrix> next2;
            params.put("dampening", dampening);
            BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(params);
            learner.reinitParams();
            this.logger.debug("Dampening is now: " + dampening);
            this.logger.debug("...training");
            bmfdg.setFold(foldNumber, BillMatlabFileDataGenerator.Mode.TRAINING);
            int j = 0;
            while ((next2 = bmfdg.generate()) != null) {
                this.logger.debug("...trying item " + j++);
                learner.process((Matrix)next2.firstObject(), (Matrix)next2.secondObject());
                Matrix u = learner.getU();
                Matrix w = learner.getW();
                Matrix bias = MatrixFactory.getDenseDefault().copyMatrix(learner.getBias());
                RootMeanSumLossEvaluator eval = new RootMeanSumLossEvaluator();
                eval.setLearner(learner);
                double loss = ((BilinearEvaluator)eval).evaluate(testpairs);
                this.logger.debug(String.format("Saving learner, Fold %d, Item %d", foldNumber, j));
                File learnerOut = new File(this.FOLD_ROOT(foldNumber), String.format("learner_%d_dampening=%2.5f", j, dampening));
                IOUtils.writeBinary((File)learnerOut, (WriteableBinary)learner);
                this.logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)w));
                this.logger.debug(String.format("W range: %2.5f -> %2.5f", CFMatrixUtils.min((Matrix)w), CFMatrixUtils.max((Matrix)w)));
                this.logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)u));
                this.logger.debug(String.format("U range: %2.5f -> %2.5f", CFMatrixUtils.min((Matrix)u), CFMatrixUtils.max((Matrix)u)));
                Boolean biasMode = (Boolean)learner.getParams().getTyped("bias");
                if (biasMode.booleanValue()) {
                    this.logger.debug("Bias: " + CFMatrixUtils.diag((Matrix)bias));
                }
                this.logger.debug(String.format("... loss: %f", loss));
            }
        }
    }
}

