/*
 * Decompiled with CFR 0.152.
 */
package org.lenskit.predict.ordrec;

import it.unimi.dsi.fastutil.longs.LongCollection;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongIterators;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.grouplens.lenskit.iterative.IterationCount;
import org.grouplens.lenskit.iterative.LearningRate;
import org.grouplens.lenskit.iterative.RegularizationTerm;
import org.grouplens.lenskit.vectors.ImmutableSparseVector;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.lenskit.api.ItemScorer;
import org.lenskit.api.Result;
import org.lenskit.api.ResultMap;
import org.lenskit.basic.AbstractRatingPredictor;
import org.lenskit.data.dao.UserEventDAO;
import org.lenskit.data.history.UserHistory;
import org.lenskit.data.ratings.Rating;
import org.lenskit.data.ratings.Ratings;
import org.lenskit.predict.ordrec.OrdRecModel;
import org.lenskit.results.AbstractResult;
import org.lenskit.results.Results;
import org.lenskit.transform.quantize.Quantizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OrdRecRatingPredictor
extends AbstractRatingPredictor {
    private static final Logger logger = LoggerFactory.getLogger(OrdRecRatingPredictor.class);
    private ItemScorer itemScorer;
    private UserEventDAO userEventDao;
    private Quantizer quantizer;
    private final double learningRate;
    private final double regTerm;
    private final int iterationCount;

    @Inject
    public OrdRecRatingPredictor(ItemScorer scorer, UserEventDAO dao, Quantizer quantizer, @LearningRate double rate, @RegularizationTerm double reg, @IterationCount int niters) {
        this.userEventDao = dao;
        this.itemScorer = scorer;
        this.quantizer = quantizer;
        this.learningRate = rate;
        this.regTerm = reg;
        this.iterationCount = niters;
    }

    OrdRecRatingPredictor(ItemScorer scorer, UserEventDAO dao, Quantizer q) {
        this.userEventDao = dao;
        this.itemScorer = scorer;
        this.quantizer = q;
        this.learningRate = 0.001;
        this.regTerm = 0.015;
        this.iterationCount = 1000;
    }

    private SparseVector makeUserVector(long uid, UserEventDAO dao) {
        UserHistory history = dao.getEventsForUser(uid, Rating.class);
        ImmutableSparseVector vector = null;
        if (history != null) {
            vector = ImmutableSparseVector.create((Map)Ratings.userRatingVector((Collection)history));
        }
        return vector;
    }

    private static double dBeta(int r, int k, double beta) {
        if (r >= 0 && k == 0) {
            return 1.0;
        }
        if (k > 0 && r >= k) {
            return Math.exp(beta);
        }
        return 0.0;
    }

    private void trainModel(OrdRecModel model, SparseVector ratings, MutableSparseVector scores) {
        RealVector beta = model.getBeta();
        ArrayRealVector deltaBeta = new ArrayRealVector(beta.getDimension());
        for (int j = 0; j < this.iterationCount; ++j) {
            for (VectorEntry rating : ratings) {
                long iid = rating.getKey();
                double score = scores.get(iid);
                int r = this.quantizer.index(rating.getValue());
                double probEqualR = model.getProbEQ(score, r);
                double probLessR = model.getProbLE(score, r);
                double probLessR_1 = model.getProbLE(score, r - 1);
                double t1 = model.getT1();
                double dt1 = this.learningRate / probEqualR * (probLessR * (1.0 - probLessR) * OrdRecRatingPredictor.dBeta(r, 0, t1) - probLessR_1 * (1.0 - probLessR_1) * OrdRecRatingPredictor.dBeta(r - 1, 0, t1) - this.regTerm * t1);
                for (int k = 0; k < beta.getDimension(); ++k) {
                    double dbetaK = this.learningRate / probEqualR * (probLessR * (1.0 - probLessR) * OrdRecRatingPredictor.dBeta(r, k + 1, beta.getEntry(k)) - probLessR_1 * (1.0 - probLessR_1) * OrdRecRatingPredictor.dBeta(r - 1, k + 1, beta.getEntry(k)) - this.regTerm * beta.getEntry(k));
                    deltaBeta.setEntry(k, dbetaK);
                }
                model.update(dt1, (RealVector)deltaBeta);
            }
        }
    }

    @Nonnull
    public Map<Long, Double> predict(long user, @Nonnull Collection<Long> items) {
        return this.computePredictions(user, items, false).scoreMap();
    }

    @Nonnull
    public ResultMap predictWithDetails(long user, @Nonnull Collection<Long> items) {
        return this.computePredictions(user, items, true);
    }

    @Nonnull
    private ResultMap computePredictions(long user, @Nonnull Collection<Long> items, boolean includeDetails) {
        Map scores;
        logger.debug("predicting {} items for {}", (Object)items.size(), (Object)user);
        SparseVector ratings = this.makeUserVector(user, this.userEventDao);
        LongOpenHashSet allItems = new LongOpenHashSet((LongCollection)ratings.keySet());
        allItems.addAll(items);
        ResultMap baseResults = null;
        if (includeDetails) {
            baseResults = this.itemScorer.scoreWithDetails(user, (Collection)allItems);
            scores = baseResults.scoreMap();
        } else {
            scores = this.itemScorer.score(user, (Collection)allItems);
        }
        MutableSparseVector scoreVector = MutableSparseVector.create((Map)scores);
        OrdRecModel params = new OrdRecModel(this.quantizer);
        this.trainModel(params, ratings, scoreVector);
        logger.debug("trained parameters for {}: {}", (Object)user, (Object)params);
        ArrayRealVector probabilities = new ArrayRealVector(params.getLevelCount());
        ArrayList<FullResult> results = new ArrayList<FullResult>();
        LongIterator iter = LongIterators.asLongIterator(items.iterator());
        while (iter.hasNext()) {
            long item = iter.nextLong();
            double score = scoreVector.get(item, Double.NaN);
            if (Double.isNaN(score)) continue;
            params.getProbDistribution(score, (RealVector)probabilities);
            int mlIdx = probabilities.getMaxIndex();
            double pred = this.quantizer.getIndexValue(mlIdx);
            if (includeDetails) {
                results.add(new FullResult(baseResults.get(item), pred, (RealVector)new ArrayRealVector((RealVector)probabilities)));
                continue;
            }
            results.add((FullResult)Results.create((long)item, (double)pred));
        }
        return Results.newResultMap(results);
    }

    public static class FullResult
    extends AbstractResult
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private final Result original;
        private final RealVector distribution;

        FullResult(Result orig, double score, RealVector probs) {
            super(orig.getId(), score);
            this.original = orig;
            this.distribution = probs;
        }

        public Result getOriginalResult() {
            return this.original;
        }

        public RealVector getDistribution() {
            return this.distribution;
        }
    }
}

