/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.FunctionType;
import com.facebook.presto.common.type.IntegerType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.iterative.rule.TranslateExpressions;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
import org.testng.Assert;
import org.testng.annotations.Test;

public class TestTranslateExpressions
extends BaseRuleTest {
    private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
    private static final FunctionAndTypeManager FUNCTION_MANAGER = METADATA.getFunctionAndTypeManager();
    private static final FunctionResolution FUNCTION_RESOLUTION = new FunctionResolution(FUNCTION_MANAGER);
    private static final FunctionHandle REDUCE_AGG = FUNCTION_MANAGER.lookupFunction("reduce_agg", TypeSignatureProvider.fromTypes((Type[])new Type[]{IntegerType.INTEGER, IntegerType.INTEGER, new FunctionType((List)ImmutableList.of((Object)IntegerType.INTEGER, (Object)IntegerType.INTEGER), (Type)IntegerType.INTEGER), new FunctionType((List)ImmutableList.of((Object)IntegerType.INTEGER, (Object)IntegerType.INTEGER), (Type)IntegerType.INTEGER)}));

    public TestTranslateExpressions() {
        super(new Plugin[0]);
    }

    @Test
    public void testTranslateAggregationWithLambda() {
        PlanNode result = this.tester().assertThat(new TranslateExpressions(METADATA, new SqlParser()).aggregationRowExpressionRewriteRule()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(Expressions.variable((String)"reduce_agg", (Type)IntegerType.INTEGER), new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, (Type)IntegerType.INTEGER, (List)ImmutableList.of((Object)OriginalExpressionUtils.castToRowExpression((Expression)PlanBuilder.expression("input")), (Object)OriginalExpressionUtils.castToRowExpression((Expression)PlanBuilder.expression("0")), (Object)OriginalExpressionUtils.castToRowExpression((Expression)PlanBuilder.expression("(x,y) -> x*y")), (Object)OriginalExpressionUtils.castToRowExpression((Expression)PlanBuilder.expression("(a,b) -> a*b")))), Optional.of(OriginalExpressionUtils.castToRowExpression((Expression)PlanBuilder.expression("input > 10"))), Optional.empty(), false, Optional.empty())).source((PlanNode)p.values(p.variable("input", (Type)IntegerType.INTEGER))))).get();
        AggregationNode.Aggregation translated = (AggregationNode.Aggregation)((AggregationNode)result).getAggregations().get(Expressions.variable((String)"reduce_agg", (Type)IntegerType.INTEGER));
        Assert.assertEquals((Object)translated, (Object)new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, (Type)IntegerType.INTEGER, (List)ImmutableList.of((Object)Expressions.variable((String)"input", (Type)IntegerType.INTEGER), (Object)Expressions.constant((Object)0L, (Type)IntegerType.INTEGER), (Object)new LambdaDefinitionExpression((List)ImmutableList.of((Object)IntegerType.INTEGER, (Object)IntegerType.INTEGER), (List)ImmutableList.of((Object)"x", (Object)"y"), (RowExpression)this.multiply((RowExpression)Expressions.variable((String)"x", (Type)IntegerType.INTEGER), (RowExpression)Expressions.variable((String)"y", (Type)IntegerType.INTEGER))), (Object)new LambdaDefinitionExpression((List)ImmutableList.of((Object)IntegerType.INTEGER, (Object)IntegerType.INTEGER), (List)ImmutableList.of((Object)"a", (Object)"b"), (RowExpression)this.multiply((RowExpression)Expressions.variable((String)"a", (Type)IntegerType.INTEGER), (RowExpression)Expressions.variable((String)"b", (Type)IntegerType.INTEGER))))), Optional.of(this.greaterThan((RowExpression)Expressions.variable((String)"input", (Type)IntegerType.INTEGER), (RowExpression)Expressions.constant((Object)10L, (Type)IntegerType.INTEGER))), Optional.empty(), false, Optional.empty()));
        Assert.assertFalse((boolean)TestTranslateExpressions.isUntranslated(translated));
    }

    @Test
    public void testTranslateIntermediateAggregationWithLambda() {
        PlanNode result = this.tester().assertThat(new TranslateExpressions(METADATA, new SqlParser()).aggregationRowExpressionRewriteRule()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(Expressions.variable((String)"reduce_agg", (Type)IntegerType.INTEGER), new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, (Type)IntegerType.INTEGER, (List)ImmutableList.of((Object)OriginalExpressionUtils.castToRowExpression((Expression)PlanBuilder.expression("input")), (Object)OriginalExpressionUtils.castToRowExpression((Expression)PlanBuilder.expression("(x,y) -> x*y")), (Object)OriginalExpressionUtils.castToRowExpression((Expression)PlanBuilder.expression("(a,b) -> a*b")))), Optional.of(OriginalExpressionUtils.castToRowExpression((Expression)PlanBuilder.expression("input > 10"))), Optional.empty(), false, Optional.empty())).source((PlanNode)p.values(p.variable("input", (Type)IntegerType.INTEGER))))).get();
        AggregationNode.Aggregation translated = (AggregationNode.Aggregation)((AggregationNode)result).getAggregations().get(Expressions.variable((String)"reduce_agg", (Type)IntegerType.INTEGER));
        Assert.assertEquals((Object)translated, (Object)new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, (Type)IntegerType.INTEGER, (List)ImmutableList.of((Object)Expressions.variable((String)"input", (Type)IntegerType.INTEGER), (Object)new LambdaDefinitionExpression((List)ImmutableList.of((Object)IntegerType.INTEGER, (Object)IntegerType.INTEGER), (List)ImmutableList.of((Object)"x", (Object)"y"), (RowExpression)this.multiply((RowExpression)Expressions.variable((String)"x", (Type)IntegerType.INTEGER), (RowExpression)Expressions.variable((String)"y", (Type)IntegerType.INTEGER))), (Object)new LambdaDefinitionExpression((List)ImmutableList.of((Object)IntegerType.INTEGER, (Object)IntegerType.INTEGER), (List)ImmutableList.of((Object)"a", (Object)"b"), (RowExpression)this.multiply((RowExpression)Expressions.variable((String)"a", (Type)IntegerType.INTEGER), (RowExpression)Expressions.variable((String)"b", (Type)IntegerType.INTEGER))))), Optional.of(this.greaterThan((RowExpression)Expressions.variable((String)"input", (Type)IntegerType.INTEGER), (RowExpression)Expressions.constant((Object)10L, (Type)IntegerType.INTEGER))), Optional.empty(), false, Optional.empty()));
        Assert.assertFalse((boolean)TestTranslateExpressions.isUntranslated(translated));
    }

    private CallExpression greaterThan(RowExpression left, RowExpression right) {
        return Expressions.call((String)"GREATER_THAN", (FunctionHandle)FUNCTION_RESOLUTION.comparisonFunction(OperatorType.GREATER_THAN, left.getType(), right.getType()), (Type)BooleanType.BOOLEAN, (List)ImmutableList.of((Object)left, (Object)right));
    }

    private CallExpression multiply(RowExpression left, RowExpression right) {
        return Expressions.call((String)"MULTIPLY", (FunctionHandle)FUNCTION_RESOLUTION.arithmeticFunction(OperatorType.MULTIPLY, left.getType(), right.getType()), (Type)left.getType(), (List)ImmutableList.of((Object)left, (Object)right));
    }

    private static boolean isUntranslated(AggregationNode.Aggregation aggregation) {
        return aggregation.getCall().getArguments().stream().anyMatch(OriginalExpressionUtils::isExpression) || aggregation.getFilter().map(OriginalExpressionUtils::isExpression).orElse(false) != false;
    }
}

