/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.learner.activelearning;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.ActiveDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.IDyadRankingPoolProvider;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetDyadRanker;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class ARandomlyInitializingDyadRanker
extends ActiveDyadRanker {
    private final Logger logger = LoggerFactory.getLogger(ARandomlyInitializingDyadRanker.class);
    private final int numberRandomQueriesAtStart;
    private final Map<IDyad, SummaryStatistics> dyadStats = new HashMap<IDyad, SummaryStatistics>();
    private final List<IVector> instanceFeatures;
    private final Random random;
    private final int minibatchSize;
    private int iteration;

    public ARandomlyInitializingDyadRanker(PLNetDyadRanker ranker, IDyadRankingPoolProvider poolProvider, int seed, int numberRandomQueriesAtStart, int minibatchSize) {
        super(ranker, poolProvider);
        this.instanceFeatures = new ArrayList<IVector>(poolProvider.getInstanceFeatures());
        this.numberRandomQueriesAtStart = numberRandomQueriesAtStart;
        this.minibatchSize = minibatchSize;
        this.iteration = 0;
        for (IVector instance : this.instanceFeatures) {
            for (IDyad dyad : poolProvider.getDyadsByInstance(instance)) {
                this.dyadStats.put(dyad, new SummaryStatistics());
            }
        }
        this.random = new Random(seed);
    }

    @Override
    public void activelyTrain(int numberOfQueries) throws TrainingException, InterruptedException {
        for (int i = 0; i < numberOfQueries; ++i) {
            if (this.iteration < this.numberRandomQueriesAtStart) {
                DyadRankingDataset minibatch = new DyadRankingDataset();
                for (int batchIndex = 0; batchIndex < this.minibatchSize; ++batchIndex) {
                    Collections.shuffle(this.instanceFeatures, this.random);
                    if (this.instanceFeatures.isEmpty()) break;
                    IVector instance = this.instanceFeatures.get(0);
                    ArrayList<IDyad> dyads = new ArrayList<IDyad>(this.poolProvider.getDyadsByInstance(instance));
                    Collections.shuffle(dyads, this.random);
                    LinkedList<IVector> alternatives = new LinkedList<IVector>();
                    alternatives.add(((IDyad)dyads.get(0)).getAlternative());
                    alternatives.add(((IDyad)dyads.get(1)).getAlternative());
                    SparseDyadRankingInstance queryInstance = new SparseDyadRankingInstance(((IDyad)dyads.get(0)).getContext(), alternatives);
                    IDyadRankingInstance trueRanking = (IDyadRankingInstance)this.poolProvider.query((ILabeledInstance)queryInstance);
                    minibatch.add(trueRanking);
                }
                try {
                    this.updateRanker(minibatch);
                }
                catch (TrainingException e) {
                    this.logger.error("Updating the dyad ranking learner did not succeed.", (Throwable)e);
                }
            } else {
                this.activelyTrainWithOneInstance();
            }
            ++this.iteration;
        }
    }

    public int getNumberRandomQueriesAtStart() {
        return this.numberRandomQueriesAtStart;
    }

    public int getIteration() {
        return this.iteration;
    }

    public Map<IDyad, SummaryStatistics> getDyadStats() {
        return this.dyadStats;
    }

    public List<IVector> getInstanceFeatures() {
        return this.instanceFeatures;
    }

    public Random getRandom() {
        return this.random;
    }

    public int getMinibatchSize() {
        return this.minibatchSize;
    }

    @Override
    public abstract void activelyTrainWithOneInstance() throws TrainingException, InterruptedException;

    public void updateRanker(DyadRankingDataset minibatch) throws TrainingException, InterruptedException {
        this.ranker.fit(minibatch);
        for (IVector inst : this.getInstanceFeatures()) {
            for (IDyad dyad : this.poolProvider.getDyadsByInstance(inst)) {
                double skill = this.ranker.getSkillForDyad(dyad);
                this.dyadStats.get(dyad).addValue(skill);
            }
        }
    }
}

