/*
 * Decompiled with CFR 0.152.
 */
package org.graylog.shaded.opensearch2.org.opensearch.index.rankeval;

import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.OptionalInt;
import org.graylog.shaded.opensearch2.org.opensearch.common.io.stream.StreamInput;
import org.graylog.shaded.opensearch2.org.opensearch.common.io.stream.StreamOutput;
import org.graylog.shaded.opensearch2.org.opensearch.core.ParseField;
import org.graylog.shaded.opensearch2.org.opensearch.core.xcontent.ConstructingObjectParser;
import org.graylog.shaded.opensearch2.org.opensearch.core.xcontent.ToXContent;
import org.graylog.shaded.opensearch2.org.opensearch.core.xcontent.XContentBuilder;
import org.graylog.shaded.opensearch2.org.opensearch.core.xcontent.XContentParser;
import org.graylog.shaded.opensearch2.org.opensearch.index.rankeval.EvalQueryQuality;
import org.graylog.shaded.opensearch2.org.opensearch.index.rankeval.EvaluationMetric;
import org.graylog.shaded.opensearch2.org.opensearch.index.rankeval.MetricDetail;
import org.graylog.shaded.opensearch2.org.opensearch.index.rankeval.RatedDocument;
import org.graylog.shaded.opensearch2.org.opensearch.index.rankeval.RatedSearchHit;
import org.graylog.shaded.opensearch2.org.opensearch.search.SearchHit;

public class PrecisionAtK
implements EvaluationMetric {
    public static final String NAME = "precision";
    private static final int DEFAULT_RELEVANT_RATING_THRESHOLD = 1;
    private static final boolean DEFAULT_IGNORE_UNLABELED = false;
    private static final int DEFAULT_K = 10;
    private static final ParseField RELEVANT_RATING_THRESHOLD_FIELD = new ParseField("relevant_rating_threshold", new String[0]);
    private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled", new String[0]);
    private static final ParseField K_FIELD = new ParseField("k", new String[0]);
    private final int relevantRatingThreshold;
    private final boolean ignoreUnlabeled;
    private final int k;
    private static final ConstructingObjectParser<PrecisionAtK, Void> PARSER = new ConstructingObjectParser("precision", args -> {
        Integer relevantRatingThreshold = (Integer)args[0];
        Boolean ignoreUnlabeled = (Boolean)args[1];
        Integer k = (Integer)args[2];
        return new PrecisionAtK(relevantRatingThreshold == null ? 1 : relevantRatingThreshold, ignoreUnlabeled == null ? false : ignoreUnlabeled, k == null ? 10 : k);
    });

    public PrecisionAtK(int relevantRatingThreshold, boolean ignoreUnlabeled, int k) {
        if (relevantRatingThreshold < 0) {
            throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer.");
        }
        if (k <= 0) {
            throw new IllegalArgumentException("Window size k must be positive.");
        }
        this.relevantRatingThreshold = relevantRatingThreshold;
        this.ignoreUnlabeled = ignoreUnlabeled;
        this.k = k;
    }

    public PrecisionAtK(boolean ignoreUnlabeled) {
        this(1, ignoreUnlabeled, 10);
    }

    public PrecisionAtK() {
        this(1, false, 10);
    }

    PrecisionAtK(StreamInput in) throws IOException {
        this(in.readVInt(), in.readBoolean(), in.readVInt());
    }

    public static PrecisionAtK fromXContent(XContentParser parser) {
        return PARSER.apply(parser, null);
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeVInt(this.getRelevantRatingThreshold());
        out.writeBoolean(this.getIgnoreUnlabeled());
        out.writeVInt(this.getK());
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.startObject(NAME);
        builder.field(RELEVANT_RATING_THRESHOLD_FIELD.getPreferredName(), this.getRelevantRatingThreshold());
        builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.getIgnoreUnlabeled());
        builder.field(K_FIELD.getPreferredName(), this.getK());
        builder.endObject();
        builder.endObject();
        return builder;
    }

    @Override
    public String getWriteableName() {
        return NAME;
    }

    public int getRelevantRatingThreshold() {
        return this.relevantRatingThreshold;
    }

    public boolean getIgnoreUnlabeled() {
        return this.ignoreUnlabeled;
    }

    public int getK() {
        return this.k;
    }

    @Override
    public OptionalInt forcedSearchSize() {
        return OptionalInt.of(this.getK());
    }

    private boolean isRelevant(int rating) {
        return rating >= this.getRelevantRatingThreshold();
    }

    private boolean shouldCountUnlabeled() {
        return !this.getIgnoreUnlabeled();
    }

    @Override
    public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, List<RatedDocument> ratedDocs) {
        List<RatedSearchHit> ratedSearchHits = EvaluationMetric.joinHitsWithRatings(hits, ratedDocs);
        int relevantRetrieved = 0;
        int retrieved = 0;
        for (RatedSearchHit hit : ratedSearchHits) {
            OptionalInt rating = hit.getRating();
            if (rating.isPresent()) {
                ++retrieved;
                if (!this.isRelevant(rating.getAsInt())) continue;
                ++relevantRetrieved;
                continue;
            }
            if (!this.shouldCountUnlabeled()) continue;
            ++retrieved;
        }
        double precision = 0.0;
        if (retrieved > 0) {
            precision = (double)relevantRetrieved / (double)retrieved;
        }
        EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, precision);
        evalQueryQuality.setMetricDetails(new Detail(relevantRetrieved, retrieved));
        evalQueryQuality.addHitsAndRatings(ratedSearchHits);
        return evalQueryQuality;
    }

    public final boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        PrecisionAtK other = (PrecisionAtK)obj;
        return Objects.equals(this.relevantRatingThreshold, other.relevantRatingThreshold) && Objects.equals(this.ignoreUnlabeled, other.ignoreUnlabeled) && Objects.equals(this.k, other.k);
    }

    public final int hashCode() {
        return Objects.hash(this.relevantRatingThreshold, this.ignoreUnlabeled, this.k);
    }

    static {
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RELEVANT_RATING_THRESHOLD_FIELD);
        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), IGNORE_UNLABELED_FIELD);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), K_FIELD);
    }

    public static final class Detail
    implements MetricDetail {
        private static final ParseField RELEVANT_DOCS_RETRIEVED_FIELD = new ParseField("relevant_docs_retrieved", new String[0]);
        private static final ParseField DOCS_RETRIEVED_FIELD = new ParseField("docs_retrieved", new String[0]);
        private int relevantRetrieved;
        private int retrieved;
        private static final ConstructingObjectParser<Detail, Void> PARSER = new ConstructingObjectParser("precision", true, args -> new Detail((Integer)args[0], (Integer)args[1]));

        Detail(int relevantRetrieved, int retrieved) {
            this.relevantRetrieved = relevantRetrieved;
            this.retrieved = retrieved;
        }

        Detail(StreamInput in) throws IOException {
            this(in.readVInt(), in.readVInt());
        }

        public static Detail fromXContent(XContentParser parser) {
            return PARSER.apply(parser, null);
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            out.writeVLong(this.relevantRetrieved);
            out.writeVLong(this.retrieved);
        }

        @Override
        public XContentBuilder innerToXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), this.relevantRetrieved);
            builder.field(DOCS_RETRIEVED_FIELD.getPreferredName(), this.retrieved);
            return builder;
        }

        @Override
        public String getWriteableName() {
            return PrecisionAtK.NAME;
        }

        public int getRelevantRetrieved() {
            return this.relevantRetrieved;
        }

        public int getRetrieved() {
            return this.retrieved;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || this.getClass() != obj.getClass()) {
                return false;
            }
            Detail other = (Detail)obj;
            return Objects.equals(this.relevantRetrieved, other.relevantRetrieved) && Objects.equals(this.retrieved, other.retrieved);
        }

        public int hashCode() {
            return Objects.hash(this.relevantRetrieved, this.retrieved);
        }

        static {
            PARSER.declareInt(ConstructingObjectParser.constructorArg(), RELEVANT_DOCS_RETRIEVED_FIELD);
            PARSER.declareInt(ConstructingObjectParser.constructorArg(), DOCS_RETRIEVED_FIELD);
        }
    }
}

