/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.gen;

import com.facebook.presto.bytecode.Access;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.ParameterizedType;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.DriverYieldSignal;
import com.facebook.presto.operator.index.PageRecordSet;
import com.facebook.presto.operator.project.CursorProcessor;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.gen.CommonSubExpressionRewriter;
import com.facebook.presto.sql.gen.CursorProcessorCompiler;
import com.facebook.presto.sql.gen.ExpressionCompiler;
import com.facebook.presto.sql.gen.PageFunctionCompiler;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.testing.TestingConnectorSession;
import com.facebook.presto.util.CompilerUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.testng.Assert;
import org.testng.annotations.Test;

public class TestCursorProcessorCompiler {
    private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
    private static final FunctionAndTypeManager FUNCTION_MANAGER = METADATA.getFunctionAndTypeManager();
    private static final CallExpression ADD_X_Y = Expressions.call((String)OperatorType.ADD.name(), (FunctionHandle)FUNCTION_MANAGER.resolveOperator(OperatorType.ADD, TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT})), (Type)BigintType.BIGINT, (RowExpression[])new RowExpression[]{Expressions.field((int)0, (Type)BigintType.BIGINT), Expressions.field((int)1, (Type)BigintType.BIGINT)});
    private static final CallExpression ADD_X_Y_GREATER_THAN_2 = Expressions.call((String)OperatorType.GREATER_THAN.name(), (FunctionHandle)FUNCTION_MANAGER.resolveOperator(OperatorType.GREATER_THAN, TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT})), (Type)BooleanType.BOOLEAN, (RowExpression[])new RowExpression[]{ADD_X_Y, Expressions.constant((Object)2L, (Type)BigintType.BIGINT)});
    private static final CallExpression ADD_X_Y_LESS_THAN_10 = Expressions.call((String)OperatorType.LESS_THAN.name(), (FunctionHandle)FUNCTION_MANAGER.resolveOperator(OperatorType.LESS_THAN, TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT})), (Type)BooleanType.BOOLEAN, (RowExpression[])new RowExpression[]{ADD_X_Y, Expressions.constant((Object)10L, (Type)BigintType.BIGINT)});
    private static final CallExpression ADD_X_Y_Z = Expressions.call((String)OperatorType.ADD.name(), (FunctionHandle)FUNCTION_MANAGER.resolveOperator(OperatorType.ADD, TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT})), (Type)BigintType.BIGINT, (RowExpression[])new RowExpression[]{Expressions.call((String)OperatorType.ADD.name(), (FunctionHandle)FUNCTION_MANAGER.resolveOperator(OperatorType.ADD, TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT})), (Type)BigintType.BIGINT, (RowExpression[])new RowExpression[]{Expressions.field((int)0, (Type)BigintType.BIGINT), Expressions.field((int)1, (Type)BigintType.BIGINT)}), Expressions.field((int)2, (Type)BigintType.BIGINT)});

    @Test
    public void testRewriteRowExpressionWithCSE() {
        CursorProcessorCompiler cseCursorCompiler = new CursorProcessorCompiler(METADATA, true, Collections.emptyMap());
        ClassDefinition cursorProcessorClassDefinition = new ClassDefinition(Access.a((Access[])new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName((String)CursorProcessor.class.getSimpleName()), ParameterizedType.type(Object.class), new ParameterizedType[]{ParameterizedType.type(CursorProcessor.class)});
        SpecialFormExpression filter = new SpecialFormExpression(SpecialFormExpression.Form.AND, (Type)BigintType.BIGINT, new RowExpression[]{ADD_X_Y_GREATER_THAN_2});
        ImmutableList projections = ImmutableList.of((Object)ADD_X_Y_Z);
        ImmutableList rowExpressions = ImmutableList.builder().addAll((Iterable)projections).add((Object)filter).build();
        Map commonSubExpressionsByLevel = CommonSubExpressionRewriter.collectCSEByLevel((List)rowExpressions);
        Map cseFields = CommonSubExpressionRewriter.CommonSubExpressionFields.declareCommonSubExpressionFields((ClassDefinition)cursorProcessorClassDefinition, (Map)commonSubExpressionsByLevel);
        Map commonSubExpressions = (Map)commonSubExpressionsByLevel.values().stream().flatMap(m -> m.entrySet().stream()).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
        Assert.assertEquals((int)1, (int)cseFields.size());
        VariableReferenceExpression cseVariable = (VariableReferenceExpression)cseFields.keySet().iterator().next();
        RowExpression rewrittenFilter = (RowExpression)cseCursorCompiler.rewriteRowExpressionsWithCSE((List)ImmutableList.of((Object)filter), commonSubExpressions).get(0);
        List rewrittenProjections = cseCursorCompiler.rewriteRowExpressionsWithCSE((List)projections, commonSubExpressions);
        Assert.assertTrue((boolean)((CallExpression)rewrittenProjections.get(0)).getArguments().contains(cseVariable));
        Assert.assertTrue((boolean)((CallExpression)((SpecialFormExpression)rewrittenFilter).getArguments().get(0)).getArguments().contains(cseVariable));
    }

    @Test
    public void testCompilerWithCSE() {
        PageFunctionCompiler functionCompiler = new PageFunctionCompiler(METADATA, 0);
        ExpressionCompiler expressionCompiler = new ExpressionCompiler(METADATA, functionCompiler);
        SpecialFormExpression filter = new SpecialFormExpression(SpecialFormExpression.Form.AND, (Type)BigintType.BIGINT, new RowExpression[]{ADD_X_Y_GREATER_THAN_2, ADD_X_Y_LESS_THAN_10});
        List<? extends RowExpression> projections = this.createIfProjectionList(5);
        Supplier cseCursorProcessorSupplier = expressionCompiler.compileCursorProcessor(TestingConnectorSession.SESSION.getSqlFunctionProperties(), Optional.of(filter), projections, (Object)"key", true);
        Supplier noCseSECursorProcessorSupplier = expressionCompiler.compileCursorProcessor(TestingConnectorSession.SESSION.getSqlFunctionProperties(), Optional.of(filter), projections, (Object)"key", false);
        Page input = TestCursorProcessorCompiler.createLongBlockPage(2, 0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L);
        ImmutableList types = ImmutableList.of((Object)BigintType.BIGINT, (Object)BigintType.BIGINT);
        PageBuilder pageBuilder = new PageBuilder(projections.stream().map(RowExpression::getType).collect(Collectors.toList()));
        PageRecordSet recordSet = new PageRecordSet((List)types, input);
        ((CursorProcessor)cseCursorProcessorSupplier.get()).process(TestingConnectorSession.SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), recordSet.cursor(), pageBuilder);
        Page pageFromCSE = pageBuilder.build();
        pageBuilder.reset();
        ((CursorProcessor)noCseSECursorProcessorSupplier.get()).process(TestingConnectorSession.SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), recordSet.cursor(), pageBuilder);
        Page pageFromNoCSE = pageBuilder.build();
        this.checkPageEqual(pageFromCSE, pageFromNoCSE);
    }

    private static Page createLongBlockPage(int blockCount, long ... values) {
        Block[] blocks = new Block[blockCount];
        for (int i = 0; i < blockCount; ++i) {
            BlockBuilder builder = BigintType.BIGINT.createFixedSizeBlockBuilder(values.length);
            for (long value : values) {
                BigintType.BIGINT.writeLong(builder, value);
            }
            blocks[i] = builder.build();
        }
        return new Page(blocks);
    }

    private List<? extends RowExpression> createIfProjectionList(int projectionCount) {
        return (List)IntStream.range(0, projectionCount).mapToObj(i -> new SpecialFormExpression(SpecialFormExpression.Form.IF, (Type)BigintType.BIGINT, new RowExpression[]{Expressions.call((String)OperatorType.GREATER_THAN.name(), (FunctionHandle)FUNCTION_MANAGER.resolveOperator(OperatorType.GREATER_THAN, TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT})), (Type)BooleanType.BOOLEAN, (RowExpression[])new RowExpression[]{ADD_X_Y, Expressions.constant((Object)8L, (Type)BigintType.BIGINT)}), Expressions.constant((Object)i, (Type)BigintType.BIGINT), Expressions.constant((Object)((long)i + 1L), (Type)BigintType.BIGINT)})).collect(ImmutableList.toImmutableList());
    }

    private void checkBlockEqual(Block a, Block b) {
        Assert.assertEquals((int)a.getPositionCount(), (int)b.getPositionCount());
        for (int i = 0; i < a.getPositionCount(); ++i) {
            Assert.assertEquals((long)a.getLong(i), (long)b.getLong(i));
        }
    }

    private void checkPageEqual(Page a, Page b) {
        Assert.assertEquals((int)a.getPositionCount(), (int)b.getPositionCount());
        Assert.assertEquals((int)a.getChannelCount(), (int)b.getChannelCount());
        for (int i = 0; i < a.getChannelCount(); ++i) {
            this.checkBlockEqual(a.getBlock(i), b.getBlock(i));
        }
    }
}

