/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.graph.Traverser;
import io.trino.Session;
import io.trino.cost.StatsAndCosts;
import io.trino.execution.QueryManagerConfig;
import io.trino.execution.warnings.WarningCollector;
import io.trino.plugin.tpch.TpchConnectorFactory;
import io.trino.security.AccessControl;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.connector.ConnectorFactory;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.PlanFragmenter;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SubPlan;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import io.trino.transaction.TransactionBuilder;
import io.trino.transaction.TransactionManager;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import org.assertj.core.api.Assertions;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

public class TestPlanFragmentPartitionCount {
    private PlanFragmenter planFragmenter;
    private Session session;
    private LocalQueryRunner localQueryRunner;

    @BeforeClass
    public void setUp() {
        this.session = TestingSession.testSessionBuilder().setCatalog("test-catalog").build();
        this.localQueryRunner = LocalQueryRunner.create((Session)this.session);
        this.localQueryRunner.createCatalog("test-catalog", (ConnectorFactory)new TpchConnectorFactory(), (Map)ImmutableMap.of());
        this.planFragmenter = new PlanFragmenter(this.localQueryRunner.getMetadata(), this.localQueryRunner.getFunctionManager(), this.localQueryRunner.getTransactionManager(), this.localQueryRunner.getCatalogManager(), new QueryManagerConfig());
    }

    @AfterClass(alwaysRun=true)
    public void tearDown() {
        this.planFragmenter = null;
        this.session = null;
        this.localQueryRunner.close();
        this.localQueryRunner = null;
    }

    @Test
    public void testPartitionCountInPlanFragment() {
        PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), this.localQueryRunner.getMetadata(), this.session);
        Symbol a = p.symbol("a", (Type)VarcharType.VARCHAR);
        Symbol b = p.symbol("b", (Type)VarcharType.VARCHAR);
        Symbol c = p.symbol("c", (Type)VarcharType.VARCHAR);
        Symbol d = p.symbol("d", (Type)VarcharType.VARCHAR);
        Symbol f = p.symbol("f", (Type)VarcharType.VARCHAR);
        Symbol g = p.symbol("g", (Type)VarcharType.VARCHAR);
        Symbol h = p.symbol("h", (Type)VarcharType.VARCHAR);
        Symbol i = p.symbol("i", (Type)VarcharType.VARCHAR);
        OutputNode output = p.output(o -> o.source((PlanNode)p.exchange(e -> e.type(ExchangeNode.Type.REPARTITION).addSource((PlanNode)p.exchange(exc -> exc.type(ExchangeNode.Type.REPARTITION).addSource((PlanNode)p.join(JoinNode.Type.INNER, (PlanNode)p.exchange(ex -> ex.type(ExchangeNode.Type.REPARTITION).addSource((PlanNode)p.values(a, b)).addInputsSet(a, b).fixedHashDistributionPartitioningScheme((List<Symbol>)ImmutableList.of((Object)a, (Object)b), (List<Symbol>)ImmutableList.of((Object)b), 5)), (PlanNode)p.exchange(ex -> ex.type(ExchangeNode.Type.REPARTITION).addSource((PlanNode)p.values(c, d)).addInputsSet(c, d).fixedHashDistributionPartitioningScheme((List<Symbol>)ImmutableList.of((Object)c, (Object)d), (List<Symbol>)ImmutableList.of((Object)d), 5)), new JoinNode.EquiJoinClause(b, d))).addInputsSet(a, b, c, d).fixedArbitraryDistributionPartitioningScheme((List<Symbol>)ImmutableList.of((Object)a, (Object)b, (Object)c, (Object)d), 2))).addSource((PlanNode)p.values(f, g, h, i)).addInputsSet(a, b, c, d).addInputsSet(f, g, h, i).fixedHashDistributionPartitioningScheme((List<Symbol>)ImmutableList.of((Object)a, (Object)b, (Object)c, (Object)d), (List<Symbol>)ImmutableList.of((Object)b), 3))));
        Plan plan = new Plan((PlanNode)output, p.getTypes(), StatsAndCosts.empty());
        SubPlan rootSubPlan = this.fragment(plan);
        ImmutableMap.Builder actualPartitionCount = ImmutableMap.builder();
        Traverser.forTree(SubPlan::getChildren).depthFirstPreOrder((Object)rootSubPlan).forEach(subPlan -> actualPartitionCount.put((Object)subPlan.getFragment().getId(), (Object)subPlan.getFragment().getPartitionCount()));
        ImmutableMap expectedPartitionCount = ImmutableMap.of((Object)new PlanFragmentId("0"), Optional.of(3), (Object)new PlanFragmentId("1"), Optional.of(2), (Object)new PlanFragmentId("2"), Optional.of(5), (Object)new PlanFragmentId("3"), Optional.empty(), (Object)new PlanFragmentId("4"), Optional.empty(), (Object)new PlanFragmentId("5"), Optional.empty());
        Assertions.assertThat((Map)expectedPartitionCount).isEqualTo((Object)actualPartitionCount.buildOrThrow());
    }

    private SubPlan fragment(Plan plan) {
        return this.inTransaction(session -> this.planFragmenter.createSubPlans(session, plan, false, WarningCollector.NOOP));
    }

    private <T> T inTransaction(Function<Session, T> transactionSessionConsumer) {
        return (T)TransactionBuilder.transaction((TransactionManager)this.localQueryRunner.getTransactionManager(), (AccessControl)new AllowAllAccessControl()).singleStatement().execute(this.session, session -> {
            session.getCatalog().ifPresent(catalog -> this.localQueryRunner.getMetadata().getCatalogHandle(session, catalog));
            return transactionSessionConsumer.apply((Session)session);
        });
    }
}

