/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.gen;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import io.prestosql.annotation.UsedByGeneratedCode;
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionKind;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlScalarFunction;
import io.prestosql.operator.scalar.AbstractTestFunctions;
import io.prestosql.operator.scalar.ChoicesScalarFunctionImplementation;
import io.prestosql.operator.scalar.ScalarFunctionImplementation;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.type.IntegerType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.gen.VarArgsToArrayAdapterGenerator;
import io.prestosql.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

public class TestVarArgsToArrayAdapterGenerator
extends AbstractTestFunctions {
    @BeforeClass
    public void setUp() {
        this.registerScalarFunction(TestVarArgsSum.VAR_ARGS_SUM);
    }

    @Test
    public void testArrayElements() {
        this.assertFunction("var_args_sum()", (Type)IntegerType.INTEGER, 0);
        this.assertFunction("var_args_sum(1)", (Type)IntegerType.INTEGER, 1);
        this.assertFunction("var_args_sum(1, 2)", (Type)IntegerType.INTEGER, 3);
        this.assertFunction("var_args_sum(null)", (Type)IntegerType.INTEGER, null);
        this.assertFunction("var_args_sum(1, null, 2, null, 3)", (Type)IntegerType.INTEGER, null);
        this.assertFunction("var_args_sum(1, 2, 3)", (Type)IntegerType.INTEGER, 6);
        int k = 100;
        int expectedSum = (1 + k) * k / 2;
        this.assertFunction(String.format("var_args_sum(%s)", Joiner.on((String)",").join((Iterable)IntStream.rangeClosed(1, k).boxed().collect(Collectors.toSet()))), (Type)IntegerType.INTEGER, expectedSum);
    }

    public static class TestVarArgsSum
    extends SqlScalarFunction {
        public static final TestVarArgsSum VAR_ARGS_SUM = new TestVarArgsSum();
        private static final MethodHandle METHOD_HANDLE = Reflection.methodHandle(TestVarArgsSum.class, (String)"varArgsSum", (Class[])new Class[]{Object.class, long[].class});
        private static final MethodHandle USER_STATE_FACTORY = Reflection.methodHandle(TestVarArgsSum.class, (String)"createState", (Class[])new Class[0]);

        private TestVarArgsSum() {
            super(new FunctionMetadata(new Signature("var_args_sum", (List)ImmutableList.of(), (List)ImmutableList.of(), IntegerType.INTEGER.getTypeSignature(), (List)ImmutableList.of((Object)IntegerType.INTEGER.getTypeSignature()), true), false, (List)ImmutableList.of((Object)new FunctionArgumentDefinition(false)), false, false, "return sum of all the parameters", FunctionKind.SCALAR));
        }

        protected ScalarFunctionImplementation specialize(FunctionBinding functionBinding) {
            VarArgsToArrayAdapterGenerator.MethodHandleAndConstructor methodHandleAndConstructor = VarArgsToArrayAdapterGenerator.generateVarArgsToArrayAdapter(Long.TYPE, Long.TYPE, (int)functionBinding.getArity(), (MethodHandle)METHOD_HANDLE, (MethodHandle)USER_STATE_FACTORY);
            return new ChoicesScalarFunctionImplementation(functionBinding, InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, Collections.nCopies(functionBinding.getArity(), InvocationConvention.InvocationArgumentConvention.NEVER_NULL), methodHandleAndConstructor.getMethodHandle(), Optional.of(methodHandleAndConstructor.getConstructor()));
        }

        @UsedByGeneratedCode
        public static Object createState() {
            return null;
        }

        @UsedByGeneratedCode
        public static long varArgsSum(Object state, long[] values) {
            long sum = 0L;
            for (long value : values) {
                sum += value;
            }
            return sum;
        }
    }
}

