/*
 * Decompiled with CFR 0.152.
 */
package jaicore.search.algorithms.standard.mcts;

import ai.libs.jaicore.basic.ILoggingCustomizable;
import jaicore.search.algorithms.standard.mcts.IPathUpdatablePolicy;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UCBPolicy<T, A>
implements IPathUpdatablePolicy<T, A, Double>,
ILoggingCustomizable {
    private String loggerName;
    private Logger logger = LoggerFactory.getLogger(UCBPolicy.class);
    private final boolean maximize;
    private final Map<T, NodeLabel> labels = new HashMap<T, NodeLabel>();

    public UCBPolicy() {
        this(true);
    }

    public UCBPolicy(boolean maximize) {
        this.maximize = maximize;
    }

    @Override
    public void updatePath(List<T> path, Double score) {
        this.logger.debug("Updating path {} with score {}", path, (Object)score);
        for (T node : path) {
            if (!this.labels.containsKey(node)) {
                this.labels.put(node, new NodeLabel());
            }
            NodeLabel label = this.labels.get(node);
            label.visits++;
            label.scores.addValue(score.doubleValue());
            this.logger.trace("Updated label of node {}. Visits now {}, stats contains {} entries with mean {}", new Object[]{node, label.visits, label.scores.getN(), label.scores.getMean()});
        }
    }

    @Override
    public A getAction(T node, Map<A, T> actionsWithTheirSuccessors) {
        Set<A> possibleActions = actionsWithTheirSuccessors.keySet();
        this.logger.debug("Deriving action for node {}. The {} options are: {}", new Object[]{node, possibleActions.size(), actionsWithTheirSuccessors});
        List actionsThatHaveNotBeenTriedYet = possibleActions.stream().filter(a -> !this.labels.containsKey(actionsWithTheirSuccessors.get(a))).collect(Collectors.toList());
        if (!actionsThatHaveNotBeenTriedYet.isEmpty()) {
            Object action = actionsThatHaveNotBeenTriedYet.get(0);
            T child = actionsWithTheirSuccessors.get(action);
            this.labels.put(child, new NodeLabel());
            this.logger.info("Dictating action {}, because this was never played before.", action);
            return (A)action;
        }
        double best = this.maximize ? Double.MIN_VALUE : Double.MAX_VALUE;
        this.logger.debug("All actions have been tried. Label is: {}", (Object)this.labels.get(node));
        int n = this.labels.get(node).visits;
        A choice = null;
        for (Object action : possibleActions) {
            T child = actionsWithTheirSuccessors.get(action);
            NodeLabel label = this.labels.get(child);
            assert (label.visits != 0) : "Visits of node " + child + " cannot be 0 if we already used this action before!";
            assert (label.scores.getN() != 0L) : "Number of observations cannot be 0 if we already visited this node before";
            this.logger.trace("Considering action {} whose successor state has stats {} and {} visits", new Object[]{action, label.scores.getMean(), label.visits});
            double ucb = label.scores.getMean() + (double)(this.maximize ? 1 : -1) * Math.sqrt(2.0 * Math.log(n) / (double)label.visits);
            assert (!new Double(ucb).equals(Double.NaN)) : "The UCB score is NaN, which cannot be the case. Score mean is " + NodeLabel.access$100(label).getMean() + ", number of visits is " + NodeLabel.access$000(label);
            if (this.maximize && ucb > best || !this.maximize && ucb < best) {
                this.logger.trace("Updating best choice {} with {} since it is better than the current solution with performance {}", new Object[]{choice, action, best});
                best = ucb;
                choice = (A)action;
                continue;
            }
            this.logger.trace("Skipping current solution {} since its score {} is not better than the currently best {}.", new Object[]{action, ucb, best});
        }
        assert (choice != null) : "Would return null, but this must not be the case!";
        this.logger.info("Recommending action {}.", choice);
        return choice;
    }

    public String getLoggerName() {
        return this.loggerName;
    }

    public void setLoggerName(String name) {
        this.loggerName = name;
        this.logger = LoggerFactory.getLogger((String)name);
    }

    class NodeLabel {
        private final DescriptiveStatistics scores = new DescriptiveStatistics();
        private int visits;

        NodeLabel() {
        }

        public String toString() {
            return "NodeLabel [scores=" + this.scores + ", visits=" + this.visits + "]";
        }
    }
}

