/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import smile.math.MathEx;
import smile.neighbor.LSH;
import smile.neighbor.Neighbor;
import smile.neighbor.RNNSearch;
import smile.neighbor.lsh.Bucket;
import smile.neighbor.lsh.Hash;
import smile.neighbor.lsh.MultiProbeHash;
import smile.neighbor.lsh.MultiProbeSample;
import smile.neighbor.lsh.PosterioriModel;
import smile.sort.HeapSelect;
import smile.util.IntArrayList;

public class MPLSH<E>
extends LSH<E> {
    private static final long serialVersionUID = 2L;
    private List<PosterioriModel> model;

    public MPLSH(int d, int L, int k, double w) {
        this(d, L, k, w, 1017881);
    }

    public MPLSH(int d, int L, int k, double w, int H) {
        super(d, L, k, w, H);
    }

    @Override
    protected void initHashTable(int d, int L, int k, double w, int H) {
        this.hash = new ArrayList(L);
        for (int i = 0; i < L; ++i) {
            this.hash.add(new MultiProbeHash(d, k, w, H));
        }
    }

    @Override
    public String toString() {
        return "Multi-Probe " + super.toString();
    }

    public void fit(RNNSearch<double[], double[]> range, double[][] samples, double radius) {
        this.fit(range, samples, radius, 2500);
    }

    public void fit(RNNSearch<double[], double[]> range, double[][] samples, double radius, int Nz) {
        this.fit(range, samples, radius, Nz, 0.2);
    }

    public void fit(RNNSearch<double[], double[]> range, double[][] samples, double radius, int Nz, double sigma) {
        MultiProbeSample[] training = new MultiProbeSample[samples.length];
        for (int i = 0; i < samples.length; ++i) {
            training[i] = new MultiProbeSample(samples[i], new LinkedList<double[]>());
            ArrayList neighbors = new ArrayList();
            range.search(samples[i], radius, neighbors);
            for (Neighbor neighbor : neighbors) {
                training[i].neighbors.add((double[])this.keys.get(neighbor.index()));
            }
        }
        this.model = new ArrayList<PosterioriModel>(this.hash.size());
        for (Hash h : this.hash) {
            this.model.add(new PosterioriModel((MultiProbeHash)h, training, Nz, sigma));
        }
    }

    @Override
    public Neighbor<double[], E> nearest(double[] q) {
        if (this.model == null) {
            return super.nearest(q);
        }
        return this.nearest(q, 0.95, 100);
    }

    public Neighbor<double[], E> nearest(double[] q, double recall, int T) {
        if (recall > 1.0 || recall < 0.0) {
            throw new IllegalArgumentException("Invalid recall: " + recall);
        }
        double[] key = null;
        int index = -1;
        double nearest = Double.MAX_VALUE;
        Set<Integer> candidates = this.getCandidates(q, recall, T);
        for (int i : candidates) {
            double distance;
            double[] x = (double[])this.keys.get(i);
            if (q == x || !((distance = MathEx.distance(q, x)) < nearest)) continue;
            index = i;
            nearest = distance;
            key = x;
        }
        return index == -1 ? null : new Neighbor(key, this.data.get(index), index, nearest);
    }

    @Override
    public Neighbor<double[], E>[] search(double[] q, int k) {
        if (this.model == null) {
            return super.search(q, k);
        }
        return this.search(q, k, 0.95, 100);
    }

    public Neighbor<double[], E>[] search(double[] q, int k, double recall, int T) {
        if (recall > 1.0 || recall < 0.0) {
            throw new IllegalArgumentException("Invalid recall: " + recall);
        }
        if (k < 1) {
            throw new IllegalArgumentException("Invalid k: " + k);
        }
        Set<Integer> candidates = this.getCandidates(q, recall, T);
        k = Math.min(k, candidates.size());
        HeapSelect heap = new HeapSelect((Comparable[])new Neighbor[k]);
        for (int index : candidates) {
            double[] key = (double[])this.keys.get(index);
            if (q == key) continue;
            double distance = MathEx.distance(q, key);
            heap.add(new Neighbor(key, this.data.get(index), index, distance));
        }
        heap.sort();
        return (Neighbor[])heap.toArray();
    }

    @Override
    public void search(double[] q, double radius, List<Neighbor<double[], E>> neighbors) {
        if (this.model == null) {
            super.search(q, radius, neighbors);
        } else {
            this.search(q, radius, neighbors, 0.95, 100);
        }
    }

    public void search(double[] q, double radius, List<Neighbor<double[], E>> neighbors, double recall, int T) {
        if (radius <= 0.0) {
            throw new IllegalArgumentException("Invalid radius: " + radius);
        }
        if (recall > 1.0 || recall < 0.0) {
            throw new IllegalArgumentException("Invalid recall: " + recall);
        }
        Set<Integer> candidates = this.getCandidates(q, recall, T);
        for (int index : candidates) {
            double distance;
            double[] key = (double[])this.keys.get(index);
            if (q == key || !((distance = MathEx.distance(q, key)) <= radius)) continue;
            neighbors.add(new Neighbor(key, this.data.get(index), index, distance));
        }
    }

    private Set<Integer> getCandidates(double[] q, double recall, int T) {
        double alpha = 1.0 - Math.pow(1.0 - recall, 1.0 / (double)this.hash.size());
        LinkedHashSet<Integer> candidates = new LinkedHashSet<Integer>();
        for (int i = 0; i < this.hash.size(); ++i) {
            IntArrayList buckets = this.model.get(i).getProbeSequence(q, alpha, T);
            for (int j = 0; j < buckets.size(); ++j) {
                Bucket bin = ((Hash)this.hash.get(i)).get(buckets.get(j));
                if (bin == null) continue;
                IntArrayList points = bin.points();
                for (int l = 0; l < points.size(); ++l) {
                    candidates.add(points.get(l));
                }
            }
        }
        return candidates;
    }
}

