/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.isthmus.expression;

import io.substrait.expression.EnumArg;
import io.substrait.expression.Expression;
import io.substrait.expression.FunctionArg;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.expression.ScalarFunctionMapper;
import io.substrait.isthmus.expression.SubstraitFunctionMapping;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlTrimFunction;
import org.apache.calcite.sql.type.SqlTypeName;

final class TrimFunctionMapper
implements ScalarFunctionMapper {
    private final Map<Trim, List<SimpleExtension.ScalarFunctionVariant>> trimFunctions;

    public TrimFunctionMapper(List<SimpleExtension.ScalarFunctionVariant> functions) {
        HashMap<Trim, List<SimpleExtension.ScalarFunctionVariant>> trims = new HashMap<Trim, List<SimpleExtension.ScalarFunctionVariant>>();
        for (Trim t : Trim.values()) {
            List<SimpleExtension.ScalarFunctionVariant> funcs = this.findFunction(t.substraitName(), functions);
            if (funcs.isEmpty()) continue;
            trims.put(t, funcs);
        }
        this.trimFunctions = Collections.unmodifiableMap(trims);
    }

    private List<SimpleExtension.ScalarFunctionVariant> findFunction(String name, Collection<SimpleExtension.ScalarFunctionVariant> functions) {
        return functions.stream().filter(f -> name.equals(f.name())).collect(Collectors.toUnmodifiableList());
    }

    @Override
    public Optional<SubstraitFunctionMapping> toSubstrait(RexCall call) {
        if (!SqlStdOperatorTable.TRIM.equals((Object)call.op)) {
            return Optional.empty();
        }
        Optional<Trim> trimType = this.getTrimCallType(call);
        return trimType.map(trim -> {
            List<SimpleExtension.ScalarFunctionVariant> functions = this.trimFunctions.getOrDefault(trim, List.of());
            if (functions.isEmpty()) {
                return null;
            }
            String name = trim.substraitName();
            List<RexNode> operands = call.getOperands().stream().skip(1L).collect(Collectors.toUnmodifiableList());
            return new SubstraitFunctionMapping(name, operands, functions);
        });
    }

    private Optional<Trim> getTrimCallType(RexCall call) {
        RexNode trimType = (RexNode)call.operands.get(0);
        if (trimType.getType().getSqlTypeName() != SqlTypeName.SYMBOL) {
            return Optional.empty();
        }
        Comparable value = ((RexLiteral)trimType).getValue();
        if (!(value instanceof SqlTrimFunction.Flag)) {
            return Optional.empty();
        }
        return Trim.fromFlag((SqlTrimFunction.Flag)value);
    }

    @Override
    public Optional<List<FunctionArg>> getExpressionArguments(Expression.ScalarFunctionInvocation expression) {
        String name = expression.declaration().name();
        return Trim.fromSubstraitName(name).map(Trim::flag).map(Enum::name).map(EnumArg::of).map(trimTypeArg -> {
            LinkedList<EnumArg> args = new LinkedList<EnumArg>(expression.arguments());
            args.addFirst((EnumArg)trimTypeArg);
            return args;
        });
    }

    private static enum Trim {
        TRIM("trim", SqlTrimFunction.Flag.BOTH),
        LTRIM("ltrim", SqlTrimFunction.Flag.LEADING),
        RTRIM("rtrim", SqlTrimFunction.Flag.TRAILING);

        private final String substraitName;
        private final SqlTrimFunction.Flag flag;

        private Trim(String substraitName, SqlTrimFunction.Flag flag) {
            this.substraitName = substraitName;
            this.flag = flag;
        }

        public String substraitName() {
            return this.substraitName;
        }

        public SqlTrimFunction.Flag flag() {
            return this.flag;
        }

        public static Optional<Trim> fromFlag(SqlTrimFunction.Flag flag) {
            return Arrays.stream(Trim.values()).filter(t -> t.flag == flag).findAny();
        }

        public static Optional<Trim> fromSubstraitName(String name) {
            return Arrays.stream(Trim.values()).filter(t -> t.substraitName.equals(name)).findAny();
        }
    }
}

