/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.learner.activelearning;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DenseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.learner.activelearning.IDyadRankingPoolProvider;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;
import org.nd4j.linalg.primitives.Pair;

public class DyadDatasetPoolProvider
implements IDyadRankingPoolProvider {
    private HashMap<IVector, Set<IDyad>> dyadsByInstances = new HashMap();
    private HashMap<IVector, Set<IDyad>> dyadsByAlternatives = new HashMap();
    private HashMap<IVector, IDyadRankingInstance> dyadRankingsByInstances = new HashMap();
    private HashMap<IVector, IDyadRankingInstance> dyadRankingsByAlternatives = new HashMap();
    private List<IDyadRankingInstance> pool;
    private boolean removeDyadsWhenQueried = false;
    private HashSet<IDyadRankingInstance> queriedRankings;
    private int numberQueries = 0;

    public DyadDatasetPoolProvider(IDyadRankingDataset dataset) {
        this.pool = new ArrayList<IDyadRankingInstance>(dataset.size());
        for (IDyadRankingInstance instance : dataset) {
            this.addDyadRankingInstance(instance);
        }
        this.queriedRankings = new HashSet();
    }

    public Collection<IDyadRankingInstance> getPool() {
        return this.pool;
    }

    public IDyadRankingInstance query(IDyadRankingInstance queryInstance) {
        ++this.numberQueries;
        if (!(queryInstance instanceof SparseDyadRankingInstance)) {
            throw new IllegalArgumentException("Currently only supports SparseDyadRankingInstances!");
        }
        SparseDyadRankingInstance drInstance = (SparseDyadRankingInstance)queryInstance;
        ArrayList<Pair> dyadPositionPairs = new ArrayList<Pair>(drInstance.getNumberOfRankedElements());
        Iterator<IDyad> iterator = drInstance.iterator();
        while (iterator.hasNext()) {
            IDyad dyad = iterator.next();
            int position = this.getPositionInRankingByInstanceFeatures(dyad);
            dyadPositionPairs.add(new Pair((Object)dyad, (Object)position));
        }
        Collections.sort(dyadPositionPairs, Comparator.comparing(Pair::getRight));
        ArrayList<IDyad> dyadList = new ArrayList<IDyad>(dyadPositionPairs.size());
        for (Pair pair : dyadPositionPairs) {
            dyadList.add((IDyad)pair.getFirst());
        }
        DenseDyadRankingInstance trueRanking = new DenseDyadRankingInstance(dyadList);
        if (this.removeDyadsWhenQueried) {
            for (IDyad dyad : dyadList) {
                this.removeDyadFromPool(dyad);
            }
        }
        this.queriedRankings.add(trueRanking);
        return trueRanking;
    }

    @Override
    public Set<IDyad> getDyadsByInstance(IVector instanceFeatures) {
        if (!this.dyadsByInstances.containsKey(instanceFeatures)) {
            return new HashSet<IDyad>();
        }
        return this.dyadsByInstances.get(instanceFeatures);
    }

    @Override
    public Set<IDyad> getDyadsByAlternative(IVector alternativeFeatures) {
        if (!this.dyadsByAlternatives.containsKey(alternativeFeatures)) {
            return new HashSet<IDyad>();
        }
        return this.dyadsByAlternatives.get(alternativeFeatures);
    }

    private void addDyadRankingInstance(IDyadRankingInstance instance) {
        this.pool.add(instance);
        this.dyadRankingsByInstances.put(((IDyad)instance.getLabel().get(0)).getContext(), instance);
        this.dyadRankingsByAlternatives.put(((IDyad)instance.getLabel().get(0)).getAlternative(), instance);
        for (IDyad dyad : instance) {
            if (!this.dyadsByInstances.containsKey(dyad.getContext())) {
                this.dyadsByInstances.put(dyad.getContext(), new HashSet());
            }
            this.dyadsByInstances.get(dyad.getContext()).add(dyad);
            if (!this.dyadsByAlternatives.containsKey(dyad.getAlternative())) {
                this.dyadsByAlternatives.put(dyad.getAlternative(), new HashSet());
            }
            this.dyadsByAlternatives.get(dyad.getAlternative()).add(dyad);
        }
    }

    private int getPositionInRankingByInstanceFeatures(IDyad dyad) {
        if (!this.dyadRankingsByInstances.containsKey(dyad.getContext())) {
            return -1;
        }
        IDyadRankingInstance ranking = this.dyadRankingsByInstances.get(dyad.getContext());
        boolean found = false;
        int curPos = 0;
        while (curPos < ranking.getNumberOfRankedElements() && !found) {
            IDyad dyadInRanking = (IDyad)ranking.getLabel().get(curPos);
            if (dyadInRanking.equals(dyad)) {
                found = true;
                continue;
            }
            ++curPos;
        }
        return curPos;
    }

    @Override
    public Collection<IVector> getInstanceFeatures() {
        return this.dyadsByInstances.keySet();
    }

    private void removeDyadFromPool(IDyad dyad) {
        if (this.dyadsByInstances.containsKey(dyad.getContext())) {
            this.dyadsByInstances.get(dyad.getContext()).remove(dyad);
            if (this.dyadsByInstances.get(dyad.getContext()).size() < 2) {
                this.dyadsByInstances.remove(dyad.getContext());
            }
        }
        if (this.dyadsByAlternatives.containsKey(dyad.getAlternative())) {
            this.dyadsByAlternatives.get(dyad.getAlternative()).remove(dyad);
            if (this.dyadsByAlternatives.get(dyad.getAlternative()).size() < 2) {
                this.dyadsByAlternatives.remove(dyad.getAlternative());
            }
        }
    }

    @Override
    public void setRemoveDyadsWhenQueried(boolean flag) {
        this.removeDyadsWhenQueried = flag;
    }

    @Override
    public int getPoolSize() {
        int size = 0;
        for (Set<IDyad> set : this.dyadsByInstances.values()) {
            size += set.size();
        }
        return size;
    }

    public int getNumberQueries() {
        return this.numberQueries;
    }

    @Override
    public DyadRankingDataset getQueriedRankings() {
        return new DyadRankingDataset(new ArrayList<IDyadRankingInstance>(this.queriedRankings));
    }
}

