/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.cost.PlanNodeStatsEstimate;
import io.prestosql.cost.StatsAndCosts;
import io.prestosql.metadata.AbstractMockMetadata;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.MetadataManager;
import io.prestosql.sql.planner.Plan;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.assertions.ExpectedValueProvider;
import io.prestosql.sql.planner.assertions.ExpressionMatcher;
import io.prestosql.sql.planner.assertions.PlanAssert;
import io.prestosql.sql.planner.assertions.PlanMatchPattern;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.iterative.rule.PushProjectionThroughJoin;
import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.ArithmeticBinaryExpression;
import io.prestosql.sql.tree.ArithmeticUnaryExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.testing.TestingSession;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.Test;

public class TestPushProjectionThroughJoin {
    private final Metadata metadata = MetadataManager.createTestMetadataManager();

    @Test
    public void testPushesProjectionThroughJoin() {
        PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
        PlanBuilder p = new PlanBuilder(idAllocator, AbstractMockMetadata.dummyMetadata());
        Symbol a0 = p.symbol("a0");
        Symbol a1 = p.symbol("a1");
        Symbol a2 = p.symbol("a2");
        Symbol a3 = p.symbol("a3");
        Symbol b0 = p.symbol("b0");
        Symbol b1 = p.symbol("b1");
        Symbol b2 = p.symbol("b2");
        ProjectNode planNode = p.project(Assignments.of((Symbol)a3, (Expression)new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, (Expression)a2.toSymbolReference()), (Symbol)b2, (Expression)new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.PLUS, (Expression)b1.toSymbolReference())), (PlanNode)p.join(JoinNode.Type.INNER, (PlanNode)p.project(Assignments.of((Symbol)a2, (Expression)new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.PLUS, (Expression)a0.toSymbolReference()), (Symbol)a1, (Expression)a1.toSymbolReference()), (PlanNode)p.project(Assignments.builder().putIdentity(a0).putIdentity(a1).build(), (PlanNode)p.values(a0, a1))), (PlanNode)p.values(b0, b1), new JoinNode.EquiJoinClause(a1, b1)));
        Optional rewritten = PushProjectionThroughJoin.pushProjectionThroughJoin((Metadata)this.metadata, (ProjectNode)planNode, (Lookup)Lookup.noLookup(), (PlanNodeIdAllocator)idAllocator);
        Assert.assertTrue((boolean)rewritten.isPresent());
        PlanAssert.assertPlan(TestingSession.testSessionBuilder().build(), AbstractMockMetadata.dummyMetadata(), node -> PlanNodeStatsEstimate.unknown(), new Plan((PlanNode)rewritten.get(), p.getTypes(), StatsAndCosts.empty()), Lookup.noLookup(), PlanMatchPattern.join(JoinNode.Type.INNER, (List<ExpectedValueProvider<JoinNode.EquiJoinClause>>)ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol("a1"), new Symbol("b1"))), PlanMatchPattern.strictProject((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"a3", (Object)PlanMatchPattern.expression("-(+a0)"), (Object)"a1", (Object)PlanMatchPattern.expression("a1")), PlanMatchPattern.strictProject((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"a0", (Object)PlanMatchPattern.expression("a0"), (Object)"a1", (Object)PlanMatchPattern.expression("a1")), PlanMatchPattern.values("a0", "a1"))), PlanMatchPattern.strictProject((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"b2", (Object)PlanMatchPattern.expression("+b1"), (Object)"b1", (Object)PlanMatchPattern.expression("b1")), PlanMatchPattern.values("b0", "b1"))).withExactOutputs("a3", "b2"));
    }

    @Test
    public void testDoesNotPushStraddlingProjection() {
        PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), AbstractMockMetadata.dummyMetadata());
        Symbol a = p.symbol("a");
        Symbol b = p.symbol("b");
        Symbol c = p.symbol("c");
        ProjectNode planNode = p.project(Assignments.of((Symbol)c, (Expression)new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, (Expression)a.toSymbolReference(), (Expression)b.toSymbolReference())), (PlanNode)p.join(JoinNode.Type.INNER, (PlanNode)p.values(a), (PlanNode)p.values(b), new JoinNode.EquiJoinClause[0]));
        Optional rewritten = PushProjectionThroughJoin.pushProjectionThroughJoin((Metadata)this.metadata, (ProjectNode)planNode, (Lookup)Lookup.noLookup(), (PlanNodeIdAllocator)new PlanNodeIdAllocator());
        Assertions.assertThat((Optional)rewritten).isEmpty();
    }

    @Test
    public void testDoesNotPushProjectionThroughOuterJoin() {
        PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), AbstractMockMetadata.dummyMetadata());
        Symbol a = p.symbol("a");
        Symbol b = p.symbol("b");
        Symbol c = p.symbol("c");
        ProjectNode planNode = p.project(Assignments.of((Symbol)c, (Expression)new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, (Expression)a.toSymbolReference())), (PlanNode)p.join(JoinNode.Type.LEFT, (PlanNode)p.values(a), (PlanNode)p.values(b), new JoinNode.EquiJoinClause[0]));
        Optional rewritten = PushProjectionThroughJoin.pushProjectionThroughJoin((Metadata)this.metadata, (ProjectNode)planNode, (Lookup)Lookup.noLookup(), (PlanNodeIdAllocator)new PlanNodeIdAllocator());
        Assertions.assertThat((Optional)rewritten).isEmpty();
    }
}

