/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.algorithm.featuretransform;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.core.exception.ConfigurationException;
import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.core.predictivemodel.IPredictiveModelConfiguration;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.algorithm.IPLDyadRanker;
import ai.libs.jaicore.ml.dyadranking.algorithm.featuretransform.BiliniearFeatureTransform;
import ai.libs.jaicore.ml.dyadranking.algorithm.featuretransform.IDyadFeatureTransform;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.optimizing.BilinFunction;
import ai.libs.jaicore.ml.dyadranking.optimizing.DyadRankingFeatureTransformNegativeLogLikelihood;
import ai.libs.jaicore.ml.dyadranking.optimizing.DyadRankingFeatureTransformNegativeLogLikelihoodDerivative;
import ai.libs.jaicore.ml.dyadranking.optimizing.IDyadRankingFeatureTransformPLGradientDescendableFunction;
import ai.libs.jaicore.ml.dyadranking.optimizing.IDyadRankingFeatureTransformPLGradientFunction;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.QNMinimizer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FeatureTransformPLDyadRanker
implements IPLDyadRanker {
    private static final Logger log = LoggerFactory.getLogger(FeatureTransformPLDyadRanker.class);
    private IDyadFeatureTransform featureTransform;
    private Vector w;
    private IDyadRankingFeatureTransformPLGradientDescendableFunction negativeLogLikelihood = new DyadRankingFeatureTransformNegativeLogLikelihood();
    private IDyadRankingFeatureTransformPLGradientFunction negativeLogLikelihoodDerivative = new DyadRankingFeatureTransformNegativeLogLikelihoodDerivative();

    public FeatureTransformPLDyadRanker() {
        this(new BiliniearFeatureTransform());
    }

    public FeatureTransformPLDyadRanker(IDyadFeatureTransform featureTransform) {
        this.featureTransform = featureTransform;
    }

    @Override
    public IDyadRankingInstance predict(IDyadRankingInstance instance) throws PredictionException {
        if (this.w == null) {
            throw new PredictionException("The Ranker has not been trained yet.");
        }
        log.debug("Training ranker with instance {}", (Object)instance);
        ArrayList<Pair> skillForDyads = new ArrayList<Pair>();
        for (Dyad d : instance) {
            double skill = this.computeSkillForDyad(d);
            skillForDyads.add(new Pair((Object)skill, (Object)d));
        }
        return new DyadRankingInstance(skillForDyads.stream().sorted((p1, p2) -> Double.compare((Double)p1.getX(), (Double)p2.getX())).map(Pair::getY).collect(Collectors.toList()));
    }

    @Override
    public List<IDyadRankingInstance> predict(DyadRankingDataset dataset) throws PredictionException {
        ArrayList<IDyadRankingInstance> predictions = new ArrayList<IDyadRankingInstance>();
        for (IDyadRankingInstance i : dataset) {
            predictions.add(this.predict(i));
        }
        return predictions;
    }

    private double computeSkillForDyad(Dyad dyad) {
        Vector featureTransformVector = this.featureTransform.transform(dyad);
        double dot = this.w.dotProduct(featureTransformVector);
        double val = Math.exp(dot);
        log.debug("Feature transform for dyad {} is {}. \n Dot-Product is {} and skill is {}", new Object[]{dyad, featureTransformVector, dot, val});
        return val;
    }

    @Override
    public void train(DyadRankingDataset dataset) throws TrainingException {
        Map<IDyadRankingInstance, Map<Dyad, Vector>> featureTransforms = this.featureTransform.getPreComputedFeatureTransforms(dataset);
        this.negativeLogLikelihood.initialize(dataset, featureTransforms);
        this.negativeLogLikelihoodDerivative.initialize(dataset, featureTransforms);
        int alternativeLength = ((IDyadRankingInstance)dataset.get(0)).getDyadAtPosition(0).getAlternative().length();
        int instanceLength = ((IDyadRankingInstance)dataset.get(0)).getDyadAtPosition(0).getInstance().length();
        this.w = new DenseDoubleVector(this.featureTransform.getTransformedVectorLength(alternativeLength, instanceLength), 0.3);
        log.debug("Likelihood of the randomly filled w is {}", (Object)this.likelihoodOfParameter(this.w, dataset));
        BilinFunction fun = new BilinFunction(featureTransforms, dataset, this.featureTransform.getTransformedVectorLength(alternativeLength, instanceLength));
        QNMinimizer minimizer = new QNMinimizer();
        this.w = new DenseDoubleVector(minimizer.minimize((DiffFunction)fun, 0.01, this.w.asArray()));
        log.debug("Finished optimizing, the final w is {}", (Object)this.w);
    }

    private double likelihoodOfParameter(Vector w, DyadRankingDataset dataset) {
        int largeN = dataset.size();
        double outerProduct = 1.0;
        for (int smallN = 0; smallN < largeN; ++smallN) {
            IDyadRankingInstance dyadRankingInstance = (IDyadRankingInstance)dataset.get(smallN);
            int mN = dyadRankingInstance.length();
            double innerProduct = 1.0;
            for (int m = 0; m < mN; ++m) {
                Dyad dyad = dyadRankingInstance.getDyadAtPosition(m);
                Vector zNM = this.featureTransform.transform(dyad);
                double en = Math.exp(w.dotProduct(zNM));
                double denumSum = 0.0;
                for (int l = m; l < mN; ++l) {
                    Dyad dyadL = dyadRankingInstance.getDyadAtPosition(l);
                    Vector zNL = this.featureTransform.transform(dyadL);
                    denumSum += Math.exp(w.dotProduct(zNL));
                }
                innerProduct *= en / denumSum;
            }
            outerProduct *= innerProduct;
        }
        return outerProduct;
    }

    @Override
    public IPredictiveModelConfiguration getConfiguration() {
        return null;
    }

    @Override
    public void setConfiguration(IPredictiveModelConfiguration configuration) throws ConfigurationException {
    }
}

