/*
 * Decompiled with CFR 0.152.
 */
package org.nlpub.watset.graph;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.jgrapht.Graph;
import org.nlpub.watset.graph.Clustering;
import org.nlpub.watset.graph.NodeWeighting;
import org.nlpub.watset.util.Maximizer;
import org.nlpub.watset.util.Neighbors;

public class ChineseWhispers<V, E>
implements Clustering<V> {
    public static final int ITERATIONS = 20;
    protected final Graph<V, E> graph;
    protected final NodeWeighting<V, E> weighting;
    protected final int iterations;
    protected final Random random;
    protected Map<V, Integer> labels;
    protected int steps;

    public static <V, E> Function<Graph<V, E>, Clustering<V>> provider(NodeWeighting<V, E> weighting) {
        return graph -> new ChineseWhispers(graph, weighting);
    }

    public static <V, E> Function<Graph<V, E>, Clustering<V>> provider(NodeWeighting<V, E> weighting, int iterations, Random random) {
        return graph -> new ChineseWhispers(graph, weighting, iterations, random);
    }

    public ChineseWhispers(Graph<V, E> graph, NodeWeighting<V, E> weighting, int iterations, Random random) {
        this.graph = Objects.requireNonNull(graph);
        this.weighting = Objects.requireNonNull(weighting);
        this.iterations = iterations;
        this.random = Objects.requireNonNull(random);
    }

    public ChineseWhispers(Graph<V, E> graph, NodeWeighting<V, E> weighting) {
        this(graph, weighting, 20, new Random());
    }

    @Override
    public void fit() {
        ArrayList nodes = new ArrayList(this.graph.vertexSet());
        this.labels = new HashMap<V, Integer>(nodes.size());
        int i = 0;
        for (Object node : this.graph.vertexSet()) {
            this.labels.put((Integer)node, i++);
        }
        this.steps = 0;
        while (this.steps < this.iterations) {
            Collections.shuffle(nodes, this.random);
            if (this.step(nodes) == 0) break;
            ++this.steps;
        }
    }

    protected int step(List<V> nodes) {
        int changed = 0;
        Iterator<V> iterator = nodes.iterator();
        while (iterator.hasNext()) {
            V node;
            Map<Integer, Double> scores = this.score(this.graph, this.labels, this.weighting, node = iterator.next());
            Optional<Map.Entry> label = Maximizer.argmaxRandom(scores.entrySet().iterator(), Map.Entry::getValue, this.random);
            int updated = label.isPresent() ? ((Integer)label.get().getKey()).intValue() : this.labels.get(node).intValue();
            int previous = this.labels.put((Integer)node, updated);
            if (previous == updated) continue;
            ++changed;
        }
        return changed;
    }

    @Override
    public Collection<Collection<V>> getClusters() {
        Objects.requireNonNull(this.labels, "call fit() first");
        Map<Integer, List<Map.Entry>> groups = this.labels.entrySet().stream().collect(Collectors.groupingBy(Map.Entry::getValue));
        ArrayList<Collection<V>> clusters = new ArrayList<Collection<V>>(groups.size());
        for (List<Map.Entry> cluster : groups.values()) {
            clusters.add(cluster.stream().map(Map.Entry::getKey).collect(Collectors.toSet()));
        }
        return clusters;
    }

    protected Map<Integer, Double> score(Graph<V, E> graph, Map<V, Integer> labels, NodeWeighting<V, E> weighting, V node) {
        HashMap<Integer, Double> weights = new HashMap<Integer, Double>();
        Iterator<V> neighbors = Neighbors.neighborIterator(graph, node);
        neighbors.forEachRemaining(neighbor -> {
            int label = (Integer)labels.get(neighbor);
            weights.merge(label, weighting.apply(graph, labels, node, neighbor), Double::sum);
        });
        return weights;
    }

    public int getIterations() {
        return this.iterations;
    }

    public int getSteps() {
        return this.steps;
    }
}

