/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.learner.activelearning;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.ARandomlyInitializingDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.IDyadRankingPoolProvider;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetDyadRanker;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;
import org.nd4j.linalg.primitives.Pair;

public class UCBPoolBasedActiveDyadRanker
extends ARandomlyInitializingDyadRanker {
    public UCBPoolBasedActiveDyadRanker(PLNetDyadRanker ranker, IDyadRankingPoolProvider poolProvider, int seed, int numberRandomQueriesAtStart, int minibatchSize) {
        super(ranker, poolProvider, seed, numberRandomQueriesAtStart, minibatchSize);
    }

    @Override
    public void activelyTrainWithOneInstance() throws TrainingException, InterruptedException {
        DyadRankingDataset minibatch = new DyadRankingDataset();
        for (int minibatchIndex = 0; minibatchIndex < this.getMinibatchSize(); ++minibatchIndex) {
            int index = this.getRandom().nextInt(this.getInstanceFeatures().size());
            IVector problemInstance = this.getInstanceFeatures().get(index);
            ArrayList<IDyad> dyads = new ArrayList<IDyad>(this.poolProvider.getDyadsByInstance(problemInstance));
            ArrayList<Pair> dyadsWithUCB = new ArrayList<Pair>(dyads.size());
            for (IDyad dyad : dyads) {
                double skill = this.ranker.getSkillForDyad(dyad);
                double std = this.getDyadStats().get(dyad).getStandardDeviation();
                double ucb = skill + std;
                dyadsWithUCB.add(new Pair((Object)dyad, (Object)ucb));
            }
            Collections.sort(dyadsWithUCB, Comparator.comparing(p -> -((Double)p.getRight()).doubleValue()));
            IDyad d1 = (IDyad)((Pair)dyadsWithUCB.get(0)).getFirst();
            IDyad d2 = (IDyad)((Pair)dyadsWithUCB.get(1)).getFirst();
            ArrayList<IVector> alts = new ArrayList<IVector>(2);
            alts.add(d1.getAlternative());
            alts.add(d2.getAlternative());
            SparseDyadRankingInstance sparseQueryPair = new SparseDyadRankingInstance(d1.getContext(), alts);
            IDyadRankingInstance groundTruthPair = (IDyadRankingInstance)this.poolProvider.query((ILabeledInstance)sparseQueryPair);
            minibatch.add(groundTruthPair);
        }
        this.updateRanker(minibatch);
    }
}

