/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.impl.recommender.svd;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
import org.apache.mahout.cf.taste.impl.recommender.svd.RatingSGDFactorizer;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;

public final class SVDPlusPlusFactorizer
extends RatingSGDFactorizer {
    private double[][] p;
    private double[][] y;
    private Map<Integer, List<Integer>> itemsByUser;

    public SVDPlusPlusFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException {
        this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0);
        this.biasLearningRate = 0.7;
        this.biasReg = 0.33;
    }

    public SVDPlusPlusFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting, double randomNoise, int numIterations, double learningRateDecay) throws TasteException {
        super(dataModel, numFeatures, learningRate, preventOverfitting, randomNoise, numIterations, learningRateDecay);
    }

    @Override
    protected void prepareTraining() throws TasteException {
        int feature;
        int i;
        super.prepareTraining();
        RandomWrapper random = RandomUtils.getRandom();
        this.p = new double[this.dataModel.getNumUsers()][this.numFeatures];
        for (i = 0; i < this.p.length; ++i) {
            for (feature = 0; feature < 3; ++feature) {
                this.p[i][feature] = 0.0;
            }
            for (feature = 3; feature < this.numFeatures; ++feature) {
                this.p[i][feature] = random.nextGaussian() * this.randomNoise;
            }
        }
        this.y = new double[this.dataModel.getNumItems()][this.numFeatures];
        for (i = 0; i < this.y.length; ++i) {
            for (feature = 0; feature < 3; ++feature) {
                this.y[i][feature] = 0.0;
            }
            for (feature = 3; feature < this.numFeatures; ++feature) {
                this.y[i][feature] = random.nextGaussian() * this.randomNoise;
            }
        }
        this.itemsByUser = new HashMap<Integer, List<Integer>>();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            long userId = userIDs.nextLong();
            int userIndex = this.userIndex(userId);
            FastIDSet itemIDsFromUser = this.dataModel.getItemIDsFromUser(userId);
            ArrayList<Integer> itemIndexes = new ArrayList<Integer>(itemIDsFromUser.size());
            this.itemsByUser.put(userIndex, itemIndexes);
            Iterator i$ = itemIDsFromUser.iterator();
            while (i$.hasNext()) {
                long itemID2 = (Long)i$.next();
                int i2 = this.itemIndex(itemID2);
                itemIndexes.add(i2);
            }
        }
    }

    @Override
    public Factorization factorize() throws TasteException {
        this.prepareTraining();
        super.factorize();
        for (int userIndex = 0; userIndex < this.userVectors.length; ++userIndex) {
            int feature;
            for (int itemIndex : this.itemsByUser.get(userIndex)) {
                for (feature = 3; feature < this.numFeatures; ++feature) {
                    double[] dArray = this.userVectors[userIndex];
                    int n = feature;
                    dArray[n] = dArray[n] + this.y[itemIndex][feature];
                }
            }
            double denominator = Math.sqrt(this.itemsByUser.get(userIndex).size());
            for (feature = 0; feature < this.userVectors[userIndex].length; ++feature) {
                this.userVectors[userIndex][feature] = (float)(this.userVectors[userIndex][feature] / denominator + this.p[userIndex][feature]);
            }
        }
        return this.createFactorization(this.userVectors, this.itemVectors);
    }

    @Override
    protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) {
        int userIndex = this.userIndex(userID);
        int itemIndex = this.itemIndex(itemID);
        double[] userVector = this.p[userIndex];
        double[] itemVector = this.itemVectors[itemIndex];
        double[] pPlusY = new double[this.numFeatures];
        for (int i2 : this.itemsByUser.get(userIndex)) {
            for (int f = 3; f < this.numFeatures; ++f) {
                int n = f;
                pPlusY[n] = pPlusY[n] + this.y[i2][f];
            }
        }
        double denominator = Math.sqrt(this.itemsByUser.get(userIndex).size());
        for (int feature = 0; feature < pPlusY.length; ++feature) {
            pPlusY[feature] = (float)(pPlusY[feature] / denominator + this.p[userIndex][feature]);
        }
        double prediction = this.predictRating(pPlusY, itemIndex);
        double err = (double)rating - prediction;
        double normalized_error = err / denominator;
        userVector[1] = userVector[1] + this.biasLearningRate * currentLearningRate * (err - this.biasReg * this.preventOverfitting * userVector[1]);
        itemVector[2] = itemVector[2] + this.biasLearningRate * currentLearningRate * (err - this.biasReg * this.preventOverfitting * itemVector[2]);
        for (int feature = 3; feature < this.numFeatures; ++feature) {
            double pF = userVector[feature];
            double iF = itemVector[feature];
            double deltaU = err * iF - this.preventOverfitting * pF;
            int n = feature;
            userVector[n] = userVector[n] + currentLearningRate * deltaU;
            double deltaI = err * pPlusY[feature] - this.preventOverfitting * iF;
            int n2 = feature;
            itemVector[n2] = itemVector[n2] + currentLearningRate * deltaI;
            double commonUpdate = normalized_error * iF;
            for (int itemIndex2 : this.itemsByUser.get(userIndex)) {
                double deltaI2 = commonUpdate - this.preventOverfitting * this.y[itemIndex2][feature];
                double[] dArray = this.y[itemIndex2];
                int n3 = feature;
                dArray[n3] = dArray[n3] + this.learningRate * deltaI2;
            }
        }
    }

    private double predictRating(double[] userVector, int itemID) {
        double sum = 0.0;
        for (int feature = 0; feature < this.numFeatures; ++feature) {
            sum += userVector[feature] * this.itemVectors[itemID][feature];
        }
        return sum;
    }
}

