/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.activelearning;

import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.activelearning.ARandomlyInitializingDyadRanker;
import ai.libs.jaicore.ml.dyadranking.activelearning.IDyadRankingPoolProvider;
import ai.libs.jaicore.ml.dyadranking.algorithm.PLNetDyadRanker;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.SparseDyadRankingInstance;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import org.nd4j.linalg.primitives.Pair;

public class PrototypicalPoolBasedActiveDyadRanker
extends ARandomlyInitializingDyadRanker {
    private ArrayList<IDyadRankingInstance> seenInstances;
    private double ratioOfOldInstancesForMinibatch;
    private int lengthOfTopRankingToConsider;

    public PrototypicalPoolBasedActiveDyadRanker(PLNetDyadRanker ranker, IDyadRankingPoolProvider poolProvider, int maxBatchSize, int lengthOfTopRankingToConsider, double ratioOfOldInstancesForMinibatch, int numberRandomQueriesAtStart, int seed) {
        super(ranker, poolProvider, seed, numberRandomQueriesAtStart, maxBatchSize);
        this.seenInstances = new ArrayList(poolProvider.getPool().size());
        this.ratioOfOldInstancesForMinibatch = ratioOfOldInstancesForMinibatch;
        this.lengthOfTopRankingToConsider = lengthOfTopRankingToConsider;
    }

    @Override
    public void activelyTrainWithOneInstance() throws TrainingException {
        Vector curDStar;
        ArrayList<Dyad> dyads;
        HashSet<IDyadRankingInstance> minibatch = new HashSet<IDyadRankingInstance>();
        ArrayList<Pair> dStarWithProbability = new ArrayList<Pair>(this.getMinibatchSize());
        for (Vector instanceFeatures : this.poolProvider.getInstanceFeatures()) {
            dStarWithProbability.add(new Pair((Object)instanceFeatures, (Object)54.0));
        }
        Collections.shuffle(dStarWithProbability);
        int numberOfOldInstances = Integer.min((int)(this.ratioOfOldInstancesForMinibatch * (double)this.getMinibatchSize()), this.seenInstances.size());
        int numberOfNewInstances = this.getMinibatchSize() - numberOfOldInstances;
        for (int batchIndex = 0; batchIndex < numberOfNewInstances && (dyads = new ArrayList<Dyad>(this.poolProvider.getDyadsByInstance(curDStar = (Vector)((Pair)dStarWithProbability.get(batchIndex)).getFirst()))).size() >= 2; ++batchIndex) {
            Vector instance = ((Dyad)dyads.get(0)).getInstance();
            ArrayList<Vector> alternatives = new ArrayList<Vector>(dyads.size());
            for (Dyad dyad : dyads) {
                alternatives.add(dyad.getAlternative());
            }
            SparseDyadRankingInstance queryRanking = new SparseDyadRankingInstance(instance, alternatives);
            IDyadRankingInstance queryPair = this.ranker.getPairWithLeastCertainty(queryRanking);
            ArrayList<Vector> alternativePair = new ArrayList<Vector>(queryPair.length());
            for (Dyad dyad : queryPair) {
                alternativePair.add(dyad.getAlternative());
            }
            SparseDyadRankingInstance sparseQueryPair = new SparseDyadRankingInstance(queryPair.getDyadAtPosition(0).getInstance(), alternativePair);
            IDyadRankingInstance groundTruthPair = this.poolProvider.query(sparseQueryPair);
            this.seenInstances.add(groundTruthPair);
            minibatch.add(groundTruthPair);
        }
        Collections.shuffle(this.seenInstances);
        List<IDyadRankingInstance> oldInstances = this.seenInstances.subList(0, numberOfOldInstances);
        minibatch.addAll(oldInstances);
        this.updateRanker(minibatch);
    }

    public double getRatioOfOldInstancesForMinibatch() {
        return this.ratioOfOldInstancesForMinibatch;
    }

    public void setRatioOfOldInstancesForMinibatch(double ratioOfOldInstancesForMinibatch) {
        this.ratioOfOldInstancesForMinibatch = ratioOfOldInstancesForMinibatch;
    }

    public int getLengthOfTopRankingToConsider() {
        return this.lengthOfTopRankingToConsider;
    }

    public void setLengthOfTopRankingToConsider(int lengthOfTopRankingToConsider) {
        this.lengthOfTopRankingToConsider = lengthOfTopRankingToConsider;
    }
}

