/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.VariableInstruction;
import io.trino.metadata.SqlScalarFunction;
import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction;
import io.trino.operator.scalar.SpecializedSqlScalarFunction;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.ArrayValueBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BufferedArrayValueBuilder;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.Signature;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.LambdaMetafactoryGenerator;
import io.trino.sql.gen.SqlTypeBytecodeExpression;
import io.trino.sql.gen.lambda.UnaryFunctionInterface;
import io.trino.type.UnknownType;
import io.trino.util.CompilerUtils;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.List;
import java.util.Optional;

public final class ArrayTransformFunction
extends SqlScalarFunction {
    private static final MethodHandle CREATE_STATE;
    public static final ArrayTransformFunction ARRAY_TRANSFORM_FUNCTION;
    public static final String ARRAY_TRANSFORM_NAME = "transform";

    private ArrayTransformFunction() {
        super(FunctionMetadata.scalarBuilder((String)ARRAY_TRANSFORM_NAME).signature(Signature.builder().typeVariable("T").typeVariable("U").returnType(TypeSignature.arrayType((TypeSignature)new TypeSignature("U", new TypeSignatureParameter[0]))).argumentType(TypeSignature.arrayType((TypeSignature)new TypeSignature("T", new TypeSignatureParameter[0]))).argumentType(TypeSignature.functionType((TypeSignature)new TypeSignature("T", new TypeSignatureParameter[0]), (TypeSignature[])new TypeSignature[]{new TypeSignature("U", new TypeSignatureParameter[0])})).build()).description("Apply lambda to each element of the array").build());
    }

    @Override
    protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) {
        Type inputType = ((ArrayType)boundSignature.getArgumentTypes().get(0)).getElementType();
        ArrayType returnType = (ArrayType)boundSignature.getReturnType();
        Type outputType = returnType.getElementType();
        return new ChoicesSpecializedSqlScalarFunction(boundSignature, InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (List<InvocationConvention.InvocationArgumentConvention>)ImmutableList.of((Object)InvocationConvention.InvocationArgumentConvention.NEVER_NULL, (Object)InvocationConvention.InvocationArgumentConvention.FUNCTION), (List<Class<?>>)ImmutableList.of(UnaryFunctionInterface.class), ArrayTransformFunction.generateTransform(inputType, outputType), Optional.of(CREATE_STATE.bindTo(returnType)));
    }

    private static MethodHandle generateTransform(Type inputType, Type outputType) {
        CallSiteBinder binder = new CallSiteBinder();
        ClassDefinition definition = new ClassDefinition(Access.a((Access[])new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("ArrayTransform"), ParameterizedType.type(Object.class), new ParameterizedType[0]);
        definition.declareDefaultConstructor(Access.a((Access[])new Access[]{Access.PRIVATE}));
        MethodDefinition transformValue = ArrayTransformFunction.generateTransformValueInner(definition, binder, inputType, outputType);
        Parameter arrayValueBuilder = Parameter.arg((String)"arrayValueBuilder", BufferedArrayValueBuilder.class);
        Parameter block = Parameter.arg((String)"block", Block.class);
        Parameter function = Parameter.arg((String)"function", UnaryFunctionInterface.class);
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC, Access.STATIC}), ARRAY_TRANSFORM_NAME, ParameterizedType.type(Block.class), (Iterable)ImmutableList.of((Object)arrayValueBuilder, (Object)block, (Object)function));
        BytecodeExpression arrayBuilder = LambdaMetafactoryGenerator.generateMetafactory(ArrayValueBuilder.class, transformValue, (List<BytecodeExpression>)ImmutableList.of((Object)block, (Object)function));
        BytecodeExpression entryCount = block.invoke("getPositionCount", Integer.TYPE, new BytecodeExpression[0]);
        method.getBody().append((BytecodeNode)arrayValueBuilder.invoke("build", Block.class, new BytecodeExpression[]{entryCount, arrayBuilder}).ret());
        Class<Object> generatedClass = CompilerUtils.defineClass(definition, Object.class, binder.getBindings(), ArrayTransformFunction.class.getClassLoader());
        try {
            return MethodHandles.lookup().findStatic(generatedClass, ARRAY_TRANSFORM_NAME, MethodType.methodType(Block.class, BufferedArrayValueBuilder.class, Block.class, UnaryFunctionInterface.class));
        }
        catch (ReflectiveOperationException e) {
            throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, (Throwable)e);
        }
    }

    private static MethodDefinition generateTransformValueInner(ClassDefinition definition, CallSiteBinder binder, Type inputType, Type outputType) {
        Class inputJavaType = Primitives.wrap((Class)inputType.getJavaType());
        Class outputJavaType = Primitives.wrap((Class)outputType.getJavaType());
        Parameter block = Parameter.arg((String)"block", Block.class);
        Parameter function = Parameter.arg((String)"function", UnaryFunctionInterface.class);
        Parameter elementBuilder = Parameter.arg((String)"elementBuilder", BlockBuilder.class);
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PRIVATE, Access.STATIC}), "transformValue", ParameterizedType.type(Void.TYPE), (Iterable)ImmutableList.of((Object)block, (Object)function, (Object)elementBuilder));
        BytecodeBlock body = method.getBody();
        Scope scope = method.getScope();
        Variable positionCount = scope.declareVariable(Integer.TYPE, "positionCount");
        Variable position = scope.declareVariable(Integer.TYPE, "position");
        Variable inputElement = scope.declareVariable(inputJavaType, "inputElement");
        Variable outputElement = scope.declareVariable(outputJavaType, "outputElement");
        body.append((BytecodeNode)positionCount.set(block.invoke("getPositionCount", Integer.TYPE, new BytecodeExpression[0])));
        Object loadInputElement = !inputType.equals((Object)UnknownType.UNKNOWN) ? new IfStatement().condition((BytecodeNode)block.invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{position})).ifTrue((BytecodeNode)inputElement.set(BytecodeExpressions.constantNull((Class)inputJavaType))).ifFalse((BytecodeNode)inputElement.set(SqlTypeBytecodeExpression.constantType(binder, inputType).getValue((BytecodeExpression)block, (BytecodeExpression)position).cast(inputJavaType))) : new BytecodeBlock().append((BytecodeNode)inputElement.set(BytecodeExpressions.constantNull((Class)inputJavaType)));
        Object writeOutputElement = !outputType.equals((Object)UnknownType.UNKNOWN) ? new IfStatement().condition((BytecodeNode)BytecodeExpressions.equal((BytecodeExpression)outputElement, (BytecodeExpression)BytecodeExpressions.constantNull((Class)outputJavaType))).ifTrue((BytecodeNode)elementBuilder.invoke("appendNull", BlockBuilder.class, new BytecodeExpression[0]).pop()).ifFalse((BytecodeNode)SqlTypeBytecodeExpression.constantType(binder, outputType).writeValue((BytecodeExpression)elementBuilder, outputElement.cast(outputType.getJavaType()))) : new BytecodeBlock().append((BytecodeNode)elementBuilder.invoke("appendNull", BlockBuilder.class, new BytecodeExpression[0]).pop());
        body.append((BytecodeNode)new ForLoop().initialize((BytecodeNode)position.set(BytecodeExpressions.constantInt((int)0))).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)position, (BytecodeExpression)positionCount)).update((BytecodeNode)VariableInstruction.incrementVariable((Variable)position, (byte)1)).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)loadInputElement).append((BytecodeNode)outputElement.set(function.invoke("apply", Object.class, new BytecodeExpression[]{inputElement.cast(Object.class)}).cast(outputJavaType))).append((BytecodeNode)writeOutputElement)));
        body.ret();
        return method;
    }

    static {
        try {
            CREATE_STATE = MethodHandles.lookup().findStatic(BufferedArrayValueBuilder.class, "createBuffered", MethodType.methodType(BufferedArrayValueBuilder.class, ArrayType.class));
        }
        catch (ReflectiveOperationException e) {
            throw new ExceptionInInitializerError(e);
        }
        ARRAY_TRANSFORM_FUNCTION = new ArrayTransformFunction();
    }
}

