/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.activelearning;

import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.activelearning.ARandomlyInitializingDyadRanker;
import ai.libs.jaicore.ml.dyadranking.activelearning.IDyadRankingPoolProvider;
import ai.libs.jaicore.ml.dyadranking.algorithm.PLNetDyadRanker;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.SparseDyadRankingInstance;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import org.nd4j.linalg.primitives.Pair;

public class UCBPoolBasedActiveDyadRanker
extends ARandomlyInitializingDyadRanker {
    public UCBPoolBasedActiveDyadRanker(PLNetDyadRanker ranker, IDyadRankingPoolProvider poolProvider, int seed, int numberRandomQueriesAtStart, int minibatchSize) {
        super(ranker, poolProvider, seed, numberRandomQueriesAtStart, minibatchSize);
    }

    @Override
    public void activelyTrainWithOneInstance() throws TrainingException {
        HashSet<IDyadRankingInstance> minibatch = new HashSet<IDyadRankingInstance>();
        for (int minibatchIndex = 0; minibatchIndex < this.getMinibatchSize(); ++minibatchIndex) {
            int index = this.getRandom().nextInt(this.getInstanceFeatures().size());
            Vector problemInstance = this.getInstanceFeatures().get(index);
            ArrayList<Dyad> dyads = new ArrayList<Dyad>(this.poolProvider.getDyadsByInstance(problemInstance));
            ArrayList<Pair> dyadsWithUCB = new ArrayList<Pair>(dyads.size());
            for (Dyad dyad : dyads) {
                double skill = this.ranker.getSkillForDyad(dyad);
                double std = this.getDyadStats().get(dyad).getStandardDeviation();
                double ucb = skill + std;
                dyadsWithUCB.add(new Pair((Object)dyad, (Object)ucb));
            }
            Collections.sort(dyadsWithUCB, Comparator.comparing(p -> -((Double)p.getRight()).doubleValue()));
            Dyad d1 = (Dyad)((Pair)dyadsWithUCB.get(0)).getFirst();
            Dyad d2 = (Dyad)((Pair)dyadsWithUCB.get(1)).getFirst();
            ArrayList<Vector> alts = new ArrayList<Vector>(2);
            alts.add(d1.getAlternative());
            alts.add(d2.getAlternative());
            SparseDyadRankingInstance sparseQueryPair = new SparseDyadRankingInstance(d1.getInstance(), alts);
            IDyadRankingInstance groundTruthPair = this.poolProvider.query(sparseQueryPair);
            minibatch.add(groundTruthPair);
        }
        this.updateRanker(minibatch);
    }
}

