/*
 * Decompiled with CFR 0.152.
 */
package org.hibernate.search.backend.lucene.search.predicate.impl;

import java.lang.invoke.MethodHandles;
import java.lang.reflect.Array;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.hibernate.search.backend.lucene.logging.impl.Log;
import org.hibernate.search.backend.lucene.lowlevel.query.impl.VectorSimilarityFilterQuery;
import org.hibernate.search.backend.lucene.search.common.impl.AbstractLuceneValueFieldSearchQueryElementFactory;
import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexScope;
import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext;
import org.hibernate.search.backend.lucene.search.predicate.impl.AbstractLuceneSingleFieldPredicate;
import org.hibernate.search.backend.lucene.search.predicate.impl.LuceneSearchPredicate;
import org.hibernate.search.backend.lucene.search.predicate.impl.PredicateRequestContext;
import org.hibernate.search.backend.lucene.types.codec.impl.LuceneFieldCodec;
import org.hibernate.search.backend.lucene.types.codec.impl.LuceneVectorFieldCodec;
import org.hibernate.search.engine.search.predicate.SearchPredicate;
import org.hibernate.search.engine.search.predicate.spi.KnnPredicateBuilder;
import org.hibernate.search.util.common.AssertionFailure;
import org.hibernate.search.util.common.logging.impl.LoggerFactory;

public abstract class LuceneKnnPredicate<T>
extends AbstractLuceneSingleFieldPredicate
implements LuceneSearchPredicate {
    private static final Log log = (Log)LoggerFactory.make(Log.class, (MethodHandles.Lookup)MethodHandles.lookup());
    protected final int k;
    protected final T vector;
    protected final Float requiredMinimumScore;
    private final LuceneSearchPredicate filter;

    private LuceneKnnPredicate(Builder<T> builder) {
        super(builder);
        this.k = builder.k;
        this.vector = builder.vector;
        this.filter = builder.filter;
        this.requiredMinimumScore = builder.requiredMinimumScore;
    }

    protected Query prepareFilter(PredicateRequestContext context) {
        return context.appendTenantAndRoutingFilters(this.filter == null ? null : this.filter.toQuery(context));
    }

    private static abstract class Builder<F>
    extends AbstractLuceneSingleFieldPredicate.AbstractBuilder
    implements KnnPredicateBuilder {
        protected final LuceneVectorFieldCodec<F> vectorCodec;
        private int k;
        private F vector;
        private LuceneSearchPredicate filter;
        private Float requiredMinimumScore;

        protected Builder(LuceneSearchIndexScope<?> scope, LuceneSearchIndexValueFieldContext<F> field) {
            super(scope, field);
            LuceneFieldCodec codec = field.type().codec();
            if (!(codec instanceof LuceneVectorFieldCodec)) {
                throw new AssertionFailure("Attempting to use a knn predicate on a non-vector field.");
            }
            this.vectorCodec = (LuceneVectorFieldCodec)codec;
        }

        public void k(int k) {
            this.k = k;
        }

        public void vector(Object vector) {
            if (!vector.getClass().isArray()) {
                throw new IllegalArgumentException("Vector can only be either a float or a byte array (float[], byte[]).");
            }
            if (!this.vectorCodec.vectorElementsType().equals(vector.getClass().getComponentType())) {
                throw log.vectorKnnMatchVectorTypeDiffersFromField(this.absoluteFieldPath, this.vectorCodec.vectorElementsType(), vector.getClass().getComponentType());
            }
            if (Array.getLength(vector) != this.vectorCodec.getConfiguredDimensions()) {
                throw log.vectorKnnMatchVectorDimensionDiffersFromField(this.absoluteFieldPath, this.vectorCodec.getConfiguredDimensions(), Array.getLength(vector));
            }
            this.vector = this.vectorCodec.encode(vector);
        }

        public void filter(SearchPredicate filter) {
            this.filter = LuceneSearchPredicate.from(this.scope, filter);
        }

        public void requiredMinimumScore(float score) {
            this.requiredMinimumScore = Float.valueOf(score);
        }

        public void requiredMinimumSimilarity(float similarity) {
            this.requiredMinimumScore(this.vectorCodec.similarityDistanceToScore(similarity));
        }
    }

    private static class LuceneFloatKnnPredicate
    extends LuceneKnnPredicate<float[]> {
        private LuceneFloatKnnPredicate(FloatBuilder builder) {
            super(builder);
        }

        @Override
        protected Query doToQuery(PredicateRequestContext context) {
            KnnFloatVectorQuery query = new KnnFloatVectorQuery(this.absoluteFieldPath, (float[])this.vector, this.k, this.prepareFilter(context));
            return this.requiredMinimumScore == null ? query : VectorSimilarityFilterQuery.create(query, this.requiredMinimumScore.floatValue());
        }

        private static class FloatBuilder
        extends Builder<float[]> {
            protected FloatBuilder(LuceneSearchIndexScope<?> scope, LuceneSearchIndexValueFieldContext<float[]> field) {
                super(scope, field);
            }

            public SearchPredicate build() {
                return new LuceneFloatKnnPredicate(this);
            }
        }
    }

    private static class LuceneByteKnnPredicate
    extends LuceneKnnPredicate<byte[]> {
        private LuceneByteKnnPredicate(ByteBuilder builder) {
            super(builder);
        }

        @Override
        protected Query doToQuery(PredicateRequestContext context) {
            KnnByteVectorQuery query = new KnnByteVectorQuery(this.absoluteFieldPath, (byte[])this.vector, this.k, this.prepareFilter(context));
            return this.requiredMinimumScore == null ? query : VectorSimilarityFilterQuery.create(query, this.requiredMinimumScore.floatValue());
        }

        private static class ByteBuilder
        extends Builder<byte[]> {
            protected ByteBuilder(LuceneSearchIndexScope<?> scope, LuceneSearchIndexValueFieldContext<byte[]> field) {
                super(scope, field);
            }

            public SearchPredicate build() {
                return new LuceneByteKnnPredicate(this);
            }
        }
    }

    public static class ByteFactory
    extends AbstractLuceneValueFieldSearchQueryElementFactory<KnnPredicateBuilder, byte[]> {
        @Override
        public KnnPredicateBuilder create(LuceneSearchIndexScope<?> scope, LuceneSearchIndexValueFieldContext<byte[]> field) {
            return new LuceneByteKnnPredicate.ByteBuilder(scope, field);
        }
    }

    public static class FloatFactory
    extends AbstractLuceneValueFieldSearchQueryElementFactory<KnnPredicateBuilder, float[]> {
        @Override
        public KnnPredicateBuilder create(LuceneSearchIndexScope<?> scope, LuceneSearchIndexValueFieldContext<float[]> field) {
            return new LuceneFloatKnnPredicate.FloatBuilder(scope, field);
        }
    }
}

