/*
 * Decompiled with CFR 0.152.
 */
package smile.graph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.graph.AdjacencyList;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.Metric;
import smile.neighbor.RandomProjectionTree;

public record NearestNeighborGraph(int k, int[][] neighbors, double[][] distances, int[] index) {
    private static final Logger logger = LoggerFactory.getLogger(NearestNeighborGraph.class);

    public NearestNeighborGraph(int k, int[][] neighbors, double[][] distances) {
        this(k, neighbors, distances, IntStream.range(0, neighbors.length).toArray());
    }

    public int size() {
        return this.neighbors.length;
    }

    public AdjacencyList graph(boolean digraph) {
        int n = this.neighbors.length;
        AdjacencyList graph = new AdjacencyList(n, digraph);
        IntStream.range(0, n).forEach(i -> {
            int[] neighbor = this.neighbors[i];
            double[] distance = this.distances[i];
            for (int j = 0; j < neighbor.length; ++j) {
                graph.setWeight(i, neighbor[j], distance[j]);
            }
        });
        return graph;
    }

    public static NearestNeighborGraph of(double[][] data, int k) {
        return NearestNeighborGraph.of(data, MathEx::distance, k);
    }

    public NearestNeighborGraph largest(boolean digraph) {
        AdjacencyList graph = this.graph(digraph);
        int[][] cc = graph.bfcc();
        if (cc.length == 1) {
            return this;
        }
        int[] index = Arrays.stream(cc).max(Comparator.comparing(a -> ((int[])a).length)).orElseThrow(NoSuchElementException::new);
        logger.info("{} connected components, largest one has {} samples.", (Object)cc.length, (Object)index.length);
        int n = this.neighbors.length;
        int[] reverseIndex = new int[n];
        for (int i = 0; i < n; ++i) {
            reverseIndex[index[i]] = i;
        }
        int[][] nearest = new int[n][this.k];
        double[][] dist = new double[n][this.k];
        for (int i = 0; i < n; ++i) {
            dist[i] = this.distances[index[i]];
            int[] ni = this.neighbors[index[i]];
            for (int j = 0; j < this.k; ++j) {
                nearest[i][j] = reverseIndex[ni[j]];
            }
        }
        return new NearestNeighborGraph(this.k, nearest, dist, index);
    }

    public static <T> NearestNeighborGraph of(T[] data, Distance<T> distance, int k) {
        List<PriorityQueue<Neighbor>> heap = NearestNeighborGraph.build(data, distance, k, (n, k_, i) -> IntStream.range(0, n).toArray());
        return NearestNeighborGraph.toGraph(heap, k);
    }

    public static <T> NearestNeighborGraph random(T[] data, Distance<T> distance, int k) {
        List<PriorityQueue<Neighbor>> heap = NearestNeighborGraph.build(data, distance, k, NearestNeighborGraph::rejectionSample);
        NearestNeighborGraph.extend(heap);
        return NearestNeighborGraph.toGraph(heap, k);
    }

    private static <T> List<PriorityQueue<Neighbor>> build(T[] data, Distance<T> distance, int k, CandidateGenerator candidates) {
        if (k < 2) {
            throw new IllegalArgumentException("k must be greater than 1: " + k);
        }
        int n = data.length;
        ArrayList<PriorityQueue<Neighbor>> heap = new ArrayList<PriorityQueue<Neighbor>>(n);
        for (int i2 = 0; i2 < n; ++i2) {
            heap.add(new PriorityQueue());
        }
        IntStream.range(0, n).parallel().forEach(i -> {
            Object xi = data[i];
            PriorityQueue pq = (PriorityQueue)heap.get(i);
            for (int j : candidates.generate(n, k, i)) {
                if (j == i) continue;
                double dist = distance.d(xi, data[j]);
                if (pq.size() < k) {
                    pq.offer(new Neighbor(j, dist));
                    continue;
                }
                if (!(dist < ((Neighbor)pq.peek()).distance)) continue;
                Neighbor neighbor = (Neighbor)pq.poll();
                neighbor.index = j;
                neighbor.distance = dist;
                pq.offer(neighbor);
            }
        });
        return heap;
    }

    private static void extend(List<PriorityQueue<Neighbor>> heap) {
        PriorityQueue<Neighbor> pq;
        Set set;
        int i;
        int n = heap.size();
        ArrayList neighbors = new ArrayList(n);
        ArrayList reverseNeighbors = new ArrayList(n);
        for (i = 0; i < n; ++i) {
            neighbors.add(new HashSet());
            reverseNeighbors.add(new HashSet());
        }
        for (i = 0; i < n; ++i) {
            set = (Set)neighbors.get(i);
            pq = heap.get(i);
            for (Neighbor neighbor : pq) {
                set.add(neighbor.index);
                ((Set)reverseNeighbors.get(neighbor.index)).add(new Neighbor(i, neighbor.distance));
            }
        }
        for (i = 0; i < n; ++i) {
            set = (Set)neighbors.get(i);
            pq = heap.get(i);
            for (Neighbor neighbor : (Set)reverseNeighbors.get(i)) {
                if (set.contains(neighbor.index) || !(neighbor.distance < pq.peek().distance)) continue;
                Neighbor top = pq.poll();
                top.index = neighbor.index;
                top.distance = neighbor.distance;
                pq.offer(top);
            }
        }
    }

    private static NearestNeighborGraph toGraph(List<PriorityQueue<Neighbor>> heap, int k) {
        int n = heap.size();
        int[][] neighbors = new int[n][k];
        double[][] distances = new double[n][k];
        for (int i = 0; i < n; ++i) {
            PriorityQueue<Neighbor> pq = heap.get(i);
            int j = pq.size();
            while (!pq.isEmpty()) {
                Neighbor neighbor = pq.poll();
                if (--j >= k) continue;
                neighbors[i][j] = neighbor.index;
                distances[i][j] = neighbor.distance;
            }
        }
        return new NearestNeighborGraph(k, neighbors, distances);
    }

    public static NearestNeighborGraph descent(double[][] data, int k) {
        return NearestNeighborGraph.descent(data, k, 5, k, 50, 50, 0.001);
    }

    public static NearestNeighborGraph descent(double[][] data, int k, int numTrees, int leafSize, int maxCandidates, int maxIter, double delta) {
        int n = data.length;
        ArrayList<PriorityQueue<Neighbor>> heapList = new ArrayList<PriorityQueue<Neighbor>>(data.length);
        ArrayList neighborSetList = new ArrayList(data.length);
        for (int i = 0; i < data.length; ++i) {
            heapList.add(new PriorityQueue());
            neighborSetList.add(new HashSet());
        }
        for (int ti = 0; ti < numTrees; ++ti) {
            RandomProjectionTree tree = RandomProjectionTree.of(data, leafSize, false);
            for (int[] leaf : tree.leafSamples()) {
                for (int li = 0; li < leaf.length; ++li) {
                    int i = leaf[li];
                    double[] xi = data[i];
                    for (int lj = li + 1; lj < leaf.length; ++lj) {
                        int j = leaf[lj];
                        double[] xj = data[j];
                        double dist = MathEx.distance(xi, xj);
                        NearestNeighborGraph.updateHeap((PriorityQueue)heapList.get(i), (Set)neighborSetList.get(i), k, j, dist);
                        NearestNeighborGraph.updateHeap((PriorityQueue)heapList.get(j), (Set)neighborSetList.get(j), k, i, dist);
                    }
                }
            }
        }
        return NearestNeighborGraph.descent(data, MathEx::distance, heapList, k, maxCandidates, maxIter, delta);
    }

    private static boolean updateHeap(PriorityQueue<Neighbor> pq, Set<Integer> set, int k, int index, double dist) {
        if (!set.contains(index)) {
            if (pq.size() < k) {
                pq.add(new Neighbor(index, dist));
                set.add(index);
                return true;
            }
            if (dist < pq.peek().distance) {
                Neighbor top = pq.poll();
                set.remove(top.index);
                set.add(index);
                top.distance = dist;
                top.index = index;
                pq.offer(top);
                return true;
            }
        }
        return false;
    }

    public static <T> NearestNeighborGraph descent(T[] data, Metric<T> distance, int k) {
        return NearestNeighborGraph.descent(data, distance, k, 50, 10, 0.001);
    }

    public static <T> NearestNeighborGraph descent(T[] data, Metric<T> distance, int k, int maxCandidates, int maxIter, double delta) {
        if (k < 2) {
            throw new IllegalArgumentException("k must be greater than 1: " + k);
        }
        List<PriorityQueue<Neighbor>> heap = NearestNeighborGraph.build(data, distance, k, NearestNeighborGraph::rejectionSample);
        NearestNeighborGraph.extend(heap);
        return NearestNeighborGraph.descent(data, distance, heap, k, maxCandidates, maxIter, delta);
    }

    private static <T> NearestNeighborGraph descent(T[] data, Metric<T> distance, List<PriorityQueue<Neighbor>> heapList, int k, int maxCandidates, int maxIter, double delta) {
        int i;
        int n = data.length;
        ArrayList neighborSetList = new ArrayList(data.length);
        for (i = 0; i < data.length; ++i) {
            neighborSetList.add(new HashSet());
        }
        for (i = 0; i < n; ++i) {
            Set set = (Set)neighborSetList.get(i);
            for (Neighbor neighbor : heapList.get(i)) {
                set.add(neighbor.index);
            }
        }
        for (int iter = 1; iter <= maxIter; ++iter) {
            int count = 0;
            int[][] candidates = NearestNeighborGraph.generateCandidates(heapList, maxCandidates);
            for (int i2 = 0; i2 < n; ++i2) {
                for (int j : candidates[i2]) {
                    double dist = distance.d(data[i2], data[j]);
                    if (NearestNeighborGraph.updateHeap(heapList.get(i2), (Set)neighborSetList.get(i2), k, j, dist)) {
                        ++count;
                    }
                    if (!NearestNeighborGraph.updateHeap(heapList.get(j), (Set)neighborSetList.get(j), k, i2, dist)) continue;
                    ++count;
                }
            }
            logger.info("NearestNeighborDescent iteration {}: {}", (Object)iter, (Object)count);
            if ((double)count <= delta * (double)k * (double)n) break;
        }
        return NearestNeighborGraph.toGraph(heapList, k);
    }

    private static int[][] generateCandidates(List<PriorityQueue<Neighbor>> heapList, int maxCandidates) {
        int i;
        int n = heapList.size();
        ArrayList candidates = new ArrayList(n);
        for (i = 0; i < n; ++i) {
            candidates.add(new HashSet());
        }
        for (i = 0; i < n; ++i) {
            PriorityQueue<Neighbor> pqi = heapList.get(i);
            for (Neighbor ni : pqi) {
                int j = ni.index;
                double dij = ni.distance;
                PriorityQueue<Neighbor> pqj = heapList.get(j);
                for (Neighbor nj : pqj) {
                    int k = nj.index;
                    double djk = nj.distance;
                    ((Set)candidates.get(i)).add(new Neighbor(k, dij + djk));
                    ((Set)candidates.get(k)).add(new Neighbor(i, dij + djk));
                }
            }
        }
        int[][] result = new int[n][];
        for (int i2 = 0; i2 < n; ++i2) {
            ArrayList<Neighbor> list = new ArrayList<Neighbor>((Collection)candidates.get(i2));
            list.sort(Comparator.comparingDouble(o -> o.distance));
            result[i2] = list.stream().limit(maxCandidates).mapToInt(neighbor -> neighbor.index).toArray();
        }
        return result;
    }

    private static int[] rejectionSample(int n, int k, int i) {
        if (k > n) {
            throw new IllegalArgumentException();
        }
        int[] samples = new int[k];
        for (int j = 0; j < k; ++j) {
            boolean loop = true;
            block1: while (loop) {
                loop = false;
                samples[j] = MathEx.randomInt(n);
                if (samples[j] == i) {
                    loop = true;
                    continue;
                }
                for (int l = 0; l < j; ++l) {
                    if (samples[j] != samples[l]) continue;
                    loop = true;
                    continue block1;
                }
            }
        }
        return samples;
    }

    private static interface CandidateGenerator {
        public int[] generate(int var1, int var2, int var3);
    }

    private static class Neighbor
    implements Comparable<Neighbor> {
        public int index;
        public double distance;

        public Neighbor(int index, double distance) {
            this.index = index;
            this.distance = distance;
        }

        public int hashCode() {
            return this.index;
        }

        @Override
        public int compareTo(Neighbor o) {
            return Double.compare(o.distance, this.distance);
        }
    }
}

