/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.learner;

import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.AbstractSparseMatrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import gov.sandia.cognition.math.matrix.mtj.SparseRowMatrix;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.openimaj.io.ReadWriteableBinary;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.OnlineLearner;
import org.openimaj.ml.linear.learner.init.ContextAwareInitStrategy;
import org.openimaj.ml.linear.learner.init.InitStrategy;
import org.openimaj.ml.linear.learner.init.SparseSingleValueInitStrat;
import org.openimaj.ml.linear.learner.loss.LossFunction;
import org.openimaj.ml.linear.learner.loss.MatLossFunction;
import org.openimaj.ml.linear.learner.regul.Regulariser;

public class BilinearSparseOnlineLearner
implements OnlineLearner<Matrix, Matrix>,
ReadWriteableBinary {
    static Logger logger = LogManager.getLogger(BilinearSparseOnlineLearner.class);
    protected BilinearLearnerParameters params;
    protected Matrix w;
    protected Matrix u;
    protected SparseMatrixFactoryMTJ smf = SparseMatrixFactoryMTJ.INSTANCE;
    protected LossFunction loss;
    protected Regulariser regul;
    protected Double lambda_w;
    protected Double lambda_u;
    protected Boolean biasMode;
    protected Matrix bias;
    protected Matrix diagX;
    protected Double eta0_u;
    protected Double eta0_w;
    private Boolean forceSparcity;
    private Boolean zStandardise;
    private boolean nodataseen;
    private double eta_gamma;
    private double biasEta0;

    public BilinearSparseOnlineLearner() {
        this(new BilinearLearnerParameters());
    }

    public BilinearSparseOnlineLearner(BilinearLearnerParameters params) {
        this.params = params;
        this.reinitParams();
    }

    public void reinitParams() {
        this.loss = (LossFunction)this.params.getTyped("loss");
        this.regul = (Regulariser)this.params.getTyped("regul");
        this.lambda_w = (Double)this.params.getTyped("lambda_w");
        this.lambda_u = (Double)this.params.getTyped("lambda_u");
        this.biasMode = (Boolean)this.params.getTyped("bias");
        this.eta0_u = (Double)this.params.getTyped("eta0u");
        this.eta0_w = (Double)this.params.getTyped("eta0w");
        this.biasEta0 = (Double)this.params.getTyped("biaseta0");
        this.eta_gamma = (Double)this.params.getTyped("gamma");
        this.forceSparcity = (Boolean)this.params.getTyped("forcesparcity");
        this.zStandardise = (Boolean)this.params.getTyped("z_standardise");
        if (!this.loss.isMatrixLoss()) {
            this.loss = new MatLossFunction(this.loss);
        }
        this.nodataseen = true;
    }

    private void initParams(Matrix x, Matrix y, int xrows, int xcols, int ycols) {
        InitStrategy wstrat = this.getInitStrat("winitstrat", x, y);
        InitStrategy ustrat = this.getInitStrat("uinitstrat", x, y);
        this.w = wstrat.init(xrows, ycols);
        this.u = ustrat.init(xcols, ycols);
        if (this.forceSparcity.booleanValue()) {
            this.u = CFMatrixUtils.asSparseColumn((Matrix)this.u);
            this.w = CFMatrixUtils.asSparseColumn((Matrix)this.w);
        }
        this.bias = this.smf.createMatrix(ycols, ycols);
        if (this.biasMode.booleanValue()) {
            InitStrategy bstrat = this.getInitStrat("biasinitstrat", x, y);
            this.bias = bstrat.init(ycols, ycols);
            this.diagX = this.smf.createIdentity(ycols, ycols);
        }
    }

    private InitStrategy getInitStrat(String initstrat, Matrix x, Matrix y) {
        InitStrategy strat = (InitStrategy)this.params.getTyped(initstrat);
        if (strat instanceof ContextAwareInitStrategy) {
            ContextAwareInitStrategy cwStrat = (ContextAwareInitStrategy)this.params.getTyped(initstrat);
            cwStrat.setLearner(this);
            cwStrat.setContext(x, y);
            return cwStrat;
        }
        return strat;
    }

    @Override
    public void process(Matrix X, Matrix Y) {
        double totalu;
        double totalw;
        double totalbias;
        Integer maxiter;
        double ratio;
        Double biconvextol;
        Matrix xt;
        this.prepareNextRound(X, Y);
        int iter = 0;
        Matrix xtrows = xt = X.transpose();
        if (xt instanceof AbstractSparseMatrix) {
            xtrows = CFMatrixUtils.asSparseRow((Matrix)xt);
        }
        do {
            if (this.biasMode.booleanValue()) {
                this.loss.setBias(this.bias);
            }
            ++iter;
            Matrix neww = this.updateW(xt, this.eta0_w, this.lambda_u);
            Matrix newu = this.updateU(xtrows, neww, this.eta0_u, this.lambda_w);
            Matrix newbias = null;
            if (this.biasMode.booleanValue()) {
                newbias = this.updateBias(xt, newu, neww, this.biasEta0);
            }
            double ratioB = 0.0;
            totalbias = 0.0;
            double sumchangew = CFMatrixUtils.absSum((Matrix)((Matrix)neww.minus((Ring)this.w)));
            totalw = CFMatrixUtils.absSum((Matrix)this.w);
            double sumchangeu = CFMatrixUtils.absSum((Matrix)((Matrix)newu.minus((Ring)this.u)));
            totalu = CFMatrixUtils.absSum((Matrix)this.u);
            double ratioU = 0.0;
            if (totalu != 0.0) {
                ratioU = sumchangeu / totalu;
            }
            double ratioW = 0.0;
            if (totalw != 0.0) {
                ratioU = sumchangew / totalw;
            }
            ratio = ratioU + 0.0;
            if (this.biasMode.booleanValue()) {
                double sumchangebias = CFMatrixUtils.absSum((Matrix)((Matrix)newbias.minus((Ring)this.bias)));
                totalbias = CFMatrixUtils.absSum((Matrix)this.bias);
                if (totalbias != 0.0) {
                    ratioB = sumchangebias / totalbias;
                }
                ratio += ratioB;
                ratio /= 3.0;
            } else {
                ratio /= 2.0;
            }
            if (this.forceSparcity.booleanValue()) {
                this.u = CFMatrixUtils.asSparseColumn((Matrix)newu);
                this.w = CFMatrixUtils.asSparseColumn((Matrix)neww);
            } else {
                this.w = neww;
                this.u = newu;
            }
            if (this.biasMode.booleanValue()) {
                this.bias = newbias;
            }
            biconvextol = (Double)this.params.getTyped("biconvex_tol");
            maxiter = (Integer)this.params.getTyped("biconvex_maxiter");
            if (iter % 3 != 0) continue;
            logger.debug(String.format("Iter: %d. Last Ratio: %2.3f", iter, ratio));
            logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)this.w));
            logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)this.u));
            logger.debug("Total U magnitude: " + totalu);
            logger.debug("Total W magnitude: " + totalw);
            logger.debug("Total Bias: " + totalbias);
        } while (!(biconvextol < 0.0) && !(ratio < biconvextol) && iter < maxiter);
        logger.debug("tolerance reached after iteration: " + iter);
        logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)this.w));
        logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)this.u));
        logger.debug("Total U magnitude: " + totalu);
        logger.debug("Total W magnitude: " + totalw);
        logger.debug("Total Bias: " + totalbias);
    }

    private void prepareNextRound(Matrix X, Matrix Y) {
        int nfeatures = X.getNumRows();
        int nusers = X.getNumColumns();
        int ntasks = Y.getNumColumns();
        if (this.w == null) {
            this.initParams(X, Y, nfeatures, nusers, ntasks);
        }
        Double dampening = (Double)this.params.getTyped("dampening");
        double weighting = 1.0 - dampening;
        logger.debug("... dampening w, u and bias by: " + weighting);
        this.w.scaleEquals(weighting);
        this.u.scaleEquals(weighting);
        if (this.biasMode.booleanValue()) {
            this.bias.scaleEquals(weighting);
        }
        SparseMatrix Yexp = BilinearSparseOnlineLearner.expandY(Y);
        this.loss.setY((Matrix)Yexp);
    }

    protected Matrix updateBias(Matrix xt, Matrix nu, Matrix nw, double biasLossWeight) {
        Matrix newut = nu.transpose();
        Matrix utxt = CFMatrixUtils.fastdot((Matrix)newut, (Matrix)xt);
        Matrix utxtw = CFMatrixUtils.fastdot((Matrix)utxt, (Matrix)nw);
        Matrix mult = (Matrix)utxtw.plus((Ring)this.bias);
        this.loss.setBias(null);
        this.loss.setX(this.diagX);
        Matrix biasGrad = this.loss.gradient(mult);
        Matrix newbias = null;
        for (int i = 0; i < 1000; ++i) {
            logger.debug("... Line searching etab = " + biasLossWeight);
            newbias = this.bias.clone();
            Matrix scaledGradW = (Matrix)biasGrad.scale(1.0 / biasLossWeight);
            newbias = CFMatrixUtils.fastminus((Matrix)newbias, (Matrix)scaledGradW);
            if (this.loss.test_backtrack(this.bias, biasGrad, newbias, biasLossWeight)) break;
            biasLossWeight *= this.eta_gamma;
        }
        return newbias;
    }

    protected Matrix updateW(Matrix xt, double wLossWeighted, double weightedLambda) {
        Matrix Dprime = null;
        Matrix ut = this.u.transpose();
        if (this.nodataseen) {
            this.nodataseen = false;
            Matrix fakeu = new SparseSingleValueInitStrat(1.0).init(this.u.getNumColumns(), this.u.getNumRows());
            Dprime = CFMatrixUtils.fastdot((Matrix)fakeu, (Matrix)xt);
        } else {
            Dprime = CFMatrixUtils.fastdot((Matrix)ut, (Matrix)xt);
        }
        if (this.zStandardise.booleanValue()) {
            Vector rowMean = CFMatrixUtils.rowMean((Matrix)Dprime);
            CFMatrixUtils.minusEqualsCol((Matrix)Dprime, (Vector)rowMean);
        }
        this.loss.setX(Dprime);
        Matrix gradW = this.loss.gradient(this.w);
        logger.debug("Abs w_grad: " + CFMatrixUtils.absSum((Matrix)gradW));
        Matrix neww = null;
        for (int i = 0; i < 1000; ++i) {
            logger.debug("... Line searching etaw = " + wLossWeighted);
            neww = this.w.clone();
            Matrix scaledGradW = (Matrix)gradW.scale(1.0 / wLossWeighted);
            neww = CFMatrixUtils.fastminus((Matrix)neww, (Matrix)scaledGradW);
            neww = this.regul.prox(neww, weightedLambda / wLossWeighted);
            if (this.loss.test_backtrack(this.w, gradW, neww, wLossWeighted)) break;
            wLossWeighted *= this.eta_gamma;
        }
        return neww;
    }

    protected Matrix updateU(Matrix xtrows, Matrix neww, double uLossWeight, double uWeightedLambda) {
        Matrix Vprime = CFMatrixUtils.fastdot((Matrix)xtrows, (Matrix)neww);
        SparseRowMatrix Vt = CFMatrixUtils.asSparseRow((Matrix)Vprime.transpose());
        if (this.zStandardise.booleanValue()) {
            Vector rowMean = CFMatrixUtils.rowMean((Matrix)Vt);
            CFMatrixUtils.minusEqualsCol((Matrix)Vt, (Vector)rowMean);
        }
        this.loss.setX((Matrix)Vt);
        Matrix gradU = this.loss.gradient(this.u);
        logger.debug("Abs u_grad: " + CFMatrixUtils.absSum((Matrix)gradU));
        Matrix newu = null;
        for (int i = 0; i < 1000; ++i) {
            logger.debug("... Line searching etau = " + uLossWeight);
            newu = this.u.clone();
            Matrix scaledGradW = (Matrix)gradU.scale(1.0 / uLossWeight);
            newu = CFMatrixUtils.fastminus((Matrix)newu, (Matrix)scaledGradW);
            newu = this.regul.prox(newu, uWeightedLambda / uLossWeight);
            if (this.loss.test_backtrack(this.u, gradU, newu, uLossWeight)) break;
            uLossWeight *= this.eta_gamma;
        }
        return newu;
    }

    private double lambdat(int iter, double lambda) {
        return lambda / (double)iter;
    }

    public static SparseMatrix expandY(Matrix Y) {
        int ntasks = Y.getNumColumns();
        SparseMatrix Yexp = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(ntasks, ntasks);
        for (int touter = 0; touter < ntasks; ++touter) {
            for (int tinner = 0; tinner < ntasks; ++tinner) {
                if (tinner == touter) {
                    Yexp.setElement(touter, tinner, Y.getElement(0, tinner));
                    continue;
                }
                Yexp.setElement(touter, tinner, Double.NaN);
            }
        }
        return Yexp;
    }

    protected double etat(int iter, double eta0) {
        Integer etaSteps = (Integer)this.params.getTyped("etasteps");
        double sqrtCeil = Math.sqrt(Math.ceil((double)iter / (double)etaSteps.intValue()));
        return this.eta(eta0) / sqrtCeil;
    }

    private double eta(double eta0) {
        return eta0;
    }

    public BilinearLearnerParameters getParams() {
        return this.params;
    }

    public Matrix getU() {
        return this.u;
    }

    public Matrix getW() {
        return this.w;
    }

    public Matrix getBias() {
        if (this.biasMode.booleanValue()) {
            return this.bias;
        }
        return null;
    }

    public void addU(int newUsers) {
        if (this.u == null) {
            return;
        }
        InitStrategy ustrat = this.getInitStrat("expandeduinitstrat", null, null);
        Matrix newU = ustrat.init(newUsers, this.u.getNumColumns());
        this.u = CFMatrixUtils.vstack((Matrix[])new Matrix[]{this.u, newU});
    }

    public void addW(int newWords) {
        if (this.w == null) {
            return;
        }
        InitStrategy wstrat = this.getInitStrat("expandedwinitstrat", null, null);
        Matrix newW = wstrat.init(newWords, this.w.getNumColumns());
        this.w = CFMatrixUtils.vstack((Matrix[])new Matrix[]{this.w, newW});
    }

    public BilinearSparseOnlineLearner clone() {
        BilinearSparseOnlineLearner ret = new BilinearSparseOnlineLearner(this.getParams());
        ret.u = this.u.clone();
        ret.w = this.w.clone();
        if (this.biasMode.booleanValue()) {
            ret.bias = this.bias.clone();
        }
        return ret;
    }

    public void setU(Matrix newu) {
        this.u = newu;
    }

    public void setW(Matrix neww) {
        this.w = neww;
    }

    public void readBinary(DataInput in) throws IOException {
        double readDouble;
        int r;
        int t;
        int nwords = in.readInt();
        int nusers = in.readInt();
        int ntasks = in.readInt();
        this.w = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nwords, ntasks);
        for (t = 0; t < ntasks; ++t) {
            for (r = 0; r < nwords; ++r) {
                readDouble = in.readDouble();
                if (readDouble == 0.0) continue;
                this.w.setElement(r, t, readDouble);
            }
        }
        this.u = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nusers, ntasks);
        for (t = 0; t < ntasks; ++t) {
            for (r = 0; r < nusers; ++r) {
                readDouble = in.readDouble();
                if (readDouble == 0.0) continue;
                this.u.setElement(r, t, readDouble);
            }
        }
        this.bias = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(ntasks, ntasks);
        for (int t1 = 0; t1 < ntasks; ++t1) {
            for (int t2 = 0; t2 < ntasks; ++t2) {
                readDouble = in.readDouble();
                if (readDouble == 0.0) continue;
                this.bias.setElement(t1, t2, readDouble);
            }
        }
    }

    public byte[] binaryHeader() {
        return "".getBytes();
    }

    public void writeBinary(DataOutput out) throws IOException {
        out.writeInt(this.w.getNumRows());
        out.writeInt(this.u.getNumRows());
        out.writeInt(this.u.getNumColumns());
        double[] wdata = CFMatrixUtils.getData((Matrix)this.w);
        for (int i = 0; i < wdata.length; ++i) {
            out.writeDouble(wdata[i]);
        }
        double[] udata = CFMatrixUtils.getData((Matrix)this.u);
        for (int i = 0; i < udata.length; ++i) {
            out.writeDouble(udata[i]);
        }
        double[] biasdata = CFMatrixUtils.getData((Matrix)this.bias);
        for (int i = 0; i < biasdata.length; ++i) {
            out.writeDouble(biasdata[i]);
        }
    }

    @Override
    public Matrix predict(Matrix x) {
        Matrix mult = this.u.transpose().times(x.transpose()).times(this.w);
        if (this.biasMode.booleanValue()) {
            mult.plusEquals((Ring)this.bias);
        }
        Vector ydiag = CFMatrixUtils.diag((Matrix)mult);
        Matrix createIdentity = SparseMatrixFactoryMTJ.INSTANCE.createIdentity(1, ydiag.getDimensionality());
        createIdentity.setRow(0, ydiag);
        return createIdentity;
    }
}

