/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.learner.search;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DenseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.learner.Dyad;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.IDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.util.AbstractDyadScaler;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.IEvaluatedPath;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.ranking.IRanking;
import org.api4.java.ai.ml.ranking.dataset.IRankingInstance;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.common.math.IVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class ADyadRankedNodeQueue<N, V extends Comparable<V>>
implements Queue<IEvaluatedPath<N, ?, V>> {
    private Logger logger = LoggerFactory.getLogger(this.getClass());
    private IDyadRanker dyadRanker;
    protected AbstractDyadScaler scaler;
    private boolean useScaler = false;
    private List<IEvaluatedPath<N, ?, V>> queue = new ArrayList();
    private List<IVector> nodeCharacterizations = new ArrayList<IVector>();
    private IVector originalContextCharacterization;
    private IVector contextCharacterization;
    private List<IDyad> queryDyads = new ArrayList<IDyad>();
    private BiMap<IEvaluatedPath<N, ?, V>, IVector> nodesAndCharacterizationsMap = HashBiMap.create();

    public ADyadRankedNodeQueue(IVector contextCharacterization) {
        this.contextCharacterization = contextCharacterization.addConstantToCopy(0.0);
        this.originalContextCharacterization = contextCharacterization;
        this.logger.trace("Construct ADyadNodeQueue with contexcharacterization {}", (Object)contextCharacterization);
    }

    public ADyadRankedNodeQueue(IVector contextCharacterization, IDyadRanker dyadRanker, AbstractDyadScaler scaler) {
        this(contextCharacterization);
        this.dyadRanker = dyadRanker;
        this.scaler = scaler;
        if (scaler != null) {
            this.useScaler = true;
        }
    }

    protected abstract IVector characterize(IEvaluatedPath<N, ?, V> var1);

    @Override
    public int size() {
        return this.queue.size();
    }

    @Override
    public boolean isEmpty() {
        return this.queue.isEmpty();
    }

    @Override
    public boolean contains(Object o) {
        return this.queue.contains(o);
    }

    @Override
    public Iterator<IEvaluatedPath<N, ?, V>> iterator() {
        return this.queue.iterator();
    }

    @Override
    public Object[] toArray() {
        return this.queue.toArray();
    }

    @Override
    public <T> T[] toArray(T[] a) {
        return this.queue.toArray(a);
    }

    @Override
    public boolean remove(Object o) {
        if (o instanceof IEvaluatedPath) {
            int index = -1;
            for (int i = 0; i < this.queue.size(); ++i) {
                if (!this.queue.get(i).equals(o)) continue;
                index = i;
            }
            if (index != -1) {
                this.removeNodeAtPosition(index);
                return true;
            }
            return false;
        }
        return false;
    }

    @Override
    public boolean containsAll(Collection<?> c) {
        return this.queue.containsAll(c);
    }

    @Override
    public boolean addAll(Collection<? extends IEvaluatedPath<N, ?, V>> c) {
        this.logger.trace("Add {} nodes", (Object)c.size());
        boolean changed = false;
        for (IEvaluatedPath<N, ?, V> elem : c) {
            if (!this.add(elem)) continue;
            changed = true;
        }
        return changed;
    }

    @Override
    public boolean removeAll(Collection<?> c) {
        boolean changed = false;
        for (Object o : c) {
            if (!this.remove(o)) continue;
            changed = true;
        }
        return changed;
    }

    @Override
    public boolean retainAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void clear() {
        this.queue.clear();
        this.nodesAndCharacterizationsMap.clear();
        this.nodeCharacterizations.clear();
    }

    @Override
    public boolean add(IEvaluatedPath<N, ?, V> e) {
        if (this.queue.contains(e)) {
            return true;
        }
        if (e != null) {
            try {
                this.logger.debug("Add node to OPEN.");
                IVector characterization = this.characterize(e);
                this.nodeCharacterizations.add(characterization);
                Dyad newDyad = new Dyad(this.contextCharacterization, characterization);
                this.queryDyads.add(newDyad);
                if (this.useScaler) {
                    DyadRankingDataset dataset = new DyadRankingDataset();
                    dataset.add(new DenseDyadRankingInstance(Arrays.asList(newDyad)));
                    this.scaler.transformAlternatives(dataset);
                }
                this.replaceNaNByZeroes(characterization);
                this.nodesAndCharacterizationsMap.put(e, (Object)characterization);
                IRanking prediction = this.dyadRanker.predict((IRankingInstance)new DenseDyadRankingInstance(this.queryDyads));
                this.queue.clear();
                for (int i = 0; i < ((IRanking)prediction.getPrediction()).size(); ++i) {
                    IEvaluatedPath toAdd = (IEvaluatedPath)this.nodesAndCharacterizationsMap.inverse().get((Object)((Dyad)((IRanking)prediction.getPrediction()).get(i)).getAlternative());
                    if (toAdd != null) {
                        this.queue.add(toAdd);
                        continue;
                    }
                    this.logger.warn("Got a node in a prediction that doesnt exist");
                }
                return true;
            }
            catch (PredictionException e1) {
                this.logger.warn("Failed to characterize: {}", (Object)e1.getLocalizedMessage());
                this.nodeCharacterizations.remove(this.nodeCharacterizations.size() - 1);
                return false;
            }
            catch (InterruptedException e1) {
                Thread.currentThread().interrupt();
                return false;
            }
        }
        return false;
    }

    private void replaceNaNByZeroes(IVector vector) {
        for (int i = 0; i < vector.length(); ++i) {
            if (!Double.isNaN(vector.getValue(i))) continue;
            vector.setValue(i, 0.0);
        }
    }

    @Override
    public boolean offer(IEvaluatedPath<N, ?, V> e) {
        return this.add(e);
    }

    @Override
    public IEvaluatedPath<N, ?, V> remove() {
        return this.removeNodeAtPosition(0);
    }

    public IEvaluatedPath<N, ?, V> removeNodeAtPosition(int i) {
        IEvaluatedPath<N, ?, V> removedNode = this.queue.remove(i);
        this.logger.trace("Retrieve node from OPEN. Index: {}", (Object)i);
        this.nodeCharacterizations.remove(this.nodesAndCharacterizationsMap.get(removedNode));
        IVector removedAlternative = (IVector)this.nodesAndCharacterizationsMap.remove(removedNode);
        int index = -1;
        for (int j = 0; j < this.queryDyads.size(); ++j) {
            if (!this.queryDyads.get(j).getAlternative().equals(removedAlternative)) continue;
            index = j;
            break;
        }
        if (index >= -1) {
            this.queryDyads.remove(index);
        }
        return removedNode;
    }

    @Override
    public IEvaluatedPath<N, ?, V> poll() {
        if (!this.queue.isEmpty()) {
            return this.remove();
        }
        return null;
    }

    @Override
    public IEvaluatedPath<N, ?, V> element() {
        return this.queue.get(0);
    }

    @Override
    public IEvaluatedPath<N, ?, V> peek() {
        if (!this.queue.isEmpty()) {
            this.logger.trace("Peek from OPEN.");
            return this.element();
        }
        return null;
    }

    public IDyadRanker getDyadRanker() {
        return this.dyadRanker;
    }

    public void setDyadRanker(IDyadRanker dyadRanker) {
        this.logger.trace("Update dyad ranker. Was {} now is {}", this.dyadRanker.getClass(), dyadRanker.getClass());
        this.dyadRanker = dyadRanker;
    }

    public AbstractDyadScaler getScaler() {
        return this.scaler;
    }

    public void setScaler(AbstractDyadScaler scaler) {
        if (this.useScaler) {
            this.logger.trace("Update scaler. Was {} now is {}", this.scaler.getClass(), scaler.getClass());
        } else {
            this.logger.trace("Now using scaler {}.", scaler.getClass());
            this.useScaler = true;
        }
        this.scaler = scaler;
        this.contextCharacterization = this.originalContextCharacterization.addConstantToCopy(0.0);
        this.transformContextCharacterization();
    }

    private void transformContextCharacterization() {
        this.logger.trace("Transform context characterization with scaler {}", this.scaler.getClass());
        Dyad dyad = new Dyad(this.contextCharacterization, this.contextCharacterization);
        DyadRankingDataset dataset = new DyadRankingDataset();
        DenseDyadRankingInstance instance = new DenseDyadRankingInstance(Arrays.asList(dyad));
        dataset.add(instance);
        this.scaler.transformInstances(dataset);
    }
}

