/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.cost.StatsAndCosts;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.MergeHandle;
import io.trino.metadata.Metadata;
import io.trino.metadata.TableExecuteHandle;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.BeginTableExecuteResult;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.MergeWriterNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.StatisticsWriterNode;
import io.trino.sql.planner.plan.TableExecuteNode;
import io.trino.sql.planner.plan.TableFinishNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.planprinter.PlanPrinter;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

public class BeginTableWrite
implements PlanOptimizer {
    private final Metadata metadata;
    private final FunctionManager functionManager;

    public BeginTableWrite(Metadata metadata, FunctionManager functionManager) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.functionManager = Objects.requireNonNull(functionManager, "functionManager is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) {
        try {
            return SimplePlanRewriter.rewriteWith(new Rewriter(context.session()), plan, Optional.empty());
        }
        catch (RuntimeException e) {
            try {
                int nestLevel = 4;
                String explain = PlanPrinter.textLogicalPlan(plan, this.metadata, this.functionManager, StatsAndCosts.empty(), context.session(), nestLevel, false);
                e.addSuppressed(new Exception("Current plan:\n" + explain));
            }
            catch (RuntimeException runtimeException) {
                // empty catch block
            }
            throw e;
        }
    }

    private static TableWriterNode.WriterTarget getContextTarget(SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
        return context.get().orElseThrow(() -> new IllegalStateException("WriterTarget not present"));
    }

    private class Rewriter
    extends SimplePlanRewriter<Optional<TableWriterNode.WriterTarget>> {
        private final Session session;

        public Rewriter(Session session) {
            this.session = session;
        }

        @Override
        public PlanNode visitTableWriter(TableWriterNode node, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
            TableWriterNode.WriterTarget writerTarget = BeginTableWrite.getContextTarget(context);
            return new TableWriterNode(node.getId(), context.rewrite(node.getSource(), context.get()), writerTarget, node.getRowCountSymbol(), node.getFragmentSymbol(), node.getColumns(), node.getColumnNames(), node.getPartitioningScheme(), node.getStatisticsAggregation(), node.getStatisticsAggregationDescriptor());
        }

        @Override
        public PlanNode visitTableExecute(TableExecuteNode node, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
            TableWriterNode.TableExecuteTarget tableExecuteTarget = (TableWriterNode.TableExecuteTarget)BeginTableWrite.getContextTarget(context);
            return new TableExecuteNode(node.getId(), this.rewriteModifyTableScan(node.getSource(), tableExecuteTarget.getSourceHandle().orElseThrow(), false), tableExecuteTarget, node.getRowCountSymbol(), node.getFragmentSymbol(), node.getColumns(), node.getColumnNames(), node.getPartitioningScheme());
        }

        @Override
        public PlanNode visitMergeWriter(MergeWriterNode mergeNode, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
            TableWriterNode.MergeTarget mergeTarget = (TableWriterNode.MergeTarget)BeginTableWrite.getContextTarget(context);
            return new MergeWriterNode(mergeNode.getId(), this.rewriteModifyTableScan(mergeNode.getSource(), mergeTarget.getHandle(), true), mergeTarget, mergeNode.getProjectedSymbols(), mergeNode.getPartitioningScheme(), mergeNode.getOutputSymbols());
        }

        @Override
        public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
            PlanNode child = node.getSource();
            child = context.rewrite(child, context.get());
            StatisticsWriterNode.WriteStatisticsHandle analyzeHandle = new StatisticsWriterNode.WriteStatisticsHandle(BeginTableWrite.this.metadata.beginStatisticsCollection(this.session, ((StatisticsWriterNode.WriteStatisticsReference)node.getTarget()).getHandle()));
            return new StatisticsWriterNode(node.getId(), child, analyzeHandle, node.getRowCountSymbol(), node.isRowCountEnabled(), node.getDescriptor());
        }

        @Override
        public PlanNode visitTableFinish(TableFinishNode node, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
            PlanNode child = node.getSource();
            TableWriterNode.WriterTarget originalTarget = this.getWriterTarget(child);
            TableWriterNode.WriterTarget newTarget = this.createWriterTarget(originalTarget, child);
            child = context.rewrite(child, Optional.of(newTarget));
            return new TableFinishNode(node.getId(), child, newTarget, node.getRowCountSymbol(), node.getStatisticsAggregation(), node.getStatisticsAggregationDescriptor());
        }

        public TableWriterNode.WriterTarget getWriterTarget(PlanNode node) {
            if (node instanceof TableWriterNode) {
                TableWriterNode tableWriterNode = (TableWriterNode)node;
                return tableWriterNode.getTarget();
            }
            if (node instanceof TableExecuteNode) {
                TableExecuteNode tableExecuteNode = (TableExecuteNode)node;
                TableWriterNode.TableExecuteTarget target = tableExecuteNode.getTarget();
                return new TableWriterNode.TableExecuteTarget(target.getExecuteHandle(), this.findTableScanHandleForTableExecute(tableExecuteNode.getSource()), target.getSchemaTableName(), target.getWriterScalingOptions());
            }
            if (node instanceof MergeWriterNode) {
                MergeWriterNode mergeWriterNode = (MergeWriterNode)node;
                TableWriterNode.MergeTarget mergeTarget = mergeWriterNode.getTarget();
                Optional<TableHandle> tableHandle = this.findTableScanHandleForMergeWriter(mergeWriterNode.getSource());
                if (tableHandle.isEmpty()) {
                    return mergeTarget;
                }
                return new TableWriterNode.MergeTarget(tableHandle.get(), mergeTarget.getMergeHandle(), mergeTarget.getSchemaTableName(), mergeTarget.getMergeParadigmAndTypes(), Rewriter.findSourceTableHandles(node), mergeTarget.getUpdateCaseColumnHandles());
            }
            if (node instanceof ExchangeNode || node instanceof UnionNode) {
                Set writerTargets = node.getSources().stream().map(this::getWriterTarget).collect(Collectors.toSet());
                return (TableWriterNode.WriterTarget)Iterables.getOnlyElement(writerTargets);
            }
            throw new IllegalArgumentException("Invalid child for TableCommitNode: " + node.getClass().getSimpleName());
        }

        private TableWriterNode.WriterTarget createWriterTarget(TableWriterNode.WriterTarget target, PlanNode planNode) {
            if (target instanceof TableWriterNode.CreateReference) {
                TableWriterNode.CreateReference create = (TableWriterNode.CreateReference)target;
                return new TableWriterNode.CreateTarget(BeginTableWrite.this.metadata.beginCreateTable(this.session, create.getCatalog(), create.getTableMetadata(), create.getLayout(), create.isReplace()), create.getTableMetadata().getTable(), target.supportsMultipleWritersPerPartition(BeginTableWrite.this.metadata, this.session), target.getMaxWriterTasks(BeginTableWrite.this.metadata, this.session), target.getWriterScalingOptions(BeginTableWrite.this.metadata, this.session), create.isReplace());
            }
            if (target instanceof TableWriterNode.InsertReference) {
                TableWriterNode.InsertReference insert = (TableWriterNode.InsertReference)target;
                return new TableWriterNode.InsertTarget(BeginTableWrite.this.metadata.beginInsert(this.session, insert.getHandle(), insert.getColumns()), BeginTableWrite.this.metadata.getTableName(this.session, insert.getHandle()).getSchemaTableName(), target.supportsMultipleWritersPerPartition(BeginTableWrite.this.metadata, this.session), target.getMaxWriterTasks(BeginTableWrite.this.metadata, this.session), target.getWriterScalingOptions(BeginTableWrite.this.metadata, this.session), Rewriter.findSourceTableHandles(planNode));
            }
            if (target instanceof TableWriterNode.MergeTarget) {
                TableWriterNode.MergeTarget merge = (TableWriterNode.MergeTarget)target;
                MergeHandle mergeHandle = BeginTableWrite.this.metadata.beginMerge(this.session, merge.getHandle(), merge.getUpdateCaseColumnHandles());
                return new TableWriterNode.MergeTarget(mergeHandle.tableHandle(), Optional.of(mergeHandle), merge.getSchemaTableName(), merge.getMergeParadigmAndTypes(), Rewriter.findSourceTableHandles(planNode), merge.getUpdateCaseColumnHandles());
            }
            if (target instanceof TableWriterNode.RefreshMaterializedViewReference) {
                TableWriterNode.RefreshMaterializedViewReference refreshMV = (TableWriterNode.RefreshMaterializedViewReference)target;
                return new TableWriterNode.RefreshMaterializedViewTarget(refreshMV.getStorageTableHandle(), BeginTableWrite.this.metadata.beginRefreshMaterializedView(this.session, refreshMV.getStorageTableHandle(), refreshMV.getSourceTableHandles(), refreshMV.getRefreshType()), BeginTableWrite.this.metadata.getTableName(this.session, refreshMV.getStorageTableHandle()).getSchemaTableName(), refreshMV.getSourceTableHandles(), refreshMV.getSourceTableFunctions(), refreshMV.getWriterScalingOptions(BeginTableWrite.this.metadata, this.session));
            }
            if (target instanceof TableWriterNode.TableExecuteTarget) {
                TableWriterNode.TableExecuteTarget tableExecute = (TableWriterNode.TableExecuteTarget)target;
                BeginTableExecuteResult<TableExecuteHandle, TableHandle> result = BeginTableWrite.this.metadata.beginTableExecute(this.session, tableExecute.getExecuteHandle(), tableExecute.getMandatorySourceHandle());
                return new TableWriterNode.TableExecuteTarget((TableExecuteHandle)result.getTableExecuteHandle(), Optional.of((TableHandle)result.getSourceHandle()), tableExecute.getSchemaTableName(), tableExecute.getWriterScalingOptions());
            }
            throw new IllegalArgumentException("Unhandled target type: " + target.getClass().getSimpleName());
        }

        private static List<TableHandle> findSourceTableHandles(PlanNode startNode) {
            return (List)PlanNodeSearcher.searchFrom(startNode).where(TableScanNode.class::isInstance).findAll().stream().map(TableScanNode.class::cast).map(TableScanNode::getTable).collect(ImmutableList.toImmutableList());
        }

        private Optional<TableHandle> findTableScanHandleForTableExecute(PlanNode startNode) {
            List<PlanNode> tableScanNodes = PlanNodeSearcher.searchFrom(startNode).where(node -> {
                TableScanNode tableScanNode;
                return node instanceof TableScanNode && (tableScanNode = (TableScanNode)node).isUpdateTarget();
            }).findAll();
            if (tableScanNodes.size() == 1) {
                return Optional.of(((TableScanNode)tableScanNodes.get(0)).getTable());
            }
            throw new IllegalArgumentException("Expected to find exactly one update target TableScanNode in plan but found: " + String.valueOf(tableScanNodes));
        }

        private Optional<TableHandle> findTableScanHandleForMergeWriter(PlanNode startNode) {
            List<PlanNode> tableScanNodes = PlanNodeSearcher.searchFrom(startNode).where(node -> {
                TableScanNode scanNode;
                return node instanceof TableScanNode && (scanNode = (TableScanNode)node).isUpdateTarget();
            }).findAll();
            if (tableScanNodes.isEmpty()) {
                return Optional.empty();
            }
            if (tableScanNodes.size() == 1) {
                return Optional.of(((TableScanNode)tableScanNodes.get(0)).getTable());
            }
            throw new IllegalArgumentException("Expected to find zero or one update target TableScanNode in plan but found: " + String.valueOf(tableScanNodes));
        }

        private PlanNode rewriteModifyTableScan(PlanNode node, final TableHandle handle, boolean tableScanNotFoundIsOk) {
            final AtomicInteger modifyCount = new AtomicInteger(0);
            PlanNode rewrittenNode = SimplePlanRewriter.rewriteWith(new SimplePlanRewriter<Void>(this){

                @Override
                public PlanNode visitTableScan(TableScanNode scan, SimplePlanRewriter.RewriteContext<Void> context) {
                    if (!scan.isUpdateTarget()) {
                        return scan;
                    }
                    modifyCount.incrementAndGet();
                    return new TableScanNode(scan.getId(), handle, scan.getOutputSymbols(), scan.getAssignments(), scan.getEnforcedConstraint(), scan.getStatistics(), scan.isUpdateTarget(), scan.getUseConnectorNodePartitioning());
                }
            }, node, null);
            int countFound = modifyCount.get();
            if (tableScanNotFoundIsOk) {
                Verify.verify((countFound == 0 || countFound == 1 ? 1 : 0) != 0, (String)"Expected to find zero or one update target TableScanNodes but found %s", (int)countFound);
            } else {
                Verify.verify((countFound == 1 ? 1 : 0) != 0, (String)"Expected to find exactly one update target TableScanNode but found %s", (int)countFound);
            }
            return rewrittenNode;
        }
    }
}

