/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.featuretransform;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import ai.libs.jaicore.ml.ranking.RankingPredictionBatch;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.IPLDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.featuretransform.BiliniearFeatureTransform;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.featuretransform.IDyadFeatureTransform;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.BilinFunction;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.DyadRankingFeatureTransformNegativeLogLikelihood;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.DyadRankingFeatureTransformNegativeLogLikelihoodDerivative;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.IDyadRankingFeatureTransformPLGradientDescendableFunction;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.IDyadRankingFeatureTransformPLGradientFunction;
import ai.libs.jaicore.ml.ranking.label.learner.clusterbased.customdatatypes.Ranking;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.QNMinimizer;
import java.util.ArrayList;
import java.util.Map;
import java.util.stream.Collectors;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.ranking.IRanking;
import org.api4.java.ai.ml.ranking.IRankingPredictionBatch;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FeatureTransformPLDyadRanker
extends ASupervisedLearner<IDyadRankingInstance, IDyadRankingDataset, IRanking<IDyad>, IRankingPredictionBatch>
implements IPLDyadRanker {
    private static final Logger log = LoggerFactory.getLogger(FeatureTransformPLDyadRanker.class);
    private IDyadFeatureTransform featureTransform;
    private IVector w;
    private IDyadRankingFeatureTransformPLGradientDescendableFunction negativeLogLikelihood = new DyadRankingFeatureTransformNegativeLogLikelihood();
    private IDyadRankingFeatureTransformPLGradientFunction negativeLogLikelihoodDerivative = new DyadRankingFeatureTransformNegativeLogLikelihoodDerivative();

    public FeatureTransformPLDyadRanker() {
        this(new BiliniearFeatureTransform());
    }

    public FeatureTransformPLDyadRanker(IDyadFeatureTransform featureTransform) {
        this.featureTransform = featureTransform;
    }

    private double computeSkillForDyad(IDyad dyad) {
        IVector featureTransformVector = this.featureTransform.transform(dyad);
        double dot = this.w.dotProduct(featureTransformVector);
        double val = Math.exp(dot);
        log.debug("Feature transform for dyad {} is {}. \n Dot-Product is {} and skill is {}", new Object[]{dyad, featureTransformVector, dot, val});
        return val;
    }

    private double likelihoodOfParameter(IVector w, IDyadRankingDataset dataset) {
        int largeN = dataset.size();
        double outerProduct = 1.0;
        for (int smallN = 0; smallN < largeN; ++smallN) {
            IDyadRankingInstance dyadRankingInstance = (IDyadRankingInstance)dataset.get(smallN);
            int mN = dyadRankingInstance.getNumberOfRankedElements();
            double innerProduct = 1.0;
            for (int m = 0; m < mN; ++m) {
                IDyad dyad = (IDyad)dyadRankingInstance.getLabel().get(m);
                IVector zNM = this.featureTransform.transform(dyad);
                double en = Math.exp(w.dotProduct(zNM));
                double denumSum = 0.0;
                for (int l = m; l < mN; ++l) {
                    IDyad dyadL = (IDyad)dyadRankingInstance.getLabel().get(l);
                    IVector zNL = this.featureTransform.transform(dyadL);
                    denumSum += Math.exp(w.dotProduct(zNL));
                }
                innerProduct *= en / denumSum;
            }
            outerProduct *= innerProduct;
        }
        return outerProduct;
    }

    public void fit(IDyadRankingDataset dataset) throws TrainingException, InterruptedException {
        Map<IDyadRankingInstance, Map<IDyad, IVector>> featureTransforms = this.featureTransform.getPreComputedFeatureTransforms(dataset);
        this.negativeLogLikelihood.initialize(dataset, featureTransforms);
        this.negativeLogLikelihoodDerivative.initialize(dataset, featureTransforms);
        int alternativeLength = ((IDyad)((IDyadRankingInstance)dataset.get(0)).getLabel().get(0)).getAlternative().length();
        int instanceLength = ((IDyad)((IDyadRankingInstance)dataset.get(0)).getLabel().get(0)).getContext().length();
        this.w = new DenseDoubleVector(this.featureTransform.getTransformedVectorLength(alternativeLength, instanceLength), 0.3);
        log.debug("Likelihood of the randomly filled w is {}", (Object)this.likelihoodOfParameter(this.w, dataset));
        BilinFunction fun = new BilinFunction(featureTransforms, dataset, this.featureTransform.getTransformedVectorLength(alternativeLength, instanceLength));
        QNMinimizer minimizer = new QNMinimizer();
        this.w = new DenseDoubleVector(minimizer.minimize((DiffFunction)fun, 0.01, this.w.asArray()));
        log.debug("Finished optimizing, the final w is {}", (Object)this.w);
    }

    @Override
    public IRanking<IDyad> predict(IDyadRankingInstance instance) throws PredictionException, InterruptedException {
        if (this.w == null) {
            throw new PredictionException("The Ranker has not been trained yet.");
        }
        log.debug("Training ranker with instance {}", (Object)instance);
        ArrayList<Pair> skillForDyads = new ArrayList<Pair>();
        for (IDyad d : instance) {
            double skill = this.computeSkillForDyad(d);
            skillForDyads.add(new Pair((Object)skill, (Object)d));
        }
        return new Ranking<IDyad>(skillForDyads.stream().sorted((p1, p2) -> Double.compare((Double)p1.getX(), (Double)p2.getX())).map(Pair::getY).collect(Collectors.toList()));
    }

    @Override
    public IRankingPredictionBatch predict(IDyadRankingInstance[] dTest) throws PredictionException, InterruptedException {
        ArrayList rankings = new ArrayList();
        for (IDyadRankingInstance instance : dTest) {
            rankings.add(this.predict(instance));
        }
        return new RankingPredictionBatch(rankings);
    }
}

