/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.store.embedding;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.OptionalDouble;

public final class MmrSelector {
    private static final double INITIAL_MMR_SCORE = -1.0;
    private static final double INITIAL_DIVERSITY_SCORE = 0.0;
    private static final double MIN_LAMBDA = 0.0;
    private static final double MAX_LAMBDA = 1.0;

    private MmrSelector() {
    }

    public static <T> List<EmbeddingMatch<T>> select(Embedding queryEmbedding, List<EmbeddingMatch<T>> candidates, int maxResults, double lambda) {
        MmrSelector.validateParameters(queryEmbedding, lambda, maxResults);
        if (Utils.isNullOrEmpty(candidates) || maxResults <= 0) {
            return new ArrayList<EmbeddingMatch<T>>();
        }
        if (candidates.size() <= maxResults) {
            return new ArrayList<EmbeddingMatch<T>>(candidates);
        }
        return MmrSelector.performMmrSelection(queryEmbedding, candidates, maxResults, lambda);
    }

    private static void validateParameters(Embedding queryEmbedding, double lambda, int maxResults) {
        if (queryEmbedding == null) {
            throw new IllegalArgumentException("Query embedding cannot be null");
        }
        if (lambda < 0.0 || lambda > 1.0) {
            throw new IllegalArgumentException("Lambda must be between 0.0 and 1.0 (inclusive), got: " + lambda);
        }
        if (maxResults < 0) {
            throw new IllegalArgumentException("Max results cannot be negative, got: " + maxResults);
        }
    }

    private static <T> List<EmbeddingMatch<T>> performMmrSelection(Embedding queryEmbedding, List<EmbeddingMatch<T>> candidates, int maxResults, double lambda) {
        ArrayList<EmbeddingMatch<T>> selected = new ArrayList<EmbeddingMatch<T>>(maxResults);
        ArrayList<EmbeddingMatch<T>> remaining = new ArrayList<EmbeddingMatch<T>>(candidates);
        remaining.sort(Comparator.comparingDouble(match -> match.score()).reversed());
        while (selected.size() < maxResults && !remaining.isEmpty()) {
            EmbeddingMatch<T> bestCandidate = MmrSelector.findBestMmrCandidate(queryEmbedding, remaining, selected, lambda);
            if (bestCandidate != null) {
                selected.add(bestCandidate);
                remaining.remove(bestCandidate);
                continue;
            }
            selected.add((EmbeddingMatch)remaining.remove(0));
        }
        return selected;
    }

    private static <T> EmbeddingMatch<T> findBestMmrCandidate(Embedding queryEmbedding, List<EmbeddingMatch<T>> remaining, List<EmbeddingMatch<T>> selected, double lambda) {
        double maxMmrScore = -1.0;
        EmbeddingMatch<T> bestCandidate = null;
        for (EmbeddingMatch<T> candidate : remaining) {
            double diversityScore;
            double relevanceScore = MmrSelector.getRelevanceScore(candidate, queryEmbedding);
            double mmrScore = MmrSelector.calculateMmrScore(relevanceScore, diversityScore = MmrSelector.calculateDiversityScore(candidate, selected), lambda);
            if (!(mmrScore > maxMmrScore)) continue;
            maxMmrScore = mmrScore;
            bestCandidate = candidate;
        }
        return bestCandidate;
    }

    private static <T> double getRelevanceScore(EmbeddingMatch<T> candidate, Embedding queryEmbedding) {
        if (candidate.score() > 0.0) {
            return candidate.score();
        }
        return CosineSimilarity.between((Embedding)candidate.embedding(), (Embedding)queryEmbedding);
    }

    private static <T> double calculateDiversityScore(EmbeddingMatch<T> candidate, List<EmbeddingMatch<T>> selected) {
        if (selected.isEmpty()) {
            return 0.0;
        }
        OptionalDouble maxSimilarity = selected.stream().mapToDouble(selectedItem -> CosineSimilarity.between((Embedding)candidate.embedding(), (Embedding)selectedItem.embedding())).max();
        return maxSimilarity.orElse(0.0);
    }

    private static double calculateMmrScore(double relevanceScore, double diversityScore, double lambda) {
        return lambda * relevanceScore - (1.0 - lambda) * diversityScore;
    }
}

