/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.trino.annotation.UsedByGeneratedCode;
import io.trino.metadata.SqlScalarFunction;
import io.trino.operator.scalar.ChoicesSpecializedSqlScalarFunction;
import io.trino.operator.scalar.SpecializedSqlScalarFunction;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BufferedMapValueBuilder;
import io.trino.spi.block.MapValueBuilder;
import io.trino.spi.block.SqlMap;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.Signature;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.MapType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.LambdaMetafactoryGenerator;
import io.trino.sql.gen.SqlTypeBytecodeExpression;
import io.trino.sql.gen.lambda.BinaryFunctionInterface;
import io.trino.type.UnknownType;
import io.trino.util.CompilerUtils;
import io.trino.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Optional;

public final class MapFilterFunction
extends SqlScalarFunction {
    public static final MapFilterFunction MAP_FILTER_FUNCTION = new MapFilterFunction();
    private static final MethodHandle STATE_FACTORY = Reflection.methodHandle(MapFilterFunction.class, "createState", MapType.class);

    private MapFilterFunction() {
        super(FunctionMetadata.scalarBuilder((String)"map_filter").signature(Signature.builder().typeVariable("K").typeVariable("V").returnType(TypeSignature.mapType((TypeSignature)new TypeSignature("K", new TypeSignatureParameter[0]), (TypeSignature)new TypeSignature("V", new TypeSignatureParameter[0]))).argumentType(TypeSignature.mapType((TypeSignature)new TypeSignature("K", new TypeSignatureParameter[0]), (TypeSignature)new TypeSignature("V", new TypeSignatureParameter[0]))).argumentType(TypeSignature.functionType((TypeSignature)new TypeSignature("K", new TypeSignatureParameter[0]), (TypeSignature[])new TypeSignature[]{new TypeSignature("V", new TypeSignatureParameter[0]), BooleanType.BOOLEAN.getTypeSignature()})).build()).nondeterministic().description("return map containing entries that match the given predicate").build());
    }

    @Override
    public SpecializedSqlScalarFunction specialize(BoundSignature boundSignature) {
        MapType mapType = (MapType)boundSignature.getReturnType();
        return new ChoicesSpecializedSqlScalarFunction(boundSignature, InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (List<InvocationConvention.InvocationArgumentConvention>)ImmutableList.of((Object)InvocationConvention.InvocationArgumentConvention.NEVER_NULL, (Object)InvocationConvention.InvocationArgumentConvention.FUNCTION), (List<Class<?>>)ImmutableList.of(BinaryFunctionInterface.class), MapFilterFunction.generateFilter(mapType), Optional.of(STATE_FACTORY.bindTo(mapType)));
    }

    @UsedByGeneratedCode
    public static Object createState(MapType mapType) {
        return BufferedMapValueBuilder.createBuffered((MapType)mapType);
    }

    private static MethodHandle generateFilter(MapType mapType) {
        CallSiteBinder binder = new CallSiteBinder();
        ClassDefinition definition = new ClassDefinition(Access.a((Access[])new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("MapFilter"), ParameterizedType.type(Object.class), new ParameterizedType[0]);
        definition.declareDefaultConstructor(Access.a((Access[])new Access[]{Access.PRIVATE}));
        MethodDefinition filterKeyValue = MapFilterFunction.generateFilterInner(definition, binder, mapType);
        Parameter state = Parameter.arg((String)"state", Object.class);
        Parameter map = Parameter.arg((String)"map", SqlMap.class);
        Parameter function = Parameter.arg((String)"function", BinaryFunctionInterface.class);
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC, Access.STATIC}), "filter", ParameterizedType.type(SqlMap.class), (Iterable)ImmutableList.of((Object)state, (Object)map, (Object)function));
        BytecodeBlock body = method.getBody();
        Scope scope = method.getScope();
        Variable mapValueBuilder = scope.declareVariable(BufferedMapValueBuilder.class, "mapValueBuilder");
        body.append((BytecodeNode)mapValueBuilder.set(state.cast(BufferedMapValueBuilder.class)));
        BytecodeExpression mapEntryBuilder = LambdaMetafactoryGenerator.generateMetafactory(MapValueBuilder.class, filterKeyValue, (List<BytecodeExpression>)ImmutableList.of((Object)map, (Object)function));
        body.append((BytecodeNode)mapValueBuilder.invoke("build", SqlMap.class, new BytecodeExpression[]{map.invoke("getSize", Integer.TYPE, new BytecodeExpression[0]), mapEntryBuilder}).ret());
        Class<Object> generatedClass = CompilerUtils.defineClass(definition, Object.class, binder.getBindings(), MapFilterFunction.class.getClassLoader());
        return Reflection.methodHandle(generatedClass, "filter", Object.class, SqlMap.class, BinaryFunctionInterface.class);
    }

    private static MethodDefinition generateFilterInner(ClassDefinition definition, CallSiteBinder binder, MapType mapType) {
        Parameter map = Parameter.arg((String)"map", SqlMap.class);
        Parameter function = Parameter.arg((String)"function", BinaryFunctionInterface.class);
        Parameter keyBuilder = Parameter.arg((String)"keyBuilder", BlockBuilder.class);
        Parameter valueBuilder = Parameter.arg((String)"valueBuilder", BlockBuilder.class);
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PRIVATE, Access.STATIC}), "filter", ParameterizedType.type(Void.TYPE), (Iterable)ImmutableList.of((Object)map, (Object)function, (Object)keyBuilder, (Object)valueBuilder));
        BytecodeBlock body = method.getBody();
        Scope scope = method.getScope();
        Type keyType = mapType.getKeyType();
        Type valueType = mapType.getValueType();
        Class keyJavaType = Primitives.wrap((Class)keyType.getJavaType());
        Class valueJavaType = Primitives.wrap((Class)valueType.getJavaType());
        Variable size = scope.declareVariable("size", body, map.invoke("getSize", Integer.TYPE, new BytecodeExpression[0]));
        Variable rawOffset = scope.declareVariable("rawOffset", body, map.invoke("getRawOffset", Integer.TYPE, new BytecodeExpression[0]));
        Variable rawKeyBlock = scope.declareVariable("rawKeyBlock", body, map.invoke("getRawKeyBlock", Block.class, new BytecodeExpression[0]));
        Variable rawValueBlock = scope.declareVariable("rawValueBlock", body, map.invoke("getRawValueBlock", Block.class, new BytecodeExpression[0]));
        Variable index = scope.declareVariable(Integer.TYPE, "index");
        Variable keyElement = scope.declareVariable(keyJavaType, "keyElement");
        Variable valueElement = scope.declareVariable(valueJavaType, "valueElement");
        Variable keep = scope.declareVariable(Boolean.class, "keep");
        SqlTypeBytecodeExpression keySqlType = SqlTypeBytecodeExpression.constantType(binder, keyType);
        Object loadKeyElement = !keyType.equals((Object)UnknownType.UNKNOWN) ? keyElement.set(keySqlType.getValue((BytecodeExpression)rawKeyBlock, BytecodeExpressions.add((BytecodeExpression)index, (BytecodeExpression)rawOffset)).cast(keyJavaType)) : new BytecodeBlock().append((BytecodeNode)keyElement.set(BytecodeExpressions.constantNull((Class)keyJavaType)));
        SqlTypeBytecodeExpression valueSqlType = SqlTypeBytecodeExpression.constantType(binder, valueType);
        Object loadValueElement = !valueType.equals((Object)UnknownType.UNKNOWN) ? new IfStatement().condition((BytecodeNode)rawValueBlock.invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{BytecodeExpressions.add((BytecodeExpression)index, (BytecodeExpression)rawOffset)})).ifTrue((BytecodeNode)valueElement.set(BytecodeExpressions.constantNull((Class)valueJavaType))).ifFalse((BytecodeNode)valueElement.set(valueSqlType.getValue((BytecodeExpression)rawValueBlock, BytecodeExpressions.add((BytecodeExpression)index, (BytecodeExpression)rawOffset)).cast(valueJavaType))) : new BytecodeBlock().append((BytecodeNode)valueElement.set(BytecodeExpressions.constantNull((Class)valueJavaType)));
        body.append((BytecodeNode)new ForLoop().initialize((BytecodeNode)index.set(BytecodeExpressions.constantInt((int)0))).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)index, (BytecodeExpression)size)).update((BytecodeNode)index.increment()).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)loadKeyElement).append((BytecodeNode)loadValueElement).append((BytecodeNode)keep.set(function.invoke("apply", Object.class, new BytecodeExpression[]{keyElement.cast(Object.class), valueElement.cast(Object.class)}).cast(Boolean.class))).append((BytecodeNode)new IfStatement("if (keep != null && keep) ...", new Object[0]).condition((BytecodeNode)BytecodeExpressions.and((BytecodeExpression)BytecodeExpressions.notEqual((BytecodeExpression)keep, (BytecodeExpression)BytecodeExpressions.constantNull(Boolean.class)), (BytecodeExpression)keep.cast(Boolean.TYPE))).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)keySqlType.invoke("appendTo", Void.TYPE, new BytecodeExpression[]{rawKeyBlock, BytecodeExpressions.add((BytecodeExpression)index, (BytecodeExpression)rawOffset), keyBuilder})).append((BytecodeNode)valueSqlType.invoke("appendTo", Void.TYPE, new BytecodeExpression[]{rawValueBlock, BytecodeExpressions.add((BytecodeExpression)index, (BytecodeExpression)rawOffset), valueBuilder}))))));
        body.ret();
        return method;
    }
}

