/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.base.Preconditions;
import io.trino.Session;
import io.trino.cost.CachingTableStatsProvider;
import io.trino.cost.ComposableStatsCalculator;
import io.trino.cost.PlanNodeStatsAssertion;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsCalculator;
import io.trino.cost.TableStatsProvider;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.QueryRunner;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;

public class StatsCalculatorAssertion {
    private final QueryRunner queryRunner;
    private final Session session;
    private final PlanNode planNode;
    private final TypeProvider types;
    private final Map<PlanNode, PlanNodeStatsEstimate> sourcesStats;
    private Optional<TableStatsProvider> tableStatsProvider = Optional.empty();

    StatsCalculatorAssertion(QueryRunner queryRunner, Session session, PlanNode planNode, TypeProvider types) {
        this.queryRunner = Objects.requireNonNull(queryRunner, "queryRunner is null");
        this.session = Objects.requireNonNull(session, "session cannot be null");
        this.planNode = Objects.requireNonNull(planNode, "planNode is null");
        this.types = Objects.requireNonNull(types, "types is null");
        this.sourcesStats = new HashMap<PlanNode, PlanNodeStatsEstimate>();
        planNode.getSources().forEach(child -> this.sourcesStats.put((PlanNode)child, PlanNodeStatsEstimate.unknown()));
    }

    public StatsCalculatorAssertion withSourceStats(PlanNodeStatsEstimate sourceStats) {
        Preconditions.checkState((this.planNode.getSources().size() == 1 ? 1 : 0) != 0, (Object)"expected single source");
        return this.withSourceStats(0, sourceStats);
    }

    public StatsCalculatorAssertion withSourceStats(int sourceIndex, PlanNodeStatsEstimate sourceStats) {
        Preconditions.checkArgument((sourceIndex < this.planNode.getSources().size() ? 1 : 0) != 0, (String)"invalid sourceIndex %s; planNode has %s sources", (int)sourceIndex, (int)this.planNode.getSources().size());
        this.sourcesStats.put((PlanNode)this.planNode.getSources().get(sourceIndex), sourceStats);
        return this;
    }

    public StatsCalculatorAssertion withSourceStats(PlanNodeId planNodeId, PlanNodeStatsEstimate sourceStats) {
        PlanNode sourceNode = PlanNodeSearcher.searchFrom((PlanNode)this.planNode).where(node -> node.getId().equals((Object)planNodeId)).findOnlyElement();
        this.sourcesStats.put(sourceNode, sourceStats);
        return this;
    }

    public StatsCalculatorAssertion withSourceStats(Map<PlanNode, PlanNodeStatsEstimate> stats) {
        this.sourcesStats.putAll(stats);
        return this;
    }

    public StatsCalculatorAssertion withTableStatisticsProvider(TableStatsProvider tableStatsProvider) {
        this.tableStatsProvider = Optional.of(tableStatsProvider);
        return this;
    }

    public StatsCalculatorAssertion check(Consumer<PlanNodeStatsAssertion> statisticsAssertionConsumer) {
        PlanNodeStatsEstimate statsEstimate = this.queryRunner.getStatsCalculator().calculateStats(this.planNode, new StatsCalculator.Context(this::getSourceStats, Lookup.noLookup(), this.session, this.types, this.tableStatsProvider.orElseGet(() -> new CachingTableStatsProvider(this.queryRunner.getPlannerContext().getMetadata(), this.session))));
        statisticsAssertionConsumer.accept(PlanNodeStatsAssertion.assertThat(statsEstimate));
        return this;
    }

    public StatsCalculatorAssertion check(ComposableStatsCalculator.Rule<?> rule, Consumer<PlanNodeStatsAssertion> statisticsAssertionConsumer) {
        Optional<PlanNodeStatsEstimate> statsEstimate = StatsCalculatorAssertion.calculatedStats(rule, this.planNode, new StatsCalculator.Context(this::getSourceStats, Lookup.noLookup(), this.session, this.types, this.tableStatsProvider.orElseGet(() -> new CachingTableStatsProvider(this.queryRunner.getPlannerContext().getMetadata(), this.session))));
        Preconditions.checkState((boolean)statsEstimate.isPresent(), (Object)"Expected stats estimates to be present");
        statisticsAssertionConsumer.accept(PlanNodeStatsAssertion.assertThat(statsEstimate.get()));
        return this;
    }

    private static <T extends PlanNode> Optional<PlanNodeStatsEstimate> calculatedStats(ComposableStatsCalculator.Rule<T> rule, PlanNode node, StatsCalculator.Context context) {
        return rule.calculate(node, context);
    }

    private PlanNodeStatsEstimate getSourceStats(PlanNode sourceNode) {
        Preconditions.checkArgument((boolean)this.sourcesStats.containsKey(sourceNode), (String)"stats not found for source %s", (Object)sourceNode);
        return this.sourcesStats.get(sourceNode);
    }
}

