/*
 * Decompiled with CFR 0.152.
 */
package org.graylog.shaded.opensearch2.org.apache.lucene.search;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.FieldInfo;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.IndexReader;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.LeafReaderContext;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.VectorSimilarityFunction;
import org.graylog.shaded.opensearch2.org.apache.lucene.index.VectorValues;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.BooleanClause;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.BooleanQuery;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.DocIdSetIterator;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.Explanation;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.HitQueue;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.IndexSearcher;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.KnnVectorFieldExistsQuery;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.MatchNoDocsQuery;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.Query;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.QueryVisitor;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.ScoreDoc;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.ScoreMode;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.Scorer;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.SimpleCollector;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TopDocs;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TopDocsCollector;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.TotalHits;
import org.graylog.shaded.opensearch2.org.apache.lucene.search.Weight;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.BitSet;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.BitSetIterator;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.Bits;
import org.graylog.shaded.opensearch2.org.apache.lucene.util.FixedBitSet;

public class KnnVectorQuery
extends Query {
    private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
    private final String field;
    private final float[] target;
    private final int k;
    private final Query filter;

    public KnnVectorQuery(String field, float[] target, int k) {
        this(field, target, k, null);
    }

    public KnnVectorQuery(String field, float[] target, int k, Query filter) {
        this.field = field;
        this.target = target;
        this.k = k;
        if (k < 1) {
            throw new IllegalArgumentException("k must be at least 1, got: " + k);
        }
        this.filter = filter;
    }

    @Override
    public Query rewrite(IndexReader reader) throws IOException {
        TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
        BitSetCollector filterCollector = null;
        if (this.filter != null) {
            filterCollector = new BitSetCollector(reader.leaves().size());
            IndexSearcher indexSearcher = new IndexSearcher(reader);
            BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.filter, BooleanClause.Occur.FILTER).add(new KnnVectorFieldExistsQuery(this.field), BooleanClause.Occur.FILTER).build();
            indexSearcher.search((Query)booleanQuery, filterCollector);
        }
        for (LeafReaderContext ctx : reader.leaves()) {
            TopDocs results = this.searchLeaf(ctx, filterCollector);
            if (ctx.docBase > 0) {
                for (ScoreDoc scoreDoc : results.scoreDocs) {
                    scoreDoc.doc += ctx.docBase;
                }
            }
            perLeafResults[ctx.ord] = results;
        }
        TopDocs topK = TopDocs.merge(this.k, perLeafResults);
        if (topK.scoreDocs.length == 0) {
            return new MatchNoDocsQuery();
        }
        return this.createRewrittenQuery(reader, topK);
    }

    private TopDocs searchLeaf(LeafReaderContext ctx, BitSetCollector filterCollector) throws IOException {
        if (filterCollector == null) {
            Bits acceptDocs = ctx.reader().getLiveDocs();
            return this.approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE);
        }
        BitSetIterator filterIterator = filterCollector.getIterator(ctx.ord);
        if (filterIterator == null || filterIterator.cost() == 0L) {
            return NO_RESULTS;
        }
        if (filterIterator.cost() <= (long)this.k) {
            return this.exactSearch(ctx, filterIterator);
        }
        BitSet acceptDocs = filterIterator.getBitSet();
        int visitedLimit = (int)filterIterator.cost();
        TopDocs results = this.approximateSearch(ctx, acceptDocs, visitedLimit);
        if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
            return results;
        }
        return this.exactSearch(ctx, filterIterator);
    }

    private TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitedLimit) throws IOException {
        TopDocs results = context.reader().searchNearestVectors(this.field, this.target, this.k, acceptDocs, visitedLimit);
        return results != null ? results : NO_RESULTS;
    }

    protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) throws IOException {
        int doc;
        FieldInfo fi = context.reader().getFieldInfos().fieldInfo(this.field);
        if (fi == null || fi.getVectorDimension() == 0) {
            return NO_RESULTS;
        }
        VectorSimilarityFunction similarityFunction = fi.getVectorSimilarityFunction();
        VectorValues vectorValues = context.reader().getVectorValues(this.field);
        HitQueue queue = new HitQueue(this.k, true);
        ScoreDoc topDoc = (ScoreDoc)queue.top();
        while ((doc = acceptIterator.nextDoc()) != Integer.MAX_VALUE) {
            int vectorDoc = vectorValues.advance(doc);
            assert (vectorDoc == doc);
            float[] vector = vectorValues.vectorValue();
            float score = similarityFunction.convertToScore(similarityFunction.compare(vector, this.target));
            if (!(score >= topDoc.score)) continue;
            topDoc.score = score;
            topDoc.doc = doc;
            topDoc = (ScoreDoc)queue.updateTop();
        }
        while (queue.size() > 0 && ((ScoreDoc)queue.top()).score < 0.0f) {
            queue.pop();
        }
        ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
        for (int i = topScoreDocs.length - 1; i >= 0; --i) {
            topScoreDocs[i] = (ScoreDoc)queue.pop();
        }
        TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
        return new TopDocs(totalHits, topScoreDocs);
    }

    private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
        int len = topK.scoreDocs.length;
        Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
        int[] docs = new int[len];
        float[] scores = new float[len];
        for (int i = 0; i < len; ++i) {
            docs[i] = topK.scoreDocs[i].doc;
            scores[i] = topK.scoreDocs[i].score;
        }
        int[] segmentStarts = this.findSegmentStarts(reader, docs);
        return new DocAndScoreQuery(this.k, docs, scores, segmentStarts, reader.getContext().id());
    }

    private int[] findSegmentStarts(IndexReader reader, int[] docs) {
        int[] starts = new int[reader.leaves().size() + 1];
        starts[starts.length - 1] = docs.length;
        if (starts.length == 2) {
            return starts;
        }
        int resultIndex = 0;
        for (int i = 1; i < starts.length - 1; ++i) {
            int upper = reader.leaves().get((int)i).docBase;
            if ((resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper)) < 0) {
                resultIndex = -1 - resultIndex;
            }
            starts[i] = resultIndex;
        }
        return starts;
    }

    @Override
    public String toString(String field) {
        return this.getClass().getSimpleName() + ":" + this.field + "[" + this.target[0] + ",...][" + this.k + "]";
    }

    @Override
    public void visit(QueryVisitor visitor) {
        if (visitor.acceptField(this.field)) {
            visitor.visitLeaf(this);
        }
    }

    @Override
    public boolean equals(Object obj) {
        return this.sameClassAs(obj) && ((KnnVectorQuery)obj).k == this.k && ((KnnVectorQuery)obj).field.equals(this.field) && Arrays.equals(((KnnVectorQuery)obj).target, this.target);
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.classHash(), this.field, this.k, Arrays.hashCode(this.target));
    }

    static class DocAndScoreQuery
    extends Query {
        private final int k;
        private final int[] docs;
        private final float[] scores;
        private final int[] segmentStarts;
        private final Object contextIdentity;

        DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
            this.k = k;
            this.docs = docs;
            this.scores = scores;
            this.segmentStarts = segmentStarts;
            this.contextIdentity = contextIdentity;
        }

        @Override
        public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
            if (searcher.getIndexReader().getContext().id() != this.contextIdentity) {
                throw new IllegalStateException("This DocAndScore query was created by a different reader");
            }
            return new Weight(this){

                @Override
                public Explanation explain(LeafReaderContext context, int doc) {
                    int found = Arrays.binarySearch(docs, doc);
                    if (found < 0) {
                        return Explanation.noMatch("not in top " + k, new Explanation[0]);
                    }
                    return Explanation.match((Number)Float.valueOf(scores[found]), "within top " + k, new Explanation[0]);
                }

                @Override
                public Scorer scorer(final LeafReaderContext context) {
                    return new Scorer(this){
                        final int lower;
                        final int upper;
                        int upTo;
                        {
                            super(weight);
                            this.lower = segmentStarts[context.ord];
                            this.upper = segmentStarts[context.ord + 1];
                            this.upTo = -1;
                        }

                        @Override
                        public DocIdSetIterator iterator() {
                            return new DocIdSetIterator(){

                                @Override
                                public int docID() {
                                    return this.docIdNoShadow();
                                }

                                @Override
                                public int nextDoc() {
                                    upTo = upTo == -1 ? lower : ++upTo;
                                    return this.docIdNoShadow();
                                }

                                @Override
                                public int advance(int target) throws IOException {
                                    return this.slowAdvance(target);
                                }

                                @Override
                                public long cost() {
                                    return upper - lower;
                                }
                            };
                        }

                        @Override
                        public float getMaxScore(int docid) {
                            docid += context.docBase;
                            float maxScore = 0.0f;
                            for (int idx = Math.max(0, this.upTo); idx < this.upper && docs[idx] <= docid; ++idx) {
                                maxScore = Math.max(maxScore, scores[idx]);
                            }
                            return maxScore;
                        }

                        @Override
                        public float score() {
                            return scores[this.upTo];
                        }

                        @Override
                        public int advanceShallow(int docid) {
                            int start = Math.max(this.upTo, this.lower);
                            int docidIndex = Arrays.binarySearch(docs, start, this.upper, docid + context.docBase);
                            if (docidIndex < 0) {
                                docidIndex = -1 - docidIndex;
                            }
                            if (docidIndex >= this.upper) {
                                return Integer.MAX_VALUE;
                            }
                            return docs[docidIndex];
                        }

                        private int docIdNoShadow() {
                            if (this.upTo == -1) {
                                return -1;
                            }
                            if (this.upTo >= this.upper) {
                                return Integer.MAX_VALUE;
                            }
                            return docs[this.upTo] - context.docBase;
                        }

                        @Override
                        public int docID() {
                            return this.docIdNoShadow();
                        }
                    };
                }

                @Override
                public boolean isCacheable(LeafReaderContext ctx) {
                    return true;
                }
            };
        }

        @Override
        public String toString(String field) {
            return "DocAndScore[" + this.k + "]";
        }

        @Override
        public void visit(QueryVisitor visitor) {
            visitor.visitLeaf(this);
        }

        @Override
        public boolean equals(Object obj) {
            if (!this.sameClassAs(obj)) {
                return false;
            }
            return this.contextIdentity == ((DocAndScoreQuery)obj).contextIdentity && Arrays.equals(this.docs, ((DocAndScoreQuery)obj).docs) && Arrays.equals(this.scores, ((DocAndScoreQuery)obj).scores);
        }

        @Override
        public int hashCode() {
            return Objects.hash(this.classHash(), this.contextIdentity, Arrays.hashCode(this.docs), Arrays.hashCode(this.scores));
        }
    }

    private static class BitSetCollector
    extends SimpleCollector {
        private final BitSet[] bitSets;
        private final int[] cost;
        private int ord;

        private BitSetCollector(int numLeaves) {
            this.bitSets = new BitSet[numLeaves];
            this.cost = new int[this.bitSets.length];
        }

        public BitSetIterator getIterator(int contextOrd) {
            if (this.bitSets[contextOrd] == null) {
                return null;
            }
            return new BitSetIterator(this.bitSets[contextOrd], this.cost[contextOrd]);
        }

        @Override
        public void collect(int doc) throws IOException {
            this.bitSets[this.ord].set(doc);
            int n = this.ord;
            this.cost[n] = this.cost[n] + 1;
        }

        @Override
        protected void doSetNextReader(LeafReaderContext context) throws IOException {
            this.bitSets[context.ord] = new FixedBitSet(context.reader().maxDoc());
            this.ord = context.ord;
        }

        @Override
        public ScoreMode scoreMode() {
            return ScoreMode.COMPLETE_NO_SCORES;
        }
    }
}

