/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.learner;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import java.util.HashMap;
import java.util.Map;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.OnlineLearner;
import org.openimaj.util.pair.IndependentPair;
import org.openimaj.util.pair.Pair;

public class IncrementalBilinearSparseOnlineLearner
implements OnlineLearner<Map<String, Map<String, Double>>, Map<String, Double>> {
    private BiMap<String, Integer> vocabulary;
    private BiMap<String, Integer> users;
    private BiMap<String, Integer> values;
    private BilinearSparseOnlineLearner bilinearLearner;
    private BilinearLearnerParameters params;

    public IncrementalBilinearSparseOnlineLearner() {
        this.init(new IncrementalBilinearSparseOnlineLearnerParams());
    }

    public IncrementalBilinearSparseOnlineLearner(BilinearLearnerParameters params) {
        this.init(params);
    }

    public void reinitParams() {
        this.init(this.params);
    }

    private void init(BilinearLearnerParameters params) {
        this.vocabulary = HashBiMap.create();
        this.users = HashBiMap.create();
        this.values = HashBiMap.create();
        this.params = params;
        this.bilinearLearner = new BilinearSparseOnlineLearner(params);
    }

    public BilinearLearnerParameters getParams() {
        return this.params;
    }

    @Override
    public void process(Map<String, Map<String, Double>> x, Map<String, Double> y) {
        this.updateUserValues(x, y);
        Matrix yMat = this.constructYMatrix(y);
        Matrix xMat = this.constructXMatrix(x);
        this.bilinearLearner.process(xMat, yMat);
    }

    public void updateUserValues(Map<String, Map<String, Double>> x, Map<String, Double> y) {
        this.updateUserWords(x);
        this.updateValues(y);
    }

    private void updateValues(Map<String, Double> y) {
        for (String value : y.keySet()) {
            if (this.values.containsKey((Object)value)) continue;
            this.values.put((Object)value, (Object)this.values.size());
        }
    }

    private Matrix constructYMatrix(Map<String, Double> y) {
        SparseMatrix mat = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, this.values.size());
        for (Map.Entry<String, Double> ent : y.entrySet()) {
            mat.setElement(0, ((Integer)this.values.get((Object)ent.getKey())).intValue(), ent.getValue().doubleValue());
        }
        return mat;
    }

    private Map<String, Double> constructYMap(Matrix y) {
        HashMap<String, Double> ret = new HashMap<String, Double>();
        for (String key : this.values.keySet()) {
            Integer index = (Integer)this.values.get((Object)key);
            double yvalue = y.getElement(0, index.intValue());
            ret.put(key, yvalue);
        }
        return ret;
    }

    private Matrix constructXMatrix(Map<String, Map<String, Double>> x) {
        SparseMatrix mat = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(this.vocabulary.size(), this.users.size());
        for (Map.Entry<String, Map<String, Double>> userwords : x.entrySet()) {
            int userindex = (Integer)this.users.get((Object)userwords.getKey());
            for (Map.Entry<String, Double> ent : userwords.getValue().entrySet()) {
                mat.setElement(((Integer)this.vocabulary.get((Object)ent.getKey())).intValue(), userindex, ent.getValue().doubleValue());
            }
        }
        return mat;
    }

    private void updateUserWords(Map<String, Map<String, Double>> x) {
        int newUsers = 0;
        int newWords = 0;
        for (Map.Entry<String, Map<String, Double>> userWords : x.entrySet()) {
            String user = userWords.getKey();
            if (!this.users.containsKey((Object)user)) {
                this.users.put((Object)user, (Object)this.users.size());
                ++newUsers;
            }
            newWords += this.updateWords(userWords.getValue());
        }
        this.bilinearLearner.addU(newUsers);
        this.bilinearLearner.addW(newWords);
    }

    private int updateWords(Map<String, Double> value) {
        int newWords = 0;
        for (String word : value.keySet()) {
            if (this.vocabulary.containsKey((Object)word)) continue;
            this.vocabulary.put((Object)word, (Object)this.vocabulary.size());
            ++newWords;
        }
        return newWords;
    }

    public BilinearSparseOnlineLearner getBilinearLearner(int nusers, int nwords) {
        BilinearSparseOnlineLearner ret = this.bilinearLearner.clone();
        Matrix u = ret.getU();
        Matrix w = ret.getW();
        SparseMatrix newu = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nusers, u.getNumColumns());
        SparseMatrix neww = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nwords, w.getNumColumns());
        newu.setSubMatrix(0, 0, u);
        neww.setSubMatrix(0, 0, w);
        ret.setU((Matrix)newu);
        ret.setW((Matrix)neww);
        return ret;
    }

    public BilinearSparseOnlineLearner getBilinearLearner() {
        return this.bilinearLearner.clone();
    }

    public Pair<Matrix> asMatrixPair(IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>> xy, int nfeatures, int nusers, int ntasks) {
        SparseMatrix y = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, ntasks);
        SparseMatrix x = SparseMatrixFactoryMTJ.INSTANCE.createMatrix(nfeatures, nusers);
        Map ymap = (Map)xy.secondObject();
        Map userFeatureMap = (Map)xy.firstObject();
        for (Map.Entry yent : ymap.entrySet()) {
            y.setElement(0, ((Integer)this.values.get(yent.getKey())).intValue(), ((Double)yent.getValue()).doubleValue());
        }
        for (Map.Entry xent : userFeatureMap.entrySet()) {
            int userind = (Integer)this.users.get(xent.getKey());
            for (Map.Entry fent : ((Map)xent.getValue()).entrySet()) {
                x.setElement(((Integer)this.vocabulary.get(fent.getKey())).intValue(), userind, ((Double)fent.getValue()).doubleValue());
            }
        }
        return new Pair((Object)x, (Object)y);
    }

    @Override
    public Map<String, Double> predict(Map<String, Map<String, Double>> x) {
        Matrix xMat = this.constructXMatrix(x);
        Matrix yMat = this.bilinearLearner.predict(xMat);
        return this.constructYMap(yMat);
    }

    public BiMap<String, Integer> getVocabulary() {
        return this.vocabulary;
    }

    public Pair<Matrix> asMatrixPair(IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>> in) {
        return this.asMatrixPair(in, this.vocabulary.size(), this.users.size(), this.values.size());
    }

    public Pair<Matrix> asMatrixPair(Map<String, Map<String, Double>> x, Map<String, Double> y) {
        return this.asMatrixPair((IndependentPair<Map<String, Map<String, Double>>, Map<String, Double>>)IndependentPair.pair(x, y), this.vocabulary.size(), this.users.size(), this.values.size());
    }

    public BiMap<String, Integer> getDependantValues() {
        return this.values;
    }

    public BiMap<String, Integer> getUsers() {
        return this.users;
    }

    static class IncrementalBilinearSparseOnlineLearnerParams
    extends BilinearLearnerParameters {
        private static final long serialVersionUID = -1847045895118918210L;

        IncrementalBilinearSparseOnlineLearnerParams() {
        }
    }
}

