/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.control.TryCatch;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.trino.annotation.UsedByGeneratedCode;
import io.trino.metadata.SqlScalarFunction;
import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction;
import io.trino.operator.scalar.SpecializedSqlScalarFunction;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BufferedMapValueBuilder;
import io.trino.spi.block.DuplicateMapKeyException;
import io.trino.spi.block.MapValueBuilder;
import io.trino.spi.block.SqlMap;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.Signature;
import io.trino.spi.type.MapType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.LambdaMetafactoryGenerator;
import io.trino.sql.gen.SqlTypeBytecodeExpression;
import io.trino.sql.gen.lambda.BinaryFunctionInterface;
import io.trino.type.BlockTypeOperators;
import io.trino.type.UnknownType;
import io.trino.util.CompilerUtils;
import io.trino.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Optional;

public final class MapTransformKeysFunction
extends SqlScalarFunction {
    public static final String NAME = "transform_keys";
    private static final MethodHandle STATE_FACTORY = Reflection.methodHandle(MapTransformKeysFunction.class, "createState", MapType.class);

    public MapTransformKeysFunction(BlockTypeOperators blockTypeOperators) {
        super(FunctionMetadata.scalarBuilder((String)NAME).signature(Signature.builder().typeVariable("K1").typeVariable("K2").typeVariable("V").returnType(TypeSignature.mapType((TypeSignature)new TypeSignature("K2", new TypeSignatureParameter[0]), (TypeSignature)new TypeSignature("V", new TypeSignatureParameter[0]))).argumentType(TypeSignature.mapType((TypeSignature)new TypeSignature("K1", new TypeSignatureParameter[0]), (TypeSignature)new TypeSignature("V", new TypeSignatureParameter[0]))).argumentType(TypeSignature.functionType((TypeSignature)new TypeSignature("K1", new TypeSignatureParameter[0]), (TypeSignature[])new TypeSignature[]{new TypeSignature("V", new TypeSignatureParameter[0]), new TypeSignature("K2", new TypeSignatureParameter[0])})).build()).nondeterministic().description("Apply lambda to each entry of the map and transform the key").build());
    }

    @Override
    protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) {
        MapType inputMapType = (MapType)boundSignature.getArgumentType(0);
        Type inputKeyType = inputMapType.getKeyType();
        MapType outputMapType = (MapType)boundSignature.getReturnType();
        Type outputKeyType = outputMapType.getKeyType();
        Type valueType = outputMapType.getValueType();
        return new ChoicesSpecializedSqlScalarFunction(boundSignature, InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (List<InvocationConvention.InvocationArgumentConvention>)ImmutableList.of((Object)InvocationConvention.InvocationArgumentConvention.NEVER_NULL, (Object)InvocationConvention.InvocationArgumentConvention.FUNCTION), (List<Class<?>>)ImmutableList.of(BinaryFunctionInterface.class), MapTransformKeysFunction.generateTransformKey(inputKeyType, outputKeyType, valueType), Optional.of(STATE_FACTORY.bindTo(outputMapType)));
    }

    @UsedByGeneratedCode
    public static Object createState(MapType mapType) {
        return BufferedMapValueBuilder.createBufferedDistinctStrict((MapType)mapType);
    }

    private static MethodHandle generateTransformKey(Type keyType, Type transformedKeyType, Type valueType) {
        CallSiteBinder binder = new CallSiteBinder();
        ClassDefinition definition = new ClassDefinition(Access.a((Access[])new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("MapTransformKey"), ParameterizedType.type(Object.class), new ParameterizedType[0]);
        definition.declareDefaultConstructor(Access.a((Access[])new Access[]{Access.PRIVATE}));
        MethodDefinition transformMap = MapTransformKeysFunction.generateTransformKeyInner(definition, binder, keyType, transformedKeyType, valueType);
        Parameter state = Parameter.arg((String)"state", Object.class);
        Parameter session = Parameter.arg((String)"session", ConnectorSession.class);
        Parameter map = Parameter.arg((String)"map", SqlMap.class);
        Parameter function = Parameter.arg((String)"function", BinaryFunctionInterface.class);
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC, Access.STATIC}), "transform", ParameterizedType.type(SqlMap.class), (Iterable)ImmutableList.of((Object)state, (Object)session, (Object)map, (Object)function));
        BytecodeBlock body = method.getBody();
        Scope scope = method.getScope();
        Variable mapValueBuilder = scope.declareVariable(BufferedMapValueBuilder.class, "mapValueBuilder");
        body.append((BytecodeNode)mapValueBuilder.set(state.cast(BufferedMapValueBuilder.class)));
        BytecodeExpression mapEntryBuilder = LambdaMetafactoryGenerator.generateMetafactory(MapValueBuilder.class, transformMap, (List<BytecodeExpression>)ImmutableList.of((Object)session, (Object)map, (Object)function));
        Variable duplicateKeyException = scope.declareVariable(DuplicateMapKeyException.class, "e");
        body.append((BytecodeNode)new TryCatch((BytecodeNode)mapValueBuilder.invoke("build", SqlMap.class, new BytecodeExpression[]{map.invoke("getSize", Integer.TYPE, new BytecodeExpression[0]), mapEntryBuilder}).ret(), (List)ImmutableList.of((Object)new TryCatch.CatchBlock((BytecodeNode)new BytecodeBlock().putVariable(duplicateKeyException).append((BytecodeNode)duplicateKeyException.invoke("withDetailedMessage", DuplicateMapKeyException.class, new BytecodeExpression[]{SqlTypeBytecodeExpression.constantType(binder, transformedKeyType), session})).throwObject(), (List)ImmutableList.of((Object)ParameterizedType.type(DuplicateMapKeyException.class))))));
        Class<Object> generatedClass = CompilerUtils.defineClass(definition, Object.class, binder.getBindings(), MapTransformKeysFunction.class.getClassLoader());
        return Reflection.methodHandle(generatedClass, "transform", Object.class, ConnectorSession.class, SqlMap.class, BinaryFunctionInterface.class);
    }

    private static MethodDefinition generateTransformKeyInner(ClassDefinition definition, CallSiteBinder binder, Type keyType, Type transformedKeyType, Type valueType) {
        Parameter session = Parameter.arg((String)"session", ConnectorSession.class);
        Parameter map = Parameter.arg((String)"map", SqlMap.class);
        Parameter function = Parameter.arg((String)"function", BinaryFunctionInterface.class);
        Parameter keyBuilder = Parameter.arg((String)"keyBuilder", BlockBuilder.class);
        Parameter valueBuilder = Parameter.arg((String)"valueBuilder", BlockBuilder.class);
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PRIVATE, Access.STATIC}), "transform", ParameterizedType.type(Void.TYPE), (Iterable)ImmutableList.of((Object)session, (Object)map, (Object)function, (Object)keyBuilder, (Object)valueBuilder));
        BytecodeBlock body = method.getBody();
        Scope scope = method.getScope();
        Class keyJavaType = Primitives.wrap((Class)keyType.getJavaType());
        Class transformedKeyJavaType = Primitives.wrap((Class)transformedKeyType.getJavaType());
        Class valueJavaType = Primitives.wrap((Class)valueType.getJavaType());
        Variable size = scope.declareVariable("size", body, map.invoke("getSize", Integer.TYPE, new BytecodeExpression[0]));
        Variable rawOffset = scope.declareVariable("rawOffset", body, map.invoke("getRawOffset", Integer.TYPE, new BytecodeExpression[0]));
        Variable rawKeyBlock = scope.declareVariable("rawKeyBlock", body, map.invoke("getRawKeyBlock", Block.class, new BytecodeExpression[0]));
        Variable rawValueBlock = scope.declareVariable("rawValueBlock", body, map.invoke("getRawValueBlock", Block.class, new BytecodeExpression[0]));
        Variable index = scope.declareVariable(Integer.TYPE, "index");
        Variable keyElement = scope.declareVariable(keyJavaType, "keyElement");
        Variable transformedKeyElement = scope.declareVariable(transformedKeyJavaType, "transformedKeyElement");
        Variable valueElement = scope.declareVariable(valueJavaType, "valueElement");
        BytecodeBlock throwNullKeyException = new BytecodeBlock().append((BytecodeNode)BytecodeExpressions.newInstance(TrinoException.class, (BytecodeExpression[])new BytecodeExpression[]{BytecodeExpressions.getStatic((Class)StandardErrorCode.INVALID_FUNCTION_ARGUMENT.getDeclaringClass(), (String)"INVALID_FUNCTION_ARGUMENT").cast(ErrorCodeSupplier.class), BytecodeExpressions.constantString((String)"map key cannot be null")})).throwObject();
        SqlTypeBytecodeExpression keySqlType = SqlTypeBytecodeExpression.constantType(binder, keyType);
        Object loadKeyElement = !keyType.equals((Object)UnknownType.UNKNOWN) ? keyElement.set(keySqlType.getValue((BytecodeExpression)rawKeyBlock, BytecodeExpressions.add((BytecodeExpression)index, (BytecodeExpression)rawOffset)).cast(keyJavaType)) : new BytecodeBlock().append((BytecodeNode)keyElement.set(BytecodeExpressions.constantNull((Class)keyJavaType))).append((BytecodeNode)throwNullKeyException);
        SqlTypeBytecodeExpression valueSqlType = SqlTypeBytecodeExpression.constantType(binder, valueType);
        Object loadValueElement = !valueType.equals((Object)UnknownType.UNKNOWN) ? new IfStatement().condition((BytecodeNode)rawValueBlock.invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{BytecodeExpressions.add((BytecodeExpression)index, (BytecodeExpression)rawOffset)})).ifTrue((BytecodeNode)valueElement.set(BytecodeExpressions.constantNull((Class)valueJavaType))).ifFalse((BytecodeNode)valueElement.set(valueSqlType.getValue((BytecodeExpression)rawValueBlock, BytecodeExpressions.add((BytecodeExpression)index, (BytecodeExpression)rawOffset)).cast(valueJavaType))) : valueElement.set(BytecodeExpressions.constantNull((Class)valueJavaType));
        BytecodeBlock writeKeyElement = !transformedKeyType.equals((Object)UnknownType.UNKNOWN) ? new BytecodeBlock().append((BytecodeNode)transformedKeyElement.set(function.invoke("apply", Object.class, new BytecodeExpression[]{keyElement.cast(Object.class), valueElement.cast(Object.class)}).cast(transformedKeyJavaType))).append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.equal((BytecodeExpression)transformedKeyElement, (BytecodeExpression)BytecodeExpressions.constantNull((Class)transformedKeyJavaType))).ifTrue((BytecodeNode)throwNullKeyException).ifFalse((BytecodeNode)new BytecodeBlock().append((BytecodeNode)SqlTypeBytecodeExpression.constantType(binder, transformedKeyType).writeValue((BytecodeExpression)keyBuilder, transformedKeyElement.cast(transformedKeyType.getJavaType()))).append((BytecodeNode)valueSqlType.invoke("appendTo", Void.TYPE, new BytecodeExpression[]{rawValueBlock, BytecodeExpressions.add((BytecodeExpression)index, (BytecodeExpression)rawOffset), valueBuilder})))) : throwNullKeyException;
        body.append((BytecodeNode)new ForLoop().initialize((BytecodeNode)index.set(BytecodeExpressions.constantInt((int)0))).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)index, (BytecodeExpression)size)).update((BytecodeNode)index.increment()).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)loadKeyElement).append((BytecodeNode)loadValueElement).append((BytecodeNode)writeKeyElement)));
        body.ret();
        return method;
    }
}

