/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.gen;

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.OpCode;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.LabelNode;
import io.airlift.slice.Slice;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionNullability;
import io.trino.spi.function.InOut;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.type.Type;
import io.trino.sql.gen.Binding;
import io.trino.sql.gen.Bootstrap;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.InputReferenceCompiler;
import io.trino.type.FunctionType;
import java.lang.constant.Constable;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.TypeDescriptor;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

public final class BytecodeUtils {
    private BytecodeUtils() {
    }

    public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class<?> returnType, Class<?> ... stackArgsToPop) {
        return BytecodeUtils.handleNullValue(scope, label, returnType, ImmutableList.copyOf((Object[])stackArgsToPop), false);
    }

    public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class<?> returnType, Iterable<? extends Class<?>> stackArgsToPop) {
        return BytecodeUtils.handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), false);
    }

    public static BytecodeNode ifWasNullClearPopAndGoto(Scope scope, LabelNode label, Class<?> returnType, Class<?> ... stackArgsToPop) {
        return BytecodeUtils.handleNullValue(scope, label, returnType, ImmutableList.copyOf((Object[])stackArgsToPop), true);
    }

    public static BytecodeNode handleNullValue(Scope scope, LabelNode label, Class<?> returnType, List<Class<?>> stackArgsToPop, boolean clearNullFlag) {
        Variable wasNull = scope.getVariable("wasNull");
        BytecodeBlock nullCheck = new BytecodeBlock().setDescription("ifWasNullGoto").append((BytecodeNode)wasNull);
        String clearComment = null;
        if (clearNullFlag) {
            nullCheck.append((BytecodeNode)wasNull.set(BytecodeExpressions.constantFalse()));
            clearComment = "clear wasNull";
        }
        BytecodeBlock isNull = new BytecodeBlock();
        for (Class<?> parameterType : stackArgsToPop) {
            isNull.pop(parameterType);
        }
        isNull.pushJavaDefault(returnType);
        String loadDefaultComment = "loadJavaDefault(" + returnType.getName() + ")";
        isNull.gotoLabel(label);
        String popComment = null;
        if (!stackArgsToPop.isEmpty()) {
            popComment = String.format("pop(%s)", Joiner.on((String)", ").join(stackArgsToPop));
        }
        return new IfStatement("if wasNull then %s", new Object[]{Joiner.on((String)", ").skipNulls().join((Object)clearComment, (Object)popComment, new Object[]{loadDefaultComment, "goto " + label.getLabel()})}).condition((BytecodeNode)nullCheck).ifTrue((BytecodeNode)isNull);
    }

    public static BytecodeNode unboxPrimitive(Class<?> unboxedType) {
        BytecodeBlock block = new BytecodeBlock().comment("unbox primitive");
        if (unboxedType == Long.TYPE) {
            return block.invokeVirtual(Long.class, "longValue", Long.TYPE, new Class[0]);
        }
        if (unboxedType == Double.TYPE) {
            return block.invokeVirtual(Double.class, "doubleValue", Double.TYPE, new Class[0]);
        }
        if (unboxedType == Boolean.TYPE) {
            return block.invokeVirtual(Boolean.class, "booleanValue", Boolean.TYPE, new Class[0]);
        }
        throw new UnsupportedOperationException("not yet implemented: " + unboxedType);
    }

    public static BytecodeExpression loadConstant(CallSiteBinder callSiteBinder, Object constant, Class<?> type) {
        Binding binding = callSiteBinder.bind(MethodHandles.constant(type, constant));
        return BytecodeUtils.loadConstant(binding);
    }

    public static BytecodeExpression loadConstant(Binding binding) {
        return BytecodeExpressions.invokeDynamic((Method)Bootstrap.BOOTSTRAP_METHOD, (Iterable)ImmutableList.of((Object)binding.getBindingId()), (String)("constant_" + binding.getBindingId()), (Class)binding.getType().returnType(), (BytecodeExpression[])new BytecodeExpression[0]);
    }

    public static BytecodeNode generateInvocation(Scope scope, ResolvedFunction resolvedFunction, FunctionManager functionManager, List<BytecodeNode> arguments, CallSiteBinder binder) {
        return BytecodeUtils.generateInvocation(scope, resolvedFunction.getSignature().getName().getFunctionName(), resolvedFunction.getFunctionNullability(), invocationConvention -> functionManager.getScalarFunctionImplementation(resolvedFunction, (InvocationConvention)invocationConvention), arguments, binder);
    }

    public static BytecodeNode generateInvocation(Scope scope, String functionName, FunctionNullability functionNullability, Function<InvocationConvention, ScalarFunctionImplementation> functionImplementationProvider, List<BytecodeNode> arguments, CallSiteBinder binder) {
        return BytecodeUtils.generateFullInvocation(scope, functionName, functionNullability, Collections.nCopies(arguments.size(), false), functionImplementationProvider, instanceFactory -> {
            throw new IllegalArgumentException("Simple method invocation can not be used with functions that require an instance factory");
        }, (List)arguments.stream().map(BytecodeUtils::simpleArgument).collect(ImmutableList.toImmutableList()), binder);
    }

    private static Function<Optional<Class<?>>, BytecodeNode> simpleArgument(BytecodeNode argument) {
        return lambdaInterface -> {
            Preconditions.checkArgument((boolean)lambdaInterface.isEmpty(), (Object)"Simple method invocation can not be used with functions that have lambda arguments");
            return argument;
        };
    }

    public static BytecodeNode generateFullInvocation(Scope scope, ResolvedFunction resolvedFunction, FunctionManager functionManager, Function<MethodHandle, BytecodeNode> instanceFactory, List<Function<Optional<Class<?>>, BytecodeNode>> argumentCompilers, CallSiteBinder binder) {
        return BytecodeUtils.generateFullInvocation(scope, resolvedFunction.getSignature().getName().getFunctionName(), resolvedFunction.getFunctionNullability(), (List)resolvedFunction.getSignature().getArgumentTypes().stream().map(FunctionType.class::isInstance).collect(ImmutableList.toImmutableList()), invocationConvention -> functionManager.getScalarFunctionImplementation(resolvedFunction, (InvocationConvention)invocationConvention), instanceFactory, argumentCompilers, binder);
    }

    private static BytecodeNode generateFullInvocation(Scope scope, String functionName, FunctionNullability functionNullability, List<Boolean> argumentIsFunctionType, Function<InvocationConvention, ScalarFunctionImplementation> functionImplementationProvider, Function<MethodHandle, BytecodeNode> instanceFactory, List<Function<Optional<Class<?>>, BytecodeNode>> argumentCompilers, CallSiteBinder binder) {
        Verify.verify((argumentIsFunctionType.size() == argumentCompilers.size() ? 1 : 0) != 0);
        ArrayList<InvocationConvention.InvocationArgumentConvention> argumentConventions = new ArrayList<InvocationConvention.InvocationArgumentConvention>();
        ArrayList<BytecodeNode> arguments = new ArrayList<BytecodeNode>();
        for (int i = 0; i < argumentIsFunctionType.size(); ++i) {
            if (argumentIsFunctionType.get(i).booleanValue()) {
                argumentConventions.add(InvocationConvention.InvocationArgumentConvention.FUNCTION);
                arguments.add(null);
                continue;
            }
            BytecodeNode argument = argumentCompilers.get(i).apply(Optional.empty());
            argumentConventions.add(BytecodeUtils.getPreferredArgumentConvention(argument, argumentCompilers.size(), functionNullability.isArgumentNullable(i)));
            arguments.add(argument);
        }
        InvocationConvention invocationConvention = new InvocationConvention(argumentConventions, functionNullability.isReturnNullable() ? InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN : InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, true, true);
        ScalarFunctionImplementation implementation = functionImplementationProvider.apply(invocationConvention);
        Binding binding = binder.bind(implementation.getMethodHandle());
        LabelNode end = new LabelNode("end");
        BytecodeBlock block = new BytecodeBlock().setDescription("invoke " + functionName);
        Optional<BytecodeNode> instance = implementation.getInstanceFactory().map(instanceFactory);
        int realParameterIndex = 0;
        int lambdaArgumentIndex = 0;
        MethodType methodType = binding.getType();
        TypeDescriptor.OfField returnType = methodType.returnType();
        Class unboxedReturnType = Primitives.unwrap((Class)returnType);
        ArrayList stackTypes = new ArrayList();
        boolean instanceIsBound = false;
        for (int currentParameterIndex = 0; currentParameterIndex < methodType.parameterArray().length; ++currentParameterIndex) {
            Class<?> type = methodType.parameterArray()[currentParameterIndex];
            stackTypes.add(type);
            if (instance.isPresent() && !instanceIsBound) {
                Preconditions.checkState((boolean)type.equals(((MethodHandle)implementation.getInstanceFactory().get()).type().returnType()), (Object)"Mismatched type for instance parameter");
                block.append(instance.get());
                instanceIsBound = true;
                continue;
            }
            if (type == ConnectorSession.class) {
                block.append((BytecodeNode)scope.getVariable("session"));
                continue;
            }
            switch (invocationConvention.getArgumentConvention(realParameterIndex)) {
                case NEVER_NULL: {
                    block.append((BytecodeNode)arguments.get(realParameterIndex));
                    Preconditions.checkArgument((!Primitives.isWrapperType(type) ? 1 : 0) != 0, (Object)"Non-nullable argument must not be primitive wrapper type");
                    block.append(BytecodeUtils.ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
                    break;
                }
                case NULL_FLAG: {
                    block.append((BytecodeNode)arguments.get(realParameterIndex));
                    block.append((BytecodeNode)scope.getVariable("wasNull"));
                    block.append((BytecodeNode)scope.getVariable("wasNull").set(BytecodeExpressions.constantFalse()));
                    stackTypes.add(Boolean.TYPE);
                    ++currentParameterIndex;
                    break;
                }
                case BOXED_NULLABLE: {
                    block.append((BytecodeNode)arguments.get(realParameterIndex));
                    block.append(BytecodeUtils.boxPrimitiveIfNecessary(scope, type));
                    block.append((BytecodeNode)scope.getVariable("wasNull").set(BytecodeExpressions.constantFalse()));
                    break;
                }
                case BLOCK_POSITION: {
                    InputReferenceCompiler.InputReferenceNode inputReferenceNode = (InputReferenceCompiler.InputReferenceNode)arguments.get(realParameterIndex);
                    block.append(inputReferenceNode.produceBlockAndPosition());
                    stackTypes.add(Integer.TYPE);
                    if (!functionNullability.isArgumentNullable(realParameterIndex)) {
                        block.append((BytecodeNode)scope.getVariable("wasNull").set(inputReferenceNode.blockAndPositionIsNull()));
                        block.append(BytecodeUtils.ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
                    }
                    ++currentParameterIndex;
                    break;
                }
                case IN_OUT: {
                    block.append((BytecodeNode)arguments.get(realParameterIndex));
                    if (!functionNullability.isArgumentNullable(realParameterIndex)) {
                        block.append((BytecodeNode)arguments.get(realParameterIndex));
                        block.invokeVirtual(InOut.class, "isNull", Boolean.TYPE, new Class[0]);
                        block.putVariable(scope.getVariable("wasNull"));
                        block.append(BytecodeUtils.ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
                    }
                    ++currentParameterIndex;
                    break;
                }
                case FUNCTION: {
                    Class lambdaInterface = (Class)implementation.getLambdaInterfaces().get(lambdaArgumentIndex);
                    block.append(argumentCompilers.get(realParameterIndex).apply(Optional.of(lambdaInterface)));
                    ++lambdaArgumentIndex;
                    break;
                }
                default: {
                    throw new UnsupportedOperationException(String.format("Unsupported argument convention type: %s", invocationConvention.getArgumentConvention(realParameterIndex)));
                }
            }
            ++realParameterIndex;
        }
        block.append((BytecodeNode)BytecodeUtils.invoke(binding, functionName, new BytecodeExpression[0]));
        if (functionNullability.isReturnNullable()) {
            block.append((BytecodeNode)BytecodeUtils.unboxPrimitiveIfNecessary(scope, returnType));
        }
        block.visitLabel(end);
        return block;
    }

    private static InvocationConvention.InvocationArgumentConvention getPreferredArgumentConvention(BytecodeNode argument, int argumentCount, boolean nullable) {
        if (argumentCount <= 64) {
            if (argument instanceof InputReferenceCompiler.InputReferenceNode) {
                return InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
            }
            if (nullable) {
                return InvocationConvention.InvocationArgumentConvention.NULL_FLAG;
            }
        }
        return nullable ? InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE : InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
    }

    public static BytecodeBlock unboxPrimitiveIfNecessary(Scope scope, Class<?> boxedType) {
        BytecodeBlock block = new BytecodeBlock();
        LabelNode end = new LabelNode("end");
        Class unboxedType = Primitives.unwrap(boxedType);
        Variable wasNull = scope.getVariable("wasNull");
        if (unboxedType.isPrimitive()) {
            LabelNode notNull = new LabelNode("notNull");
            block.dup(boxedType).ifNotNullGoto(notNull).append((BytecodeNode)wasNull.set(BytecodeExpressions.constantTrue())).comment("swap boxed null with unboxed default").pop(boxedType).pushJavaDefault(unboxedType).gotoLabel(end).visitLabel(notNull).append(BytecodeUtils.unboxPrimitive(unboxedType));
        } else {
            block.dup(boxedType).ifNotNullGoto(end).append((BytecodeNode)wasNull.set(BytecodeExpressions.constantTrue()));
        }
        block.visitLabel(end);
        return block;
    }

    public static BytecodeNode boxPrimitiveIfNecessary(Scope scope, Class<?> type) {
        Class<Constable> expectedCurrentStackType;
        Preconditions.checkArgument((!type.isPrimitive() ? 1 : 0) != 0, (Object)"cannot box into primitive type");
        if (!Primitives.isWrapperType(type)) {
            return OpCode.NOP;
        }
        BytecodeBlock notNull = new BytecodeBlock().comment("box primitive");
        if (type == Long.class) {
            notNull.invokeStatic(Long.class, "valueOf", Long.class, new Class[]{Long.TYPE});
            expectedCurrentStackType = Long.TYPE;
        } else if (type == Double.class) {
            notNull.invokeStatic(Double.class, "valueOf", Double.class, new Class[]{Double.TYPE});
            expectedCurrentStackType = Double.TYPE;
        } else if (type == Boolean.class) {
            notNull.invokeStatic(Boolean.class, "valueOf", Boolean.class, new Class[]{Boolean.TYPE});
            expectedCurrentStackType = Boolean.TYPE;
        } else {
            throw new UnsupportedOperationException("not yet implemented: " + type);
        }
        BytecodeBlock condition = new BytecodeBlock().append((BytecodeNode)scope.getVariable("wasNull"));
        BytecodeBlock wasNull = new BytecodeBlock().pop(expectedCurrentStackType).pushNull().checkCast(type);
        return new IfStatement().condition((BytecodeNode)condition).ifTrue((BytecodeNode)wasNull).ifFalse((BytecodeNode)notNull);
    }

    public static BytecodeExpression invoke(Binding binding, String name, BytecodeExpression ... parameters) {
        return BytecodeUtils.invoke(binding, name, (List<BytecodeExpression>)ImmutableList.copyOf((Object[])parameters));
    }

    public static BytecodeExpression invoke(Binding binding, String name, List<BytecodeExpression> parameters) {
        return BytecodeExpressions.invokeDynamic((Method)Bootstrap.BOOTSTRAP_METHOD, (Iterable)ImmutableList.of((Object)binding.getBindingId()), (String)BytecodeUtils.sanitizeName(name), (MethodType)binding.getType(), parameters);
    }

    public static BytecodeExpression invoke(Binding binding, BoundSignature signature) {
        return BytecodeUtils.invoke(binding, signature.getName().getFunctionName(), new BytecodeExpression[0]);
    }

    public static String sanitizeName(String name) {
        return name.replaceAll("[^A-Za-z0-9_$]", "_");
    }

    public static BytecodeNode generateWrite(CallSiteBinder callSiteBinder, Scope scope, Variable wasNullVariable, Type type) {
        Class<Object> valueJavaType = type.getJavaType();
        if (!valueJavaType.isPrimitive() && valueJavaType != Slice.class) {
            valueJavaType = Object.class;
        }
        String methodName = "write" + Primitives.wrap((Class)valueJavaType).getSimpleName();
        Variable tempValue = scope.createTempVariable(valueJavaType);
        Variable tempOutput = scope.createTempVariable(BlockBuilder.class);
        return new BytecodeBlock().comment("if (wasNull)").append((BytecodeNode)new IfStatement().condition((BytecodeNode)wasNullVariable).ifTrue((BytecodeNode)new BytecodeBlock().comment("output.appendNull();").pop(valueJavaType).invokeInterface(BlockBuilder.class, "appendNull", BlockBuilder.class, new Class[0]).pop()).ifFalse((BytecodeNode)new BytecodeBlock().comment("%s.%s(output, %s)", new Object[]{type.getTypeSignature(), methodName, valueJavaType.getSimpleName()}).putVariable(tempValue).putVariable(tempOutput).append((BytecodeNode)BytecodeUtils.loadConstant(callSiteBinder.bind(type, Type.class))).getVariable(tempOutput).getVariable(tempValue).invokeInterface(Type.class, methodName, Void.TYPE, new Class[]{BlockBuilder.class, valueJavaType})));
    }
}

