/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.search.ranking;

import ai.vespa.models.evaluation.FunctionEvaluator;
import com.yahoo.component.annotation.Inject;
import com.yahoo.data.access.Inspector;
import com.yahoo.data.access.helpers.MatchFeatureData;
import com.yahoo.data.access.helpers.MatchFeatureFilter;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.query.Ranking;
import com.yahoo.search.query.Sorting;
import com.yahoo.search.query.ranking.RankFeatures;
import com.yahoo.search.ranking.Evaluator;
import com.yahoo.search.ranking.HitRescorer;
import com.yahoo.search.ranking.RankProfilesEvaluator;
import com.yahoo.search.ranking.RankProfilesEvaluatorFactory;
import com.yahoo.search.ranking.ResultReranker;
import com.yahoo.search.ranking.SimpleEvaluator;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.Tensor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.logging.Logger;

public class GlobalPhaseRanker {
    private static final Logger logger = Logger.getLogger(GlobalPhaseRanker.class.getName());
    private final RankProfilesEvaluatorFactory factory;

    @Inject
    public GlobalPhaseRanker(RankProfilesEvaluatorFactory factory) {
        this.factory = factory;
        logger.fine(() -> "Using factory: " + factory);
    }

    public Optional<ErrorMessage> validateNoSorting(Query query, String schema) {
        RankProfilesEvaluator.GlobalPhaseData data = this.globalPhaseDataFor(query, schema).orElse(null);
        if (data == null) {
            return Optional.empty();
        }
        Sorting sorting = query.getRanking().getSorting();
        if (sorting == null || sorting.fieldOrders() == null) {
            return Optional.empty();
        }
        for (Sorting.FieldOrder fieldOrder : sorting.fieldOrders()) {
            if (fieldOrder.getSorter().getName().equals("[rank]") && fieldOrder.getSortOrder() == Sorting.Order.DESCENDING) continue;
            return Optional.of(ErrorMessage.createIllegalQuery("Sorting is not supported with global phase"));
        }
        return Optional.empty();
    }

    public void rerankHits(Query query, Result result, String schema) {
        RankProfilesEvaluator.GlobalPhaseData data = this.globalPhaseDataFor(query, schema).orElse(null);
        if (data == null) {
            return;
        }
        Supplier<FunctionEvaluator> functionEvaluatorSource = data.functionEvaluatorSource();
        List<NameAndValue> prepared = this.findFromQuery(query, data.needInputs());
        Supplier<Evaluator> supplier = () -> {
            FunctionEvaluator evaluator = (FunctionEvaluator)functionEvaluatorSource.get();
            SimpleEvaluator simple = new SimpleEvaluator(evaluator);
            for (NameAndValue entry : prepared) {
                simple.bind(entry.name(), entry.value());
            }
            return simple;
        };
        int rerankCount = data.rerankCount();
        if (rerankCount < 0) {
            rerankCount = 100;
        }
        ResultReranker.rerankHits(result, new HitRescorer(supplier), rerankCount);
        this.hideImplicitMatchFeatures(result, data.matchFeaturesToHide());
    }

    private void hideImplicitMatchFeatures(Result result, Collection<String> namesToHide) {
        if (namesToHide.size() == 0) {
            return;
        }
        MatchFeatureFilter filter = new MatchFeatureFilter(namesToHide);
        Iterator<Hit> iterator = result.hits().deepIterator();
        while (iterator.hasNext()) {
            FeatureData matchFeatures;
            Inspector inspector;
            Object object;
            Hit hit = iterator.next();
            if (hit.isMeta() || hit instanceof HitGroup || !((object = hit.getField("matchfeatures")) instanceof FeatureData) || !((inspector = (matchFeatures = (FeatureData)object).inspect()) instanceof MatchFeatureData.HitValue)) continue;
            MatchFeatureData.HitValue hitValue = (MatchFeatureData.HitValue)inspector;
            MatchFeatureData.HitValue newValue = hitValue.subsetFilter((Function)filter);
            if (newValue.fieldCount() == 0) {
                hit.removeField("matchfeatures");
                continue;
            }
            hit.setField("matchfeatures", newValue);
        }
    }

    private Optional<RankProfilesEvaluator.GlobalPhaseData> globalPhaseDataFor(Query query, String schema) {
        return this.factory.evaluatorForSchema(schema).flatMap(evaluator -> evaluator.getGlobalPhaseData(query.getRanking().getProfile()));
    }

    List<NameAndValue> findFromQuery(Query query, List<String> needInputs) {
        ArrayList<NameAndValue> result = new ArrayList<NameAndValue>();
        Ranking ranking = query.getRanking();
        RankFeatures rankFeatures = ranking.getFeatures();
        Map<String, List<Object>> rankProps = ranking.getProperties().asMap();
        for (String needed : needInputs) {
            Object object;
            Optional optRef = Reference.simple((String)needed);
            if (optRef.isEmpty()) continue;
            Reference ref = (Reference)optRef.get();
            if (ref.name().equals("constant")) {
                result.add(new NameAndValue(needed, null));
                continue;
            }
            if (!ref.isSimple() || !ref.name().equals("query")) continue;
            String queryFeatureName = (String)ref.simpleArgument().get();
            Optional<Tensor> feature = rankFeatures.getTensor(queryFeatureName);
            if (feature.isPresent()) {
                result.add(new NameAndValue(needed, feature.get()));
                continue;
            }
            List<Object> objList = rankProps.get(queryFeatureName);
            if (objList == null || objList.size() != 1 || !((object = objList.get(0)) instanceof Tensor)) continue;
            Tensor t = (Tensor)object;
            result.add(new NameAndValue(needed, t));
        }
        return result;
    }

    record NameAndValue(String name, Tensor value) {
    }
}

