/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.learner.activelearning;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.ARandomlyInitializingDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.IDyadRankingPoolProvider;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetDyadRanker;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;
import org.nd4j.linalg.primitives.Pair;

public class PrototypicalPoolBasedActiveDyadRanker
extends ARandomlyInitializingDyadRanker {
    private ArrayList<IDyadRankingInstance> seenInstances;
    private double ratioOfOldInstancesForMinibatch;
    private int lengthOfTopRankingToConsider;

    public PrototypicalPoolBasedActiveDyadRanker(PLNetDyadRanker ranker, IDyadRankingPoolProvider poolProvider, int maxBatchSize, int lengthOfTopRankingToConsider, double ratioOfOldInstancesForMinibatch, int numberRandomQueriesAtStart, int seed) {
        super(ranker, poolProvider, seed, numberRandomQueriesAtStart, maxBatchSize);
        this.seenInstances = new ArrayList(poolProvider.getPool().size());
        this.ratioOfOldInstancesForMinibatch = ratioOfOldInstancesForMinibatch;
        this.lengthOfTopRankingToConsider = lengthOfTopRankingToConsider;
    }

    @Override
    public void activelyTrainWithOneInstance() throws TrainingException, InterruptedException {
        IVector curDStar;
        ArrayList<IDyad> dyads;
        DyadRankingDataset minibatch = new DyadRankingDataset();
        ArrayList<Pair> dStarWithProbability = new ArrayList<Pair>(this.getMinibatchSize());
        for (IVector instanceFeatures : this.poolProvider.getInstanceFeatures()) {
            dStarWithProbability.add(new Pair((Object)instanceFeatures, (Object)54.0));
        }
        Collections.shuffle(dStarWithProbability);
        int numberOfOldInstances = Integer.min((int)(this.ratioOfOldInstancesForMinibatch * (double)this.getMinibatchSize()), this.seenInstances.size());
        int numberOfNewInstances = this.getMinibatchSize() - numberOfOldInstances;
        for (int batchIndex = 0; batchIndex < numberOfNewInstances && (dyads = new ArrayList<IDyad>(this.poolProvider.getDyadsByInstance(curDStar = (IVector)((Pair)dStarWithProbability.get(batchIndex)).getFirst()))).size() >= 2; ++batchIndex) {
            IVector instance = ((IDyad)dyads.get(0)).getContext();
            ArrayList<IVector> alternatives = new ArrayList<IVector>(dyads.size());
            for (IDyad dyad : dyads) {
                alternatives.add(dyad.getAlternative());
            }
            SparseDyadRankingInstance queryRanking = new SparseDyadRankingInstance(instance, alternatives);
            IDyadRankingInstance queryPair = this.ranker.getPairWithLeastCertainty(queryRanking);
            ArrayList<IVector> alternativePair = new ArrayList<IVector>(queryPair.getNumberOfRankedElements());
            for (IDyad dyad : queryPair) {
                alternativePair.add(dyad.getAlternative());
            }
            SparseDyadRankingInstance sparseQueryPair = new SparseDyadRankingInstance(((IDyad)queryPair.getLabel().get(0)).getContext(), alternativePair);
            IDyadRankingInstance groundTruthPair = (IDyadRankingInstance)this.poolProvider.query((ILabeledInstance)sparseQueryPair);
            this.seenInstances.add(groundTruthPair);
            minibatch.add(groundTruthPair);
        }
        Collections.shuffle(this.seenInstances);
        List<IDyadRankingInstance> oldInstances = this.seenInstances.subList(0, numberOfOldInstances);
        minibatch.addAll(oldInstances);
        this.updateRanker(minibatch);
    }

    public double getRatioOfOldInstancesForMinibatch() {
        return this.ratioOfOldInstancesForMinibatch;
    }

    public void setRatioOfOldInstancesForMinibatch(double ratioOfOldInstancesForMinibatch) {
        this.ratioOfOldInstancesForMinibatch = ratioOfOldInstancesForMinibatch;
    }

    public int getLengthOfTopRankingToConsider() {
        return this.lengthOfTopRankingToConsider;
    }

    public void setLengthOfTopRankingToConsider(int lengthOfTopRankingToConsider) {
        this.lengthOfTopRankingToConsider = lengthOfTopRankingToConsider;
    }
}

