/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule.test;

import io.trino.Session;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsProvider;
import io.trino.cost.TableStatsProvider;
import io.trino.security.AccessControl;
import io.trino.spi.transaction.IsolationLevel;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.iterative.rule.test.RuleAssert;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import io.trino.transaction.TransactionId;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

public class RuleBuilder {
    private final Rule<?> rule;
    private final LocalQueryRunner queryRunner;
    private Session session;
    private final TestingStatsCalculator statsCalculator;

    RuleBuilder(Rule<?> rule, LocalQueryRunner queryRunner, Session session) {
        this.rule = Objects.requireNonNull(rule, "rule is null");
        this.queryRunner = Objects.requireNonNull(queryRunner, "queryRunner is null");
        this.session = Objects.requireNonNull(session, "session is null");
        this.statsCalculator = new TestingStatsCalculator(queryRunner.getStatsCalculator());
    }

    public RuleBuilder setSystemProperty(String key, String value) {
        return this.withSession(Session.builder((Session)this.session).setSystemProperty(key, value).build());
    }

    public RuleBuilder withSession(Session session) {
        this.session = session;
        return this;
    }

    public RuleBuilder overrideStats(String nodeId, PlanNodeStatsEstimate nodeStats) {
        this.statsCalculator.setNodeStats(new PlanNodeId(nodeId), nodeStats);
        return this;
    }

    public RuleAssert on(Function<PlanBuilder, PlanNode> planProvider) {
        Session session = TestingSession.testSession((Session)this.session);
        TransactionId transactionId = this.queryRunner.getTransactionManager().beginTransaction(IsolationLevel.READ_UNCOMMITTED, false, false);
        Session transactionSession = session.beginTransactionId(transactionId, this.queryRunner.getTransactionManager(), (AccessControl)this.queryRunner.getAccessControl());
        this.queryRunner.getMetadata().beginQuery(transactionSession);
        try {
            transactionSession.getCatalog().ifPresent(catalog -> this.queryRunner.getMetadata().getCatalogHandle(transactionSession, catalog));
            PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
            PlanBuilder builder = new PlanBuilder(idAllocator, this.queryRunner.getPlannerContext(), transactionSession);
            PlanNode plan = planProvider.apply(builder);
            TypeProvider types = builder.getTypes();
            return new RuleAssert(this.rule, this.queryRunner, this.statsCalculator, transactionSession, idAllocator, plan, types);
        }
        catch (Throwable t) {
            this.queryRunner.getMetadata().cleanupQuery(session);
            this.queryRunner.getTransactionManager().asyncAbort(transactionId);
            throw t;
        }
    }

    private static class TestingStatsCalculator
    implements StatsCalculator {
        private final StatsCalculator delegate;
        private final Map<PlanNodeId, PlanNodeStatsEstimate> stats = new HashMap<PlanNodeId, PlanNodeStatsEstimate>();

        TestingStatsCalculator(StatsCalculator delegate) {
            this.delegate = Objects.requireNonNull(delegate, "delegate is null");
        }

        public PlanNodeStatsEstimate calculateStats(PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) {
            if (this.stats.containsKey(node.getId())) {
                return this.stats.get(node.getId());
            }
            return this.delegate.calculateStats(node, sourceStats, lookup, session, types, tableStatsProvider);
        }

        public void setNodeStats(PlanNodeId nodeId, PlanNodeStatsEstimate nodeStats) {
            this.stats.put(nodeId, nodeStats);
        }
    }
}

