/*
 * Decompiled with CFR 0.152.
 */
package com.amazonaws.athena.connector.lambda.handlers;

import com.amazonaws.athena.connector.lambda.data.Block;
import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
import com.amazonaws.athena.connector.lambda.data.BlockAllocatorImpl;
import com.amazonaws.athena.connector.lambda.data.BlockUtils;
import com.amazonaws.athena.connector.lambda.data.FieldResolver;
import com.amazonaws.athena.connector.lambda.data.projectors.ArrowValueProjector;
import com.amazonaws.athena.connector.lambda.data.projectors.ProjectorUtils;
import com.amazonaws.athena.connector.lambda.data.writers.GeneratedRowWriter;
import com.amazonaws.athena.connector.lambda.data.writers.extractors.Extractor;
import com.amazonaws.athena.connector.lambda.data.writers.fieldwriters.FieldWriterFactory;
import com.amazonaws.athena.connector.lambda.request.FederationRequest;
import com.amazonaws.athena.connector.lambda.request.FederationResponse;
import com.amazonaws.athena.connector.lambda.request.PingRequest;
import com.amazonaws.athena.connector.lambda.request.PingResponse;
import com.amazonaws.athena.connector.lambda.serde.VersionedObjectMapperFactory;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionRequest;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionResponse;
import com.amazonaws.athena.connector.lambda.udf.UserDefinedFunctionType;
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.arrow.vector.complex.reader.FieldReader;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class UserDefinedFunctionHandler
implements RequestStreamHandler {
    private static final Logger logger = LoggerFactory.getLogger(UserDefinedFunctionHandler.class);
    private static final int RETURN_COLUMN_COUNT = 1;
    private final String sourceType;

    public UserDefinedFunctionHandler(String sourceType) {
        this.sourceType = sourceType;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public final void handleRequest(InputStream inputStream, OutputStream outputStream, Context context) {
        try (BlockAllocatorImpl allocator = new BlockAllocatorImpl();){
            ObjectMapper objectMapper = VersionedObjectMapperFactory.create(allocator);
            try {
                Throwable throwable;
                FederationRequest rawRequest;
                block47: {
                    block49: {
                        rawRequest = (FederationRequest)objectMapper.readValue(inputStream, FederationRequest.class);
                        throwable = null;
                        if (!(rawRequest instanceof PingRequest)) break block47;
                        try (PingResponse response = this.doPing((PingRequest)rawRequest);){
                            this.assertNotNull(response);
                            objectMapper.writeValue(outputStream, (Object)response);
                        }
                        if (rawRequest == null) return;
                        if (throwable == null) break block49;
                        try {
                            rawRequest.close();
                            return;
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                            return;
                        }
                    }
                    rawRequest.close();
                    return;
                }
                try {
                    if (!(rawRequest instanceof UserDefinedFunctionRequest)) {
                        throw new RuntimeException("Expected a UserDefinedFunctionRequest but found " + rawRequest.getClass());
                    }
                    this.doHandleRequest(allocator, objectMapper, (UserDefinedFunctionRequest)rawRequest, outputStream);
                    return;
                }
                catch (Throwable throwable3) {
                    throwable = throwable3;
                    throw throwable3;
                }
                catch (Throwable throwable4) {
                    throw throwable4;
                }
                finally {
                    if (rawRequest != null) {
                        if (throwable != null) {
                            try {
                                rawRequest.close();
                            }
                            catch (Throwable throwable5) {
                                throwable.addSuppressed(throwable5);
                            }
                        } else {
                            rawRequest.close();
                        }
                    }
                }
            }
            catch (Exception ex) {
                RuntimeException runtimeException;
                if (ex instanceof RuntimeException) {
                    runtimeException = (RuntimeException)ex;
                    throw runtimeException;
                }
                runtimeException = new RuntimeException(ex);
                throw runtimeException;
            }
        }
    }

    protected final void doHandleRequest(BlockAllocator allocator, ObjectMapper objectMapper, UserDefinedFunctionRequest req, OutputStream outputStream) throws Exception {
        logger.info("doHandleRequest: request[{}]", (Object)req);
        try (UserDefinedFunctionResponse response = this.processFunction(allocator, req);){
            logger.info("doHandleRequest: response[{}]", (Object)response);
            this.assertNotNull(response);
            objectMapper.writeValue(outputStream, (Object)response);
        }
    }

    @VisibleForTesting
    UserDefinedFunctionResponse processFunction(BlockAllocator allocator, UserDefinedFunctionRequest req) throws Exception {
        UserDefinedFunctionType functionType = req.getFunctionType();
        switch (functionType) {
            case SCALAR: {
                return this.processScalarFunction(allocator, req);
            }
        }
        throw new UnsupportedOperationException("Unsupported function type " + (Object)((Object)functionType));
    }

    private UserDefinedFunctionResponse processScalarFunction(BlockAllocator allocator, UserDefinedFunctionRequest req) throws Exception {
        Method udfMethod = this.extractScalarFunctionMethod(req);
        Block inputRecords = req.getInputRecords();
        Schema outputSchema = req.getOutputSchema();
        Block outputRecords = this.processRows(allocator, udfMethod, inputRecords, outputSchema);
        return new UserDefinedFunctionResponse(outputRecords, udfMethod.getName());
    }

    protected Block processRows(BlockAllocator allocator, Method udfMethod, Block inputRecords, Schema outputSchema) throws Exception {
        int rowCount = inputRecords.getRowCount();
        ArrayList valueProjectors = Lists.newArrayList();
        for (Field field : inputRecords.getFields()) {
            FieldReader fieldReader = inputRecords.getFieldReader(field.getName());
            ArrowValueProjector arrowValueProjector = ProjectorUtils.createArrowValueProjector(fieldReader);
            valueProjectors.add(arrowValueProjector);
        }
        Field outputField = (Field)outputSchema.getFields().get(0);
        GeneratedRowWriter outputRowWriter = this.createOutputRowWriter(outputField, valueProjectors, udfMethod);
        Block outputRecords = allocator.createBlock(outputSchema);
        outputRecords.setRowCount(rowCount);
        try {
            for (int rowNum = 0; rowNum < rowCount; ++rowNum) {
                outputRowWriter.writeRow(outputRecords, rowNum, rowNum);
            }
        }
        catch (Throwable t) {
            try {
                outputRecords.close();
            }
            catch (Exception e) {
                logger.error("Error closing output block", (Throwable)e);
            }
            throw t;
        }
        return outputRecords;
    }

    private Method extractScalarFunctionMethod(UserDefinedFunctionRequest req) {
        Method udfMethod;
        String methodName = req.getMethodName();
        Object[] argumentTypes = this.extractJavaTypes(req.getInputRecords().getSchema());
        Class[] returnTypes = this.extractJavaTypes(req.getOutputSchema());
        Preconditions.checkState((returnTypes.length == 1 ? 1 : 0) != 0, (Object)String.format("Expecting %d return columns, found %d in method signature.", 1, returnTypes.length));
        Class returnType = returnTypes[0];
        try {
            udfMethod = this.getClass().getMethod(methodName, (Class<?>[])argumentTypes);
            logger.info(String.format("Found UDF method %s with input types [%s] and output types [%s]", methodName, Arrays.toString(argumentTypes), returnType.getName()));
        }
        catch (NoSuchMethodException e) {
            String msg = "Failed to find UDF method. " + e.getMessage() + " Please make sure the method name contains only lowercase and the method signature (name and argument types) in Lambda matches the function signature defined in SQL.";
            throw new RuntimeException(msg, e);
        }
        if (!returnType.equals(udfMethod.getReturnType())) {
            throw new IllegalArgumentException("signature return type " + returnType + " does not match udf implementation return type " + udfMethod.getReturnType());
        }
        return udfMethod;
    }

    private Class[] extractJavaTypes(Schema schema) {
        Class[] types = new Class[schema.getFields().size()];
        List fields = schema.getFields();
        for (int i = 0; i < fields.size(); ++i) {
            Types.MinorType minorType = Types.getMinorTypeForArrowType((ArrowType)((Field)fields.get(i)).getType());
            types[i] = BlockUtils.getJavaType(minorType);
        }
        return types;
    }

    private final PingResponse doPing(PingRequest request) {
        PingResponse response = new PingResponse(request.getCatalogName(), request.getQueryId(), this.sourceType, 24, 2);
        try {
            this.onPing(request);
        }
        catch (Exception ex) {
            logger.warn("doPing: encountered an exception while delegating onPing.", (Throwable)ex);
        }
        return response;
    }

    protected void onPing(PingRequest request) {
    }

    private void assertNotNull(FederationResponse response) {
        if (response == null) {
            throw new RuntimeException("Response was null");
        }
    }

    private GeneratedRowWriter createOutputRowWriter(Field outputField, List<ArrowValueProjector> valueProjectors, Method udfMethod) {
        GeneratedRowWriter.RowWriterBuilder builder = GeneratedRowWriter.newBuilder();
        Extractor extractor = this.makeExtractor(outputField, valueProjectors, udfMethod);
        if (extractor != null) {
            builder.withExtractor(outputField.getName(), extractor);
        } else {
            builder.withFieldWriterFactory(outputField.getName(), this.makeFactory(outputField, valueProjectors, udfMethod));
        }
        return builder.build();
    }

    private Extractor makeExtractor(Field outputField, List<ArrowValueProjector> valueProjectors, Method udfMethod) {
        Types.MinorType fieldType = Types.getMinorTypeForArrowType((ArrowType)outputField.getType());
        Object[] arguments = new Object[valueProjectors.size()];
        switch (fieldType) {
            case INT: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = (Integer)result;
                    }
                };
            }
            case DATEMILLI: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = ((LocalDateTime)result).atZone(BlockUtils.UTC_ZONE_ID).toInstant().toEpochMilli();
                    }
                };
            }
            case DATEDAY: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = (int)((LocalDate)result).toEpochDay();
                    }
                };
            }
            case TINYINT: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = (Byte)result;
                    }
                };
            }
            case SMALLINT: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = (Short)result;
                    }
                };
            }
            case FLOAT4: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = ((Float)result).floatValue();
                    }
                };
            }
            case FLOAT8: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = (Double)result;
                    }
                };
            }
            case DECIMAL: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = (BigDecimal)result;
                    }
                };
            }
            case BIT: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = (Boolean)result != false ? 1 : 0;
                    }
                };
            }
            case BIGINT: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = (Long)result;
                    }
                };
            }
            case VARCHAR: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = (String)result;
                    }
                };
            }
            case VARBINARY: {
                return (inputRowNum, dst) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    if (result == null) {
                        dst.isSet = 0;
                    } else {
                        dst.isSet = 1;
                        dst.value = (byte[])result;
                    }
                };
            }
        }
        return null;
    }

    private FieldWriterFactory makeFactory(Field field, List<ArrowValueProjector> valueProjectors, Method udfMethod) {
        Object[] arguments = new Object[valueProjectors.size()];
        Types.MinorType fieldType = Types.getMinorTypeForArrowType((ArrowType)field.getType());
        switch (fieldType) {
            case LIST: 
            case STRUCT: {
                return (vector, extractor, ignored) -> (inputRowNum, outputRowNum) -> {
                    Object result = this.invokeMethod(udfMethod, arguments, (Integer)inputRowNum, valueProjectors);
                    BlockUtils.setComplexValue(vector, outputRowNum, FieldResolver.DEFAULT, result);
                    return true;
                };
            }
        }
        throw new IllegalArgumentException("Unsupported type " + fieldType);
    }

    private Object invokeMethod(Method udfMethod, Object[] arguments, int inputRowNum, List<ArrowValueProjector> valueProjectors) {
        for (int col = 0; col < valueProjectors.size(); ++col) {
            arguments[col] = valueProjectors.get(col).project(inputRowNum);
        }
        try {
            return udfMethod.invoke((Object)this, arguments);
        }
        catch (IllegalAccessException e) {
            throw new RuntimeException(e);
        }
        catch (InvocationTargetException e) {
            if (Objects.isNull(e)) {
                throw new RuntimeException(e);
            }
            throw new RuntimeException(e.getCause());
        }
        catch (IllegalArgumentException e) {
            String msg = String.format("%s. Expected function types %s, got types %s", e.getMessage(), Arrays.stream(udfMethod.getParameterTypes()).map(clazz -> clazz.getName()).collect(Collectors.toList()), Arrays.stream(arguments).map(arg -> arg.getClass().getName()).collect(Collectors.toList()));
            throw new RuntimeException(msg, e);
        }
    }
}

