/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.search;

import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.algorithm.IDyadRanker;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.util.AbstractDyadScaler;
import ai.libs.jaicore.search.model.travesaltree.Node;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class ADyadRankedNodeQueue<N, V extends Comparable<V>>
implements Queue<Node<N, V>> {
    private Logger logger = LoggerFactory.getLogger(this.getClass());
    private IDyadRanker dyadRanker;
    protected AbstractDyadScaler scaler;
    private boolean useScaler = false;
    private List<Node<N, V>> queue = new ArrayList<Node<N, V>>();
    private List<Vector> nodeCharacterizations = new ArrayList<Vector>();
    private Vector originalContextCharacterization;
    private Vector contextCharacterization;
    private List<Dyad> queryDyads = new ArrayList<Dyad>();
    private BiMap<Node<N, V>, Vector> nodesAndCharacterizationsMap = HashBiMap.create();

    public ADyadRankedNodeQueue(Vector contextCharacterization) {
        this.contextCharacterization = contextCharacterization.addConstantToCopy(0.0);
        this.originalContextCharacterization = contextCharacterization;
        this.logger.trace("Construct ADyadNodeQueue with contexcharacterization {}", (Object)contextCharacterization);
    }

    public ADyadRankedNodeQueue(Vector contextCharacterization, IDyadRanker dyadRanker, AbstractDyadScaler scaler) {
        this(contextCharacterization);
        this.dyadRanker = dyadRanker;
        this.scaler = scaler;
        if (scaler != null) {
            this.useScaler = true;
        }
    }

    protected abstract Vector characterize(Node<N, V> var1);

    @Override
    public int size() {
        return this.queue.size();
    }

    @Override
    public boolean isEmpty() {
        return this.queue.isEmpty();
    }

    @Override
    public boolean contains(Object o) {
        return this.queue.contains(o);
    }

    @Override
    public Iterator<Node<N, V>> iterator() {
        return this.queue.iterator();
    }

    @Override
    public Object[] toArray() {
        return this.queue.toArray();
    }

    @Override
    public <T> T[] toArray(T[] a) {
        return this.queue.toArray(a);
    }

    @Override
    public boolean remove(Object o) {
        if (o instanceof Node) {
            int index = -1;
            for (int i = 0; i < this.queue.size(); ++i) {
                if (!this.queue.get(i).equals(o)) continue;
                index = i;
            }
            if (index != -1) {
                this.removeNodeAtPosition(index);
                return true;
            }
            return false;
        }
        return false;
    }

    @Override
    public boolean containsAll(Collection<?> c) {
        return this.queue.containsAll(c);
    }

    @Override
    public boolean addAll(Collection<? extends Node<N, V>> c) {
        this.logger.trace("Add {} nodes", (Object)c.size());
        boolean changed = false;
        for (Node<N, V> elem : c) {
            if (!this.add(elem)) continue;
            changed = true;
        }
        return changed;
    }

    @Override
    public boolean removeAll(Collection<?> c) {
        boolean changed = false;
        for (Object o : c) {
            if (!this.remove(o)) continue;
            changed = true;
        }
        return changed;
    }

    @Override
    public boolean retainAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void clear() {
        this.queue.clear();
        this.nodesAndCharacterizationsMap.clear();
        this.nodeCharacterizations.clear();
    }

    @Override
    public boolean add(Node<N, V> e) {
        if (this.queue.contains(e)) {
            return true;
        }
        if (e != null) {
            try {
                this.logger.debug("Add node to OPEN, is Goal: {}", (Object)e.isGoal());
                Vector characterization = this.characterize(e);
                this.nodeCharacterizations.add(characterization);
                Dyad newDyad = new Dyad(this.contextCharacterization, characterization);
                this.queryDyads.add(newDyad);
                if (this.useScaler) {
                    DyadRankingDataset dataset = new DyadRankingDataset();
                    dataset.add(new DyadRankingInstance(Arrays.asList(newDyad)));
                    this.scaler.transformAlternatives(dataset);
                }
                this.replaceNaNByZeroes(characterization);
                this.nodesAndCharacterizationsMap.put(e, (Object)characterization);
                IDyadRankingInstance prediction = (IDyadRankingInstance)this.dyadRanker.predict(new DyadRankingInstance(this.queryDyads));
                this.queue.clear();
                for (int i = 0; i < prediction.length(); ++i) {
                    Node toAdd = (Node)this.nodesAndCharacterizationsMap.inverse().get((Object)prediction.getDyadAtPosition(i).getAlternative());
                    if (toAdd != null) {
                        this.queue.add(toAdd);
                        continue;
                    }
                    this.logger.warn("Got a node in a prediction that doesnt exist");
                }
                return true;
            }
            catch (PredictionException e1) {
                this.logger.warn("Failed to characterize: {}", (Object)e1.getLocalizedMessage());
                this.nodeCharacterizations.remove(this.nodeCharacterizations.size() - 1);
                return false;
            }
        }
        return false;
    }

    private void replaceNaNByZeroes(Vector vector) {
        for (int i = 0; i < vector.length(); ++i) {
            if (!Double.isNaN(vector.getValue(i))) continue;
            vector.setValue(i, 0.0);
        }
    }

    @Override
    public boolean offer(Node<N, V> e) {
        return this.add(e);
    }

    @Override
    public Node<N, V> remove() {
        return this.removeNodeAtPosition(0);
    }

    public Node<N, V> removeNodeAtPosition(int i) {
        Node<N, V> removedNode = this.queue.remove(i);
        this.logger.trace("Retrieve node from OPEN. Is goal: {}, Index: {}", (Object)removedNode.isGoal(), (Object)i);
        this.nodeCharacterizations.remove(this.nodesAndCharacterizationsMap.get(removedNode));
        Vector removedAlternative = (Vector)this.nodesAndCharacterizationsMap.remove(removedNode);
        int index = -1;
        for (int j = 0; j < this.queryDyads.size(); ++j) {
            if (!this.queryDyads.get(j).getAlternative().equals(removedAlternative)) continue;
            index = j;
            break;
        }
        if (index >= -1) {
            this.queryDyads.remove(index);
        }
        return removedNode;
    }

    @Override
    public Node<N, V> poll() {
        if (!this.queue.isEmpty()) {
            return this.remove();
        }
        return null;
    }

    @Override
    public Node<N, V> element() {
        return this.queue.get(0);
    }

    @Override
    public Node<N, V> peek() {
        if (!this.queue.isEmpty()) {
            this.logger.trace("Peek from OPEN. Is goal: {}", (Object)this.element().isGoal());
            return this.element();
        }
        return null;
    }

    public IDyadRanker getDyadRanker() {
        return this.dyadRanker;
    }

    public void setDyadRanker(IDyadRanker dyadRanker) {
        this.logger.trace("Update dyad ranker. Was {} now is {}", this.dyadRanker.getClass(), dyadRanker.getClass());
        this.dyadRanker = dyadRanker;
    }

    public AbstractDyadScaler getScaler() {
        return this.scaler;
    }

    public void setScaler(AbstractDyadScaler scaler) {
        if (this.useScaler) {
            this.logger.trace("Update scaler. Was {} now is {}", this.scaler.getClass(), scaler.getClass());
        } else {
            this.logger.trace("Now using scaler {}.", scaler.getClass());
            this.useScaler = true;
        }
        this.scaler = scaler;
        this.contextCharacterization = this.originalContextCharacterization.addConstantToCopy(0.0);
        this.transformContextCharacterization();
    }

    private void transformContextCharacterization() {
        this.logger.trace("Transform context characterization with scaler {}", this.scaler.getClass());
        Dyad dyad = new Dyad(this.contextCharacterization, this.contextCharacterization);
        DyadRankingInstance instance = new DyadRankingInstance(Arrays.asList(dyad));
        DyadRankingDataset dataset = new DyadRankingDataset();
        dataset.add(instance);
        this.scaler.transformInstances(dataset);
    }
}

