/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.activelearning;

import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.activelearning.IDyadRankingPoolProvider;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.SparseDyadRankingInstance;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.nd4j.linalg.primitives.Pair;

public class DyadDatasetPoolProvider
implements IDyadRankingPoolProvider {
    private HashMap<Vector, Set<Dyad>> dyadsByInstances = new HashMap();
    private HashMap<Vector, Set<Dyad>> dyadsByAlternatives = new HashMap();
    private HashMap<Vector, IDyadRankingInstance> dyadRankingsByInstances = new HashMap();
    private HashMap<Vector, IDyadRankingInstance> dyadRankingsByAlternatives = new HashMap();
    private List<IDyadRankingInstance> pool;
    private boolean removeDyadsWhenQueried = false;
    private HashSet<IDyadRankingInstance> queriedRankings;
    private int numberQueries = 0;

    public DyadDatasetPoolProvider(DyadRankingDataset dataset) {
        this.pool = new ArrayList<IDyadRankingInstance>(dataset.size());
        for (IDyadRankingInstance instance : dataset) {
            this.addDyadRankingInstance(instance);
        }
        this.queriedRankings = new HashSet();
    }

    @Override
    public Collection<IDyadRankingInstance> getPool() {
        return this.pool;
    }

    @Override
    public IDyadRankingInstance query(IDyadRankingInstance queryInstance) {
        ++this.numberQueries;
        if (!(queryInstance instanceof SparseDyadRankingInstance)) {
            throw new IllegalArgumentException("Currently only supports SparseDyadRankingInstances!");
        }
        SparseDyadRankingInstance drInstance = (SparseDyadRankingInstance)queryInstance;
        ArrayList<Pair> dyadPositionPairs = new ArrayList<Pair>(drInstance.length());
        for (Object dyad : drInstance) {
            int position = this.getPositionInRankingByInstanceFeatures((Dyad)dyad);
            dyadPositionPairs.add(new Pair(dyad, (Object)position));
        }
        Collections.sort(dyadPositionPairs, Comparator.comparing(Pair::getRight));
        ArrayList<Dyad> dyadList = new ArrayList<Dyad>(dyadPositionPairs.size());
        for (Pair pair : dyadPositionPairs) {
            dyadList.add((Dyad)pair.getFirst());
        }
        DyadRankingInstance trueRanking = new DyadRankingInstance(dyadList);
        if (this.removeDyadsWhenQueried) {
            for (Dyad dyad : dyadList) {
                this.removeDyadFromPool(dyad);
            }
        }
        this.queriedRankings.add(trueRanking);
        return trueRanking;
    }

    @Override
    public Set<Dyad> getDyadsByInstance(Vector instanceFeatures) {
        if (!this.dyadsByInstances.containsKey(instanceFeatures)) {
            return new HashSet<Dyad>();
        }
        return this.dyadsByInstances.get(instanceFeatures);
    }

    @Override
    public Set<Dyad> getDyadsByAlternative(Vector alternativeFeatures) {
        if (!this.dyadsByAlternatives.containsKey(alternativeFeatures)) {
            return new HashSet<Dyad>();
        }
        return this.dyadsByAlternatives.get(alternativeFeatures);
    }

    private void addDyadRankingInstance(IDyadRankingInstance instance) {
        this.pool.add(instance);
        this.dyadRankingsByInstances.put(instance.getDyadAtPosition(0).getInstance(), instance);
        this.dyadRankingsByAlternatives.put(instance.getDyadAtPosition(0).getAlternative(), instance);
        for (Dyad dyad : instance) {
            if (!this.dyadsByInstances.containsKey(dyad.getInstance())) {
                this.dyadsByInstances.put(dyad.getInstance(), new HashSet());
            }
            this.dyadsByInstances.get(dyad.getInstance()).add(dyad);
            if (!this.dyadsByAlternatives.containsKey(dyad.getAlternative())) {
                this.dyadsByAlternatives.put(dyad.getAlternative(), new HashSet());
            }
            this.dyadsByAlternatives.get(dyad.getAlternative()).add(dyad);
        }
    }

    private int getPositionInRankingByInstanceFeatures(Dyad dyad) {
        if (!this.dyadRankingsByInstances.containsKey(dyad.getInstance())) {
            return -1;
        }
        IDyadRankingInstance ranking = this.dyadRankingsByInstances.get(dyad.getInstance());
        boolean found = false;
        int curPos = 0;
        while (curPos < ranking.length() && !found) {
            Dyad dyadInRanking = ranking.getDyadAtPosition(curPos);
            if (dyadInRanking.equals(dyad)) {
                found = true;
                continue;
            }
            ++curPos;
        }
        return curPos;
    }

    @Override
    public Collection<Vector> getInstanceFeatures() {
        return this.dyadsByInstances.keySet();
    }

    private void removeDyadFromPool(Dyad dyad) {
        if (this.dyadsByInstances.containsKey(dyad.getInstance())) {
            this.dyadsByInstances.get(dyad.getInstance()).remove(dyad);
            if (this.dyadsByInstances.get(dyad.getInstance()).size() < 2) {
                this.dyadsByInstances.remove(dyad.getInstance());
            }
        }
        if (this.dyadsByAlternatives.containsKey(dyad.getAlternative())) {
            this.dyadsByAlternatives.get(dyad.getAlternative()).remove(dyad);
            if (this.dyadsByAlternatives.get(dyad.getAlternative()).size() < 2) {
                this.dyadsByAlternatives.remove(dyad.getAlternative());
            }
        }
    }

    @Override
    public void setRemoveDyadsWhenQueried(boolean flag) {
        this.removeDyadsWhenQueried = flag;
    }

    @Override
    public int getPoolSize() {
        int size = 0;
        for (Set<Dyad> set : this.dyadsByInstances.values()) {
            size += set.size();
        }
        return size;
    }

    public int getNumberQueries() {
        return this.numberQueries;
    }

    @Override
    public DyadRankingDataset getQueriedRankings() {
        return new DyadRankingDataset((List<IDyadRankingInstance>)new ArrayList<IDyadRankingInstance>(this.queriedRankings));
    }
}

