/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.isthmus.expression;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Streams;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ParameterizedType;
import io.substrait.function.ToTypeString;
import io.substrait.isthmus.TypeConverter;
import io.substrait.isthmus.Utils;
import io.substrait.isthmus.expression.EnumConverter;
import io.substrait.isthmus.expression.FunctionMappings;
import io.substrait.isthmus.expression.IgnoreNullableAndParameters;
import io.substrait.type.Type;
import io.substrait.type.TypeVisitor;
import io.substrait.util.Util;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class FunctionConverter<F extends SimpleExtension.Function, T, C extends GenericCall> {
    private static final Logger LOGGER = LoggerFactory.getLogger(FunctionConverter.class);
    protected final Map<SqlOperator, FunctionFinder> signatures;
    protected final RelDataTypeFactory typeFactory;
    protected final TypeConverter typeConverter;
    protected final RexBuilder rexBuilder;
    protected final Multimap<String, SqlOperator> substraitFuncKeyToSqlOperatorMap;

    public FunctionConverter(List<F> functions, RelDataTypeFactory typeFactory) {
        this(functions, Collections.EMPTY_LIST, typeFactory, TypeConverter.DEFAULT);
    }

    public FunctionConverter(List<F> functions, List<FunctionMappings.Sig> additionalSignatures, RelDataTypeFactory typeFactory, TypeConverter typeConverter) {
        this.rexBuilder = new RexBuilder(typeFactory);
        this.typeConverter = typeConverter;
        ArrayList<FunctionMappings.Sig> signatures = new ArrayList<FunctionMappings.Sig>(this.getSigs().size() + additionalSignatures.size());
        signatures.addAll(additionalSignatures);
        signatures.addAll((Collection<FunctionMappings.Sig>)this.getSigs());
        this.typeFactory = typeFactory;
        this.substraitFuncKeyToSqlOperatorMap = ArrayListMultimap.create();
        ArrayListMultimap alm = ArrayListMultimap.create();
        for (SimpleExtension.Function f : functions) {
            alm.put((Object)f.name().toLowerCase(Locale.ROOT), (Object)f);
        }
        Multimap calciteOperators = (Multimap)signatures.stream().collect(Multimaps.toMultimap(FunctionMappings.Sig::name, Function.identity(), ArrayListMultimap::create));
        IdentityHashMap<SqlOperator, FunctionFinder> matcherMap = new IdentityHashMap<SqlOperator, FunctionFinder>();
        for (String key : alm.keySet()) {
            Collection sigs = calciteOperators.get((Object)key);
            if (sigs.isEmpty()) {
                LOGGER.atDebug().log("No binding for function: {}", (Object)key);
            }
            for (FunctionMappings.Sig sig : sigs) {
                List implList = alm.get((Object)key);
                if (implList.isEmpty()) continue;
                matcherMap.put(sig.operator(), new FunctionFinder(key, sig.operator(), implList));
            }
        }
        for (Map.Entry entry : alm.entries()) {
            String key = (String)entry.getKey();
            SimpleExtension.Function func = (SimpleExtension.Function)entry.getValue();
            for (FunctionMappings.Sig sig : calciteOperators.get((Object)key)) {
                this.substraitFuncKeyToSqlOperatorMap.put((Object)func.key(), (Object)sig.operator());
            }
        }
        this.signatures = matcherMap;
    }

    public Optional<SqlOperator> getSqlOperatorFromSubstraitFunc(String key, Type outputType) {
        Map<SqlOperator, FunctionMappings.TypeBasedResolver> resolver = this.getTypeBasedResolver();
        Collection operators = this.substraitFuncKeyToSqlOperatorMap.get((Object)key);
        if (operators.isEmpty()) {
            return Optional.empty();
        }
        if (operators.size() == 1) {
            return Optional.of((SqlOperator)operators.iterator().next());
        }
        String outputTypeStr = (String)outputType.accept((TypeVisitor)ToTypeString.INSTANCE);
        List resolvedOperators = operators.stream().filter(operator -> resolver.containsKey(operator) && ((FunctionMappings.TypeBasedResolver)resolver.get(operator)).types().contains(outputTypeStr)).collect(Collectors.toList());
        if (resolvedOperators.size() == 1) {
            return Optional.of((SqlOperator)resolvedOperators.get(0));
        }
        if (resolvedOperators.size() > 1) {
            throw new IllegalStateException(String.format("Found %d SqlOperators: %s for ScalarFunction %s: ", resolvedOperators.size(), resolvedOperators, key));
        }
        return Optional.empty();
    }

    private Map<SqlOperator, FunctionMappings.TypeBasedResolver> getTypeBasedResolver() {
        return FunctionMappings.OPERATOR_RESOLVER;
    }

    protected abstract ImmutableList<FunctionMappings.Sig> getSigs();

    private static List<Expression> coerceArguments(List<Expression> arguments, Type targetType) {
        return arguments.stream().map(a -> FunctionConverter.coerceArgument(a, targetType)).collect(Collectors.toList());
    }

    private static Expression coerceArgument(Expression argument, Type type) {
        if (FunctionConverter.isMatch((ParameterizedType)type, (ParameterizedType)argument.getType())) {
            return argument;
        }
        return ExpressionCreator.cast((Type)type, (Expression)argument, (Expression.FailureBehavior)Expression.FailureBehavior.THROW_EXCEPTION);
    }

    protected abstract T generateBinding(C var1, F var2, List<? extends FunctionArg> var3, Type var4);

    private static boolean isMatch(ParameterizedType actualType, ParameterizedType targetType) {
        if (targetType.isWildcard()) {
            return true;
        }
        return (Boolean)actualType.accept((TypeVisitor)new IgnoreNullableAndParameters(targetType));
    }

    @FunctionalInterface
    private static interface SingularArgumentMatcher<F> {
        public Optional<F> tryMatch(Type var1, Type var2);
    }

    public static interface GenericCall {
        public Stream<RexNode> getOperands();

        public RelDataType getType();
    }

    protected class FunctionFinder {
        private final String substraitName;
        private final SqlOperator operator;
        private final List<F> functions;
        private final Map<String, F> directMap;
        private final Optional<SingularArgumentMatcher<F>> singularInputType;
        private final Util.IntRange argRange;

        public FunctionFinder(String substraitName, SqlOperator operator, List<F> functions) {
            this.substraitName = substraitName;
            this.operator = operator;
            this.functions = functions;
            this.argRange = Util.IntRange.of((int)functions.stream().mapToInt(t -> t.getRange().getStartInclusive()).min().getAsInt(), (int)functions.stream().mapToInt(t -> t.getRange().getEndExclusive()).max().getAsInt());
            this.singularInputType = this.getSingularInputType(functions);
            ImmutableMap.Builder directMap = ImmutableMap.builder();
            for (SimpleExtension.Function func : functions) {
                String key = func.key();
                directMap.put((Object)key, (Object)func);
                if (func.requiredArguments().size() == func.args().size()) continue;
                directMap.put((Object)SimpleExtension.Function.constructKey((String)substraitName, (List)func.requiredArguments()), (Object)func);
            }
            this.directMap = directMap.build();
        }

        public boolean allowedArgCount(int count) {
            return this.argRange.within(count);
        }

        private Optional<F> signatureMatch(List<Type> inputTypes, Type outputType) {
            for (SimpleExtension.Function function : this.functions) {
                List args = function.requiredArguments();
                if (!(function.returnType() instanceof ParameterizedType) || !FunctionConverter.isMatch((ParameterizedType)outputType, (ParameterizedType)function.returnType()) || !this.inputTypesMatchDefinedArguments(inputTypes, args)) continue;
                return Optional.of(function);
            }
            return Optional.empty();
        }

        private boolean inputTypesMatchDefinedArguments(List<Type> inputTypes, List<SimpleExtension.Argument> args) {
            HashMap<String, Set> wildcardToType = new HashMap<String, Set>();
            for (int i = 0; i < inputTypes.size(); ++i) {
                SimpleExtension.ValueArgument wantType;
                Type givenType = inputTypes.get(i);
                if (!FunctionConverter.isMatch((ParameterizedType)givenType, (wantType = (SimpleExtension.ValueArgument)args.get(Integer.min(i, args.size() - 1))).value())) {
                    return false;
                }
                if (!wantType.value().isWildcard()) continue;
                wildcardToType.computeIfAbsent((String)wantType.value().accept((TypeVisitor)ToTypeString.ToTypeLiteralStringLossless.INSTANCE), k -> new HashSet()).add(givenType);
            }
            return wildcardToType.values().stream().allMatch(s -> s.size() == 1);
        }

        private Optional<SingularArgumentMatcher<F>> getSingularInputType(List<F> functions) {
            ArrayList matchers = new ArrayList();
            for (SimpleExtension.Function f : functions) {
                ParameterizedType firstType = null;
                for (SimpleExtension.Argument a : f.requiredArguments()) {
                    if (!(a instanceof SimpleExtension.ValueArgument)) {
                        firstType = null;
                        break;
                    }
                    ParameterizedType pt = ((SimpleExtension.ValueArgument)a).value();
                    if (firstType == null) {
                        firstType = pt;
                        continue;
                    }
                    if (FunctionConverter.isMatch(firstType, pt)) continue;
                    firstType = null;
                    break;
                }
                if (firstType == null) continue;
                matchers.add(this.singular(f, firstType));
            }
            switch (matchers.size()) {
                case 0: {
                    return Optional.empty();
                }
                case 1: {
                    return Optional.of((SingularArgumentMatcher)matchers.get(0));
                }
            }
            return Optional.of(this.chained(matchers));
        }

        private SingularArgumentMatcher<F> singular(F function, ParameterizedType type) {
            return (inputType, outputType) -> {
                boolean check = FunctionConverter.isMatch((ParameterizedType)inputType, type);
                if (check) {
                    return Optional.of(function);
                }
                return Optional.empty();
            };
        }

        private SingularArgumentMatcher<F> chained(List<SingularArgumentMatcher<F>> matchers) {
            return (inputType, outputType) -> {
                for (SingularArgumentMatcher s : matchers) {
                    Optional outcome = s.tryMatch(inputType, outputType);
                    if (!outcome.isPresent()) continue;
                    return outcome;
                }
                return Optional.empty();
            };
        }

        private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes) {
            assert (rexOperands.size() == opTypes.size());
            if (rexOperands.isEmpty()) {
                return Stream.of("");
            }
            List argTypeLists = Streams.zip(rexOperands.stream(), opTypes.stream(), (rexArg, opType) -> {
                boolean isOption = false;
                if (rexArg instanceof RexLiteral) {
                    isOption = ((RexLiteral)rexArg).getValue() instanceof Enum;
                }
                return isOption ? List.of("req", "opt") : List.of(opType);
            }).collect(Collectors.toList());
            return Utils.crossProduct(argTypeLists).map(typList -> typList.stream().collect(Collectors.joining("_")));
        }

        public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelConverter) {
            List<RexNode> operandsList = call.getOperands().collect(Collectors.toList());
            List<Expression> operands = call.getOperands().map(topLevelConverter).collect(Collectors.toList());
            List opTypes = operands.stream().map(Expression::getType).collect(Collectors.toList());
            Type outputType = FunctionConverter.this.typeConverter.toSubstrait(call.getType());
            List<String> typeStrings = opTypes.stream().map(t -> (String)t.accept((TypeVisitor)ToTypeString.INSTANCE)).collect(Collectors.toList());
            Stream<String> possibleKeys = this.matchKeys(operandsList, typeStrings);
            Optional<String> directMatchKey = possibleKeys.map(argList -> this.substraitName + ":" + argList).filter(this.directMap::containsKey).findFirst();
            if (directMatchKey.isPresent()) {
                SimpleExtension.Function variant = (SimpleExtension.Function)this.directMap.get(directMatchKey.get());
                variant.validateOutputType(operands, outputType);
                List funcArgs = IntStream.range(0, operandsList.size()).mapToObj(i -> {
                    RexNode r = (RexNode)operandsList.get(i);
                    Expression o = (Expression)operands.get(i);
                    if (EnumConverter.isEnumValue(r)) {
                        return EnumConverter.fromRex(variant, (RexLiteral)r, i).orElse(null);
                    }
                    return o;
                }).collect(Collectors.toList());
                boolean allArgsMapped = funcArgs.stream().filter(Objects::isNull).findFirst().isEmpty();
                if (allArgsMapped) {
                    return Optional.of(FunctionConverter.this.generateBinding(call, variant, funcArgs, outputType));
                }
                return Optional.empty();
            }
            if (this.singularInputType.isPresent()) {
                Optional coerced = this.matchCoerced(call, outputType, operands);
                if (coerced.isPresent()) {
                    return coerced;
                }
                Optional leastRestrictive = this.matchByLeastRestrictive(call, outputType, operands);
                if (leastRestrictive.isPresent()) {
                    return leastRestrictive;
                }
            }
            return Optional.empty();
        }

        private Optional<T> matchByLeastRestrictive(C call, Type outputType, List<Expression> operands) {
            RelDataType leastRestrictive = FunctionConverter.this.typeFactory.leastRestrictive(call.getOperands().map(RexNode::getType).collect(Collectors.toList()));
            if (leastRestrictive == null) {
                return Optional.empty();
            }
            Type type = FunctionConverter.this.typeConverter.toSubstrait(leastRestrictive);
            Optional out = this.singularInputType.orElseThrow().tryMatch(type, outputType);
            return out.map(declaration -> {
                List<Expression> coercedArgs = FunctionConverter.coerceArguments(operands, type);
                declaration.validateOutputType(coercedArgs, outputType);
                return FunctionConverter.this.generateBinding(call, (SimpleExtension.Function)out.get(), coercedArgs, outputType);
            });
        }

        private Optional<T> matchCoerced(C call, Type outputType, List<Expression> expressions) {
            List<Type> operandTypes = call.getOperands().map(RexNode::getType).map(FunctionConverter.this.typeConverter::toSubstrait).collect(Collectors.toList());
            Optional matchFunction = this.signatureMatch(operandTypes, outputType);
            if (matchFunction.isEmpty()) {
                return Optional.empty();
            }
            List coercedArgs = Streams.zip(expressions.stream(), operandTypes.stream(), (x$0, x$1) -> FunctionConverter.coerceArgument(x$0, x$1)).collect(Collectors.toList());
            return Optional.of(FunctionConverter.this.generateBinding(call, (SimpleExtension.Function)matchFunction.get(), coercedArgs, outputType));
        }

        protected String getSubstraitName() {
            return this.substraitName;
        }

        public SqlOperator getOperator() {
            return this.operator;
        }
    }
}

