/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.schema.expressiontransforms;

import com.yahoo.schema.RankProfile;
import com.yahoo.schema.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import java.util.ArrayList;
import java.util.List;

public class NormalizerFunctionExpander
extends ExpressionTransformer<RankProfileTransformContext> {
    public static final String NORMALIZE_LINEAR = "normalize_linear";
    public static final String RECIPROCAL_RANK = "reciprocal_rank";
    public static final String RECIPROCAL_RANK_FUSION = "reciprocal_rank_fusion";

    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof ReferenceNode) {
            ReferenceNode r = (ReferenceNode)node;
            node = this.transformReference(r, context);
        }
        if (node instanceof CompositeNode) {
            CompositeNode composite = (CompositeNode)node;
            node = this.transformChildren(composite, context);
        }
        return node;
    }

    private ExpressionNode transformReference(ReferenceNode node, RankProfileTransformContext context) {
        Reference ref = node.reference();
        String name = ref.name();
        if (ref.output() != null) {
            return node;
        }
        RankProfile.RankingExpressionFunction f = context.rankProfile().getFunctions().get(name);
        if (f != null) {
            return node;
        }
        return switch (name) {
            case RECIPROCAL_RANK_FUSION -> this.transform(this.expandRRF(ref), context);
            case NORMALIZE_LINEAR -> this.transformNormLin(ref, context);
            case RECIPROCAL_RANK -> this.transformRRank(ref, context);
            default -> node;
        };
    }

    private ExpressionNode expandRRF(Reference ref) {
        Arguments args = ref.arguments();
        if (args.size() < 2) {
            throw new IllegalArgumentException("must have at least 2 arguments: " + String.valueOf(ref));
        }
        ArrayList<ReferenceNode> children = new ArrayList<ReferenceNode>();
        ArrayList<Operator> operators = new ArrayList<Operator>();
        for (ExpressionNode arg : args.expressions()) {
            if (!children.isEmpty()) {
                operators.add(Operator.plus);
            }
            children.add(new ReferenceNode(RECIPROCAL_RANK, List.of(arg), null));
        }
        return new OperationNode(children, operators);
    }

    private ExpressionNode transformNormLin(Reference ref, RankProfileTransformContext context) {
        Arguments args = ref.arguments();
        if (args.size() != 1) {
            throw new IllegalArgumentException("must have exactly 1 argument: " + String.valueOf(ref));
        }
        ExpressionNode input = (ExpressionNode)args.expressions().get(0);
        if (input instanceof ReferenceNode) {
            ReferenceNode inputRefNode = (ReferenceNode)input;
            Reference inputRef = inputRefNode.reference();
            RankProfile.RankFeatureNormalizer normalizer = RankProfile.RankFeatureNormalizer.linear(ref, inputRef);
            context.rankProfile().addFeatureNormalizer(normalizer);
            Reference newRef = Reference.fromIdentifier((String)normalizer.name());
            return new ReferenceNode(newRef);
        }
        throw new IllegalArgumentException("the first argument must be a simple feature: " + String.valueOf(ref) + " => " + String.valueOf(input.getClass()));
    }

    private ExpressionNode transformRRank(Reference ref, RankProfileTransformContext context) {
        ExpressionNode input;
        Arguments args = ref.arguments();
        if (args.size() < 1 || args.size() > 2) {
            throw new IllegalArgumentException("must have 1 or 2 arguments: " + String.valueOf(ref));
        }
        double k = 60.0;
        if (args.size() == 2) {
            ExpressionNode kArg = (ExpressionNode)args.expressions().get(1);
            if (kArg instanceof ConstantNode) {
                ConstantNode kNode = (ConstantNode)kArg;
                k = kNode.getValue().asDouble();
            } else {
                throw new IllegalArgumentException("the second argument (k) must be a constant in: " + String.valueOf(ref));
            }
        }
        if ((input = (ExpressionNode)args.expressions().get(0)) instanceof ReferenceNode) {
            ReferenceNode inputRefNode = (ReferenceNode)input;
            Reference inputRef = inputRefNode.reference();
            RankProfile.RankFeatureNormalizer normalizer = RankProfile.RankFeatureNormalizer.rrank(ref, inputRef, k);
            context.rankProfile().addFeatureNormalizer(normalizer);
            Reference newRef = Reference.fromIdentifier((String)normalizer.name());
            return new ReferenceNode(newRef);
        }
        throw new IllegalArgumentException("the first argument must be a simple feature: " + String.valueOf(ref));
    }
}

