/*
 * Decompiled with CFR 0.152.
 */
package org.graylog.shaded.opensearch2.org.apache.lucene.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.FutureTask;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.FieldInfo;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.IndexReader;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.LeafReaderContext;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.BooleanClause;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.BooleanQuery;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.DocIdSetIterator;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.Explanation;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.FieldExistsQuery;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.FilteredDocIdSetIterator;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.HitQueue;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.IndexSearcher;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.MatchNoDocsQuery;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.Query;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.QueryVisitor;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.ScoreDoc;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.ScoreMode;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.Scorer;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.SliceExecutor;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TopDocs;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TopDocsCollector;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TotalHits;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.VectorScorer;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.Weight;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.BitSet;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.BitSetIterator;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.Bits;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.ThreadInterruptedException;

abstract class AbstractKnnVectorQuery
extends Query {
    private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
    protected final String field;
    protected final int k;
    private final Query filter;

    public AbstractKnnVectorQuery(String field, int k, Query filter) {
        this.field = Objects.requireNonNull(field, "field");
        this.k = k;
        if (k < 1) {
            throw new IllegalArgumentException("k must be at least 1, got: " + k);
        }
        this.filter = filter;
    }

    @Override
    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        Weight filterWeight;
        IndexReader reader = indexSearcher.getIndexReader();
        if (this.filter != null) {
            BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.filter, BooleanClause.Occur.FILTER).add(new FieldExistsQuery(this.field), BooleanClause.Occur.FILTER).build();
            Query rewritten = indexSearcher.rewrite(booleanQuery);
            filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
        } else {
            filterWeight = null;
        }
        SliceExecutor sliceExecutor = indexSearcher.getSliceExecutor();
        TopDocs[] perLeafResults = sliceExecutor == null ? this.sequentialSearch(reader.leaves(), filterWeight) : this.parallelSearch(indexSearcher.getSlices(), filterWeight, sliceExecutor);
        TopDocs topK = TopDocs.merge(this.k, perLeafResults);
        if (topK.scoreDocs.length == 0) {
            return new MatchNoDocsQuery();
        }
        return this.createRewrittenQuery(reader, topK);
    }

    private TopDocs[] sequentialSearch(List<LeafReaderContext> leafReaderContexts, Weight filterWeight) {
        try {
            TopDocs[] perLeafResults = new TopDocs[leafReaderContexts.size()];
            for (LeafReaderContext ctx : leafReaderContexts) {
                perLeafResults[ctx.ord] = this.searchLeaf(ctx, filterWeight);
            }
            return perLeafResults;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private TopDocs[] parallelSearch(IndexSearcher.LeafSlice[] slices, Weight filterWeight, SliceExecutor sliceExecutor) {
        ArrayList<FutureTask<TopDocs[]>> tasks = new ArrayList<FutureTask<TopDocs[]>>(slices.length);
        int segmentsCount = 0;
        for (IndexSearcher.LeafSlice leafSlice : slices) {
            segmentsCount += leafSlice.leaves.length;
            tasks.add(new FutureTask<TopDocs[]>(() -> {
                TopDocs[] results = new TopDocs[slice.leaves.length];
                int i = 0;
                for (LeafReaderContext context : slice.leaves) {
                    results[i++] = this.searchLeaf(context, filterWeight);
                }
                return results;
            }));
        }
        sliceExecutor.invokeAll(tasks);
        TopDocs[] topDocs = new TopDocs[segmentsCount];
        int i = 0;
        for (FutureTask futureTask : tasks) {
            try {
                for (TopDocs docs : (TopDocs[])futureTask.get()) {
                    topDocs[i++] = docs;
                }
            }
            catch (ExecutionException e) {
                throw new RuntimeException(e.getCause());
            }
            catch (InterruptedException e) {
                throw new ThreadInterruptedException(e);
            }
        }
        return topDocs;
    }

    private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight) throws IOException {
        TopDocs results = this.getLeafResults(ctx, filterWeight);
        if (ctx.docBase > 0) {
            for (ScoreDoc scoreDoc : results.scoreDocs) {
                scoreDoc.doc += ctx.docBase;
            }
        }
        return results;
    }

    private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight) throws IOException {
        Bits liveDocs = ctx.reader().getLiveDocs();
        int maxDoc = ctx.reader().maxDoc();
        if (filterWeight == null) {
            return this.approximateSearch(ctx, liveDocs, Integer.MAX_VALUE);
        }
        Scorer scorer = filterWeight.scorer(ctx);
        if (scorer == null) {
            return NO_RESULTS;
        }
        BitSet acceptDocs = this.createBitSet(scorer.iterator(), liveDocs, maxDoc);
        int cost = acceptDocs.cardinality();
        if (cost <= this.k) {
            return this.exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
        }
        TopDocs results = this.approximateSearch(ctx, acceptDocs, cost);
        if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
            return results;
        }
        return this.exactSearch(ctx, new BitSetIterator(acceptDocs, cost));
    }

    private BitSet createBitSet(DocIdSetIterator iterator, final Bits liveDocs, int maxDoc) throws IOException {
        if (liveDocs == null && iterator instanceof BitSetIterator) {
            return ((BitSetIterator)iterator).getBitSet();
        }
        FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(iterator){

            @Override
            protected boolean match(int doc) {
                return liveDocs == null || liveDocs.get(doc);
            }
        };
        return BitSet.of(filterIterator, maxDoc);
    }

    protected abstract TopDocs approximateSearch(LeafReaderContext var1, Bits var2, int var3) throws IOException;

    abstract VectorScorer createVectorScorer(LeafReaderContext var1, FieldInfo var2) throws IOException;

    protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) throws IOException {
        int doc;
        FieldInfo fi = context.reader().getFieldInfos().fieldInfo(this.field);
        if (fi == null || fi.getVectorDimension() == 0) {
            return NO_RESULTS;
        }
        VectorScorer vectorScorer = this.createVectorScorer(context, fi);
        HitQueue queue = new HitQueue(this.k, true);
        ScoreDoc topDoc = (ScoreDoc)queue.top();
        while ((doc = acceptIterator.nextDoc()) != Integer.MAX_VALUE) {
            boolean advanced = vectorScorer.advanceExact(doc);
            assert (advanced);
            float score = vectorScorer.score();
            if (!(score > topDoc.score)) continue;
            topDoc.score = score;
            topDoc.doc = doc;
            topDoc = (ScoreDoc)queue.updateTop();
        }
        while (queue.size() > 0 && ((ScoreDoc)queue.top()).score < 0.0f) {
            queue.pop();
        }
        ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
        for (int i = topScoreDocs.length - 1; i >= 0; --i) {
            topScoreDocs[i] = (ScoreDoc)queue.pop();
        }
        TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
        return new TopDocs(totalHits, topScoreDocs);
    }

    private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
        int len = topK.scoreDocs.length;
        Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
        int[] docs = new int[len];
        float[] scores = new float[len];
        for (int i = 0; i < len; ++i) {
            docs[i] = topK.scoreDocs[i].doc;
            scores[i] = topK.scoreDocs[i].score;
        }
        int[] segmentStarts = AbstractKnnVectorQuery.findSegmentStarts(reader, docs);
        return new DocAndScoreQuery(this.k, docs, scores, segmentStarts, reader.getContext().id());
    }

    static int[] findSegmentStarts(IndexReader reader, int[] docs) {
        int[] starts = new int[reader.leaves().size() + 1];
        starts[starts.length - 1] = docs.length;
        if (starts.length == 2) {
            return starts;
        }
        int resultIndex = 0;
        for (int i = 1; i < starts.length - 1; ++i) {
            int upper = reader.leaves().get((int)i).docBase;
            if ((resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper)) < 0) {
                resultIndex = -1 - resultIndex;
            }
            starts[i] = resultIndex;
        }
        return starts;
    }

    @Override
    public void visit(QueryVisitor visitor) {
        if (visitor.acceptField(this.field)) {
            visitor.visitLeaf(this);
        }
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        AbstractKnnVectorQuery that = (AbstractKnnVectorQuery)o;
        return this.k == that.k && Objects.equals(this.field, that.field) && Objects.equals(this.filter, that.filter);
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.field, this.k, this.filter);
    }

    public String getField() {
        return this.field;
    }

    public int getK() {
        return this.k;
    }

    public Query getFilter() {
        return this.filter;
    }

    static class DocAndScoreQuery
    extends Query {
        private final int k;
        private final int[] docs;
        private final float[] scores;
        private final int[] segmentStarts;
        private final Object contextIdentity;

        DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
            this.k = k;
            this.docs = docs;
            this.scores = scores;
            this.segmentStarts = segmentStarts;
            this.contextIdentity = contextIdentity;
        }

        @Override
        public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, final float boost) throws IOException {
            if (searcher.getIndexReader().getContext().id() != this.contextIdentity) {
                throw new IllegalStateException("This DocAndScore query was created by a different reader");
            }
            return new Weight(this){

                @Override
                public Explanation explain(LeafReaderContext context, int doc) {
                    int found = Arrays.binarySearch(docs, doc + context.docBase);
                    if (found < 0) {
                        return Explanation.noMatch("not in top " + k, new Explanation[0]);
                    }
                    return Explanation.match((Number)Float.valueOf(scores[found] * boost), "within top " + k, new Explanation[0]);
                }

                @Override
                public int count(LeafReaderContext context) {
                    return segmentStarts[context.ord + 1] - segmentStarts[context.ord];
                }

                @Override
                public Scorer scorer(final LeafReaderContext context) {
                    if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) {
                        return null;
                    }
                    return new Scorer(this){
                        final int lower;
                        final int upper;
                        int upTo;
                        {
                            super(weight);
                            this.lower = segmentStarts[context.ord];
                            this.upper = segmentStarts[context.ord + 1];
                            this.upTo = -1;
                        }

                        @Override
                        public DocIdSetIterator iterator() {
                            return new DocIdSetIterator(){

                                @Override
                                public int docID() {
                                    return this.docIdNoShadow();
                                }

                                @Override
                                public int nextDoc() {
                                    upTo = upTo == -1 ? lower : ++upTo;
                                    return this.docIdNoShadow();
                                }

                                @Override
                                public int advance(int target) throws IOException {
                                    return this.slowAdvance(target);
                                }

                                @Override
                                public long cost() {
                                    return upper - lower;
                                }
                            };
                        }

                        @Override
                        public float getMaxScore(int docId) {
                            docId += context.docBase;
                            float maxScore = 0.0f;
                            for (int idx = Math.max(0, this.upTo); idx < this.upper && docs[idx] <= docId; ++idx) {
                                maxScore = Math.max(maxScore, scores[idx]);
                            }
                            return maxScore * boost;
                        }

                        @Override
                        public float score() {
                            return scores[this.upTo] * boost;
                        }

                        @Override
                        public int advanceShallow(int docid) {
                            int start = Math.max(this.upTo, this.lower);
                            int docidIndex = Arrays.binarySearch(docs, start, this.upper, docid + context.docBase);
                            if (docidIndex < 0) {
                                docidIndex = -1 - docidIndex;
                            }
                            if (docidIndex >= this.upper) {
                                return Integer.MAX_VALUE;
                            }
                            return docs[docidIndex];
                        }

                        private int docIdNoShadow() {
                            if (this.upTo == -1) {
                                return -1;
                            }
                            if (this.upTo >= this.upper) {
                                return Integer.MAX_VALUE;
                            }
                            return docs[this.upTo] - context.docBase;
                        }

                        @Override
                        public int docID() {
                            return this.docIdNoShadow();
                        }
                    };
                }

                @Override
                public boolean isCacheable(LeafReaderContext ctx) {
                    return true;
                }
            };
        }

        @Override
        public String toString(String field) {
            return "DocAndScore[" + this.k + "]";
        }

        @Override
        public void visit(QueryVisitor visitor) {
            visitor.visitLeaf(this);
        }

        @Override
        public boolean equals(Object obj) {
            if (!this.sameClassAs(obj)) {
                return false;
            }
            return this.contextIdentity == ((DocAndScoreQuery)obj).contextIdentity && Arrays.equals(this.docs, ((DocAndScoreQuery)obj).docs) && Arrays.equals(this.scores, ((DocAndScoreQuery)obj).scores);
        }

        @Override
        public int hashCode() {
            return Objects.hash(this.classHash(), this.contextIdentity, Arrays.hashCode(this.docs), Arrays.hashCode(this.scores));
        }
    }
}

