/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.isthmus.expression;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Streams;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ParameterizedType;
import io.substrait.function.ToTypeString;
import io.substrait.isthmus.TypeConverter;
import io.substrait.isthmus.Utils;
import io.substrait.isthmus.expression.EnumConverter;
import io.substrait.isthmus.expression.FunctionMappings;
import io.substrait.isthmus.expression.IgnoreNullableAndParameters;
import io.substrait.type.Type;
import io.substrait.type.TypeVisitor;
import io.substrait.util.Util;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class FunctionConverter<F extends SimpleExtension.Function, T, C extends GenericCall> {
    static final Logger logger = LoggerFactory.getLogger(FunctionConverter.class);
    protected final Map<SqlOperator, FunctionFinder> signatures;
    protected final RelDataTypeFactory typeFactory;
    protected final TypeConverter typeConverter;
    protected final RexBuilder rexBuilder;
    protected final Multimap<String, SqlOperator> substraitFuncKeyToSqlOperatorMap;

    public FunctionConverter(List<F> functions, RelDataTypeFactory typeFactory) {
        this(functions, Collections.EMPTY_LIST, typeFactory, TypeConverter.DEFAULT);
    }

    public FunctionConverter(List<F> functions, List<FunctionMappings.Sig> additionalSignatures, RelDataTypeFactory typeFactory, TypeConverter typeConverter) {
        this.rexBuilder = new RexBuilder(typeFactory);
        this.typeConverter = typeConverter;
        ArrayList<FunctionMappings.Sig> signatures = new ArrayList<FunctionMappings.Sig>(this.getSigs().size() + additionalSignatures.size());
        signatures.addAll(additionalSignatures);
        signatures.addAll((Collection<FunctionMappings.Sig>)this.getSigs());
        this.typeFactory = typeFactory;
        this.substraitFuncKeyToSqlOperatorMap = ArrayListMultimap.create();
        ArrayListMultimap alm = ArrayListMultimap.create();
        for (SimpleExtension.Function f2 : functions) {
            alm.put((Object)f2.name().toLowerCase(Locale.ROOT), (Object)f2);
        }
        ListMultimap calciteOperators = (ListMultimap)signatures.stream().collect(Multimaps.toMultimap(FunctionMappings.Sig::name, f -> f, () -> ArrayListMultimap.create()));
        IdentityHashMap<SqlOperator, FunctionFinder> matcherMap = new IdentityHashMap<SqlOperator, FunctionFinder>();
        for (String key : alm.keySet()) {
            List sigs = calciteOperators.get((Object)key);
            if (sigs == null) {
                logger.info("Dropping function due to no binding: {}", (Object)key);
                continue;
            }
            for (FunctionMappings.Sig sig : sigs) {
                List implList = alm.get((Object)key);
                if (implList == null || implList.isEmpty()) continue;
                matcherMap.put(sig.operator(), new FunctionFinder(key, sig.operator(), implList));
            }
        }
        for (Map.Entry entry : alm.entries()) {
            String key = (String)entry.getKey();
            SimpleExtension.Function func = (SimpleExtension.Function)entry.getValue();
            for (FunctionMappings.Sig sig : calciteOperators.get((Object)key)) {
                this.substraitFuncKeyToSqlOperatorMap.put((Object)func.key(), (Object)sig.operator());
            }
        }
        this.signatures = matcherMap;
    }

    public Optional<SqlOperator> getSqlOperatorFromSubstraitFunc(String key, Type outputType) {
        Map<SqlOperator, FunctionMappings.TypeBasedResolver> resolver = this.getTypeBasedResolver();
        if (!this.substraitFuncKeyToSqlOperatorMap.containsKey((Object)key)) {
            return Optional.empty();
        }
        Collection operators = this.substraitFuncKeyToSqlOperatorMap.get((Object)key);
        if (operators.size() == 1) {
            return Optional.of((SqlOperator)operators.iterator().next());
        }
        String outputTypeStr = (String)outputType.accept((TypeVisitor)ToTypeString.INSTANCE);
        List resolvedOperators = operators.stream().filter(operator -> resolver.containsKey(operator) && ((FunctionMappings.TypeBasedResolver)resolver.get(operator)).types().contains(outputTypeStr)).collect(Collectors.toList());
        if (resolvedOperators.size() == 1) {
            return Optional.of((SqlOperator)resolvedOperators.get(0));
        }
        if (resolvedOperators.size() > 1) {
            throw new RuntimeException(String.format("Found %d SqlOperators: %s for ScalarFunction %s: ", resolvedOperators.size(), resolvedOperators, key));
        }
        return Optional.empty();
    }

    private Map<SqlOperator, FunctionMappings.TypeBasedResolver> getTypeBasedResolver() {
        return FunctionMappings.OPERATOR_RESOLVER;
    }

    protected abstract ImmutableList<FunctionMappings.Sig> getSigs();

    private static List<Expression> coerceArguments(List<Expression> arguments, Type type) {
        return arguments.stream().map(a -> FunctionConverter.coerceArgument(a, type)).collect(Collectors.toList());
    }

    private static Expression coerceArgument(Expression argument, Type type) {
        boolean typeMatches = FunctionConverter.isMatch(type, (ParameterizedType)argument.getType());
        if (!typeMatches) {
            return ExpressionCreator.cast((Type)type, (Expression)argument, (Expression.FailureBehavior)Expression.FailureBehavior.THROW_EXCEPTION);
        }
        return argument;
    }

    protected abstract T generateBinding(C var1, F var2, List<FunctionArg> var3, Type var4);

    private static SignatureMatcher chainedSignature(SignatureMatcher ... matchers) {
        SignatureMatcher signatureMatcher;
        switch (matchers.length) {
            case 0: {
                signatureMatcher = (types, outputType) -> Optional.empty();
                break;
            }
            case 1: {
                signatureMatcher = matchers[0];
                break;
            }
            default: {
                signatureMatcher = (types, outputType) -> {
                    for (SignatureMatcher m : matchers) {
                        Optional t = m.tryMatch(types, outputType);
                        if (!t.isPresent()) continue;
                        return t;
                    }
                    return Optional.empty();
                };
            }
        }
        return signatureMatcher;
    }

    private static boolean isMatch(Type inputType, ParameterizedType type) {
        if (type.isWildcard()) {
            return true;
        }
        return (Boolean)inputType.accept((TypeVisitor)new IgnoreNullableAndParameters(type));
    }

    private static boolean isMatch(ParameterizedType inputType, ParameterizedType type) {
        if (type.isWildcard()) {
            return true;
        }
        return (Boolean)inputType.accept((TypeVisitor)new IgnoreNullableAndParameters(type));
    }

    protected class FunctionFinder {
        private final String name;
        private final SqlOperator operator;
        private final List<F> functions;
        private final Map<String, F> directMap;
        private final SignatureMatcher<F> matcher;
        private final Optional<SingularArgumentMatcher<F>> singularInputType;
        private final Util.IntRange argRange;

        public FunctionFinder(String name, SqlOperator operator, List<F> functions) {
            this.name = name;
            this.operator = operator;
            this.functions = functions;
            this.argRange = Util.IntRange.of((int)functions.stream().mapToInt(t -> t.getRange().getStartInclusive()).min().getAsInt(), (int)functions.stream().mapToInt(t -> t.getRange().getEndExclusive()).max().getAsInt());
            this.matcher = FunctionFinder.getSignatureMatcher(operator, functions);
            this.singularInputType = FunctionFinder.getSingularInputType(functions);
            ImmutableMap.Builder directMap = ImmutableMap.builder();
            for (SimpleExtension.Function func : functions) {
                String key = func.key();
                directMap.put((Object)key, (Object)func);
                if (func.requiredArguments().size() == func.args().size()) continue;
                directMap.put((Object)SimpleExtension.Function.constructKey((String)name, (List)func.requiredArguments()), (Object)func);
            }
            this.directMap = directMap.build();
        }

        public boolean allowedArgCount(int count) {
            return this.argRange.within(count);
        }

        private static <F extends SimpleExtension.Function> SignatureMatcher<F> getSignatureMatcher(SqlOperator operator, List<F> functions) {
            return (inputTypes, outputType) -> {
                for (SimpleExtension.Function function : functions) {
                    List args = function.requiredArguments();
                    if (!(function.returnType() instanceof ParameterizedType) || !FunctionConverter.isMatch(outputType, (ParameterizedType)function.returnType()) || !FunctionFinder.inputTypesSatisfyDefinedArguments(inputTypes, args)) continue;
                    return Optional.of(function);
                }
                return Optional.empty();
            };
        }

        private static boolean inputTypesSatisfyDefinedArguments(List<Type> inputTypes, List<SimpleExtension.Argument> args) {
            HashMap<String, Set> wildcardToType = new HashMap<String, Set>();
            for (int i = 0; i < inputTypes.size(); ++i) {
                SimpleExtension.ValueArgument wantType;
                Type givenType = inputTypes.get(i);
                if (!FunctionConverter.isMatch(givenType, (wantType = (SimpleExtension.ValueArgument)args.get(Integer.min(i, args.size() - 1))).value())) {
                    return false;
                }
                if (!wantType.value().isWildcard()) continue;
                wildcardToType.computeIfAbsent((String)wantType.value().accept((TypeVisitor)ToTypeString.ToTypeLiteralStringLossless.INSTANCE), k -> new HashSet()).add(givenType);
            }
            return wildcardToType.values().stream().allMatch(s -> s.size() == 1);
        }

        private static <F extends SimpleExtension.Function> Optional<SingularArgumentMatcher<F>> getSingularInputType(List<F> functions) {
            Optional<SingularArgumentMatcher<F>> optional;
            ArrayList<SingularArgumentMatcher> matchers = new ArrayList<SingularArgumentMatcher>();
            for (SimpleExtension.Function f : functions) {
                ParameterizedType firstType = null;
                for (SimpleExtension.Argument a : f.requiredArguments()) {
                    if (!(a instanceof SimpleExtension.ValueArgument)) {
                        firstType = null;
                        break;
                    }
                    ParameterizedType pt = ((SimpleExtension.ValueArgument)a).value();
                    if (firstType == null) {
                        firstType = pt;
                        continue;
                    }
                    if (FunctionConverter.isMatch(firstType, pt)) continue;
                    firstType = null;
                    break;
                }
                if (firstType == null) continue;
                matchers.add(FunctionFinder.singular(f, firstType));
            }
            switch (matchers.size()) {
                case 0: {
                    optional = Optional.empty();
                    break;
                }
                case 1: {
                    optional = Optional.of((SingularArgumentMatcher)matchers.get(0));
                    break;
                }
                default: {
                    optional = Optional.of(FunctionFinder.chained(matchers));
                }
            }
            return optional;
        }

        public static <F extends SimpleExtension.Function> SingularArgumentMatcher<F> singular(F function, ParameterizedType type) {
            return (inputType, outputType) -> {
                boolean check = FunctionConverter.isMatch(inputType, type);
                if (check) {
                    return Optional.of(function);
                }
                return Optional.empty();
            };
        }

        public static SingularArgumentMatcher chained(List<SingularArgumentMatcher> matchers) {
            return (inputType, outputType) -> {
                for (SingularArgumentMatcher s : matchers) {
                    Optional outcome = s.tryMatch(inputType, outputType);
                    if (!outcome.isPresent()) continue;
                    return outcome;
                }
                return Optional.empty();
            };
        }

        private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes) {
            assert (rexOperands.size() == opTypes.size());
            if (rexOperands.size() == 0) {
                return Stream.of("");
            }
            List argTypeLists = Streams.zip(rexOperands.stream(), opTypes.stream(), (rexArg, opType) -> {
                boolean isOption = false;
                if (rexArg instanceof RexLiteral) {
                    isOption = ((RexLiteral)rexArg).getValue() instanceof Enum;
                }
                return isOption ? List.of("req", "opt") : List.of(opType);
            }).collect(Collectors.toList());
            return Utils.crossProduct(argTypeLists).map(typList -> typList.stream().collect(Collectors.joining("_")));
        }

        public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelConverter) {
            List<Expression> operands = call.getOperands().map(topLevelConverter).collect(Collectors.toList());
            List opTypes = operands.stream().map(Expression::getType).collect(Collectors.toList());
            Type outputType = FunctionConverter.this.typeConverter.toSubstrait(call.getType());
            List<String> typeStrings = opTypes.stream().map(t -> (String)t.accept((TypeVisitor)ToTypeString.INSTANCE)).collect(Collectors.toList());
            Stream<String> possibleKeys = this.matchKeys(call.getOperands().collect(Collectors.toList()), typeStrings);
            Optional<String> directMatchKey = possibleKeys.map(argList -> this.name + ":" + argList).filter(k -> this.directMap.containsKey(k)).findFirst();
            if (directMatchKey.isPresent()) {
                SimpleExtension.Function variant = (SimpleExtension.Function)this.directMap.get(directMatchKey.get());
                variant.validateOutputType(operands, outputType);
                List<FunctionArg> funcArgs = Streams.zip(call.getOperands(), operands.stream(), (r, o) -> {
                    if (EnumConverter.isEnumValue(r)) {
                        return (FunctionArg)EnumConverter.fromRex(variant, (RexLiteral)r).orElseGet(() -> null);
                    }
                    return o;
                }).collect(Collectors.toList());
                boolean allArgsMapped = funcArgs.stream().filter(e -> e == null).findFirst().isEmpty();
                if (allArgsMapped) {
                    return Optional.of(FunctionConverter.this.generateBinding(call, variant, funcArgs, outputType));
                }
                return Optional.empty();
            }
            if (this.singularInputType.isPresent()) {
                Optional coerced = this.matchCoerced(call, outputType, operands);
                if (coerced.isPresent()) {
                    return coerced;
                }
                Optional leastRestrictive = this.matchByLeastRestrictive(call, outputType, operands);
                if (leastRestrictive.isPresent()) {
                    return leastRestrictive;
                }
            }
            return Optional.empty();
        }

        private Optional<T> matchByLeastRestrictive(C call, Type outputType, List<Expression> operands) {
            RelDataType leastRestrictive = FunctionConverter.this.typeFactory.leastRestrictive(call.getOperands().map(RexNode::getType).collect(Collectors.toList()));
            if (leastRestrictive == null) {
                return Optional.empty();
            }
            Type type = FunctionConverter.this.typeConverter.toSubstrait(leastRestrictive);
            Optional out = this.singularInputType.get().tryMatch(type, outputType);
            if (out.isPresent()) {
                SimpleExtension.Function declaration = (SimpleExtension.Function)out.get();
                List<Expression> coercedArgs = FunctionConverter.coerceArguments(operands, type);
                declaration.validateOutputType(coercedArgs, outputType);
                return Optional.of(FunctionConverter.this.generateBinding(call, (SimpleExtension.Function)out.get(), coercedArgs.stream().map(FunctionArg.class::cast).collect(Collectors.toList()), outputType));
            }
            return Optional.empty();
        }

        private Optional<T> matchCoerced(C call, Type outputType, List<Expression> operands) {
            List<Type> allTypes = call.getOperands().map(RexNode::getType).map(FunctionConverter.this.typeConverter::toSubstrait).collect(Collectors.toList());
            Optional matchFunction = this.matcher.tryMatch(allTypes, outputType);
            if (matchFunction.isPresent()) {
                List coerced = Streams.zip(operands.stream(), call.getOperands(), (a, b) -> {
                    Type type = FunctionConverter.this.typeConverter.toSubstrait(b.getType());
                    return FunctionConverter.coerceArgument(a, type);
                }).collect(Collectors.toList());
                return Optional.of(FunctionConverter.this.generateBinding(call, (SimpleExtension.Function)matchFunction.get(), coerced.stream().map(FunctionArg.class::cast).collect(Collectors.toList()), outputType));
            }
            return Optional.empty();
        }

        protected String getName() {
            return this.name;
        }

        public SqlOperator getOperator() {
            return this.operator;
        }
    }

    public static interface SignatureMatcher<F> {
        public Optional<F> tryMatch(List<Type> var1, Type var2);
    }

    public static interface SingularArgumentMatcher<F> {
        public Optional<F> tryMatch(Type var1, Type var2);
    }

    public static interface GenericCall {
        public Stream<RexNode> getOperands();

        public RelDataType getType();
    }
}

