/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.Metadata;
import io.trino.metadata.TableHandle;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.ChildReplacer;
import io.trino.sql.planner.plan.DeleteNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.StatisticsWriterNode;
import io.trino.sql.planner.plan.TableFinishNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UpdateNode;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class BeginTableWrite
implements PlanOptimizer {
    private final Metadata metadata;

    public BeginTableWrite(Metadata metadata) {
        this.metadata = metadata;
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        return SimplePlanRewriter.rewriteWith(new Rewriter(session), plan, Optional.empty());
    }

    private static TableWriterNode.WriterTarget getContextTarget(SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
        return context.get().orElseThrow(() -> new IllegalStateException("WriterTarget not present"));
    }

    private class Rewriter
    extends SimplePlanRewriter<Optional<TableWriterNode.WriterTarget>> {
        private final Session session;

        public Rewriter(Session session) {
            this.session = session;
        }

        @Override
        public PlanNode visitTableWriter(TableWriterNode node, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
            TableWriterNode.WriterTarget writerTarget = BeginTableWrite.getContextTarget(context);
            return new TableWriterNode(node.getId(), context.rewrite(node.getSource(), context.get()), writerTarget, node.getRowCountSymbol(), node.getFragmentSymbol(), node.getColumns(), node.getColumnNames(), node.getNotNullColumnSymbols(), node.getPartitioningScheme(), node.getPreferredPartitioningScheme(), node.getStatisticsAggregation(), node.getStatisticsAggregationDescriptor());
        }

        @Override
        public PlanNode visitDelete(DeleteNode node, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
            TableWriterNode.DeleteTarget deleteTarget = (TableWriterNode.DeleteTarget)BeginTableWrite.getContextTarget(context);
            return new DeleteNode(node.getId(), this.rewriteModifyTableScan(node.getSource(), deleteTarget.getHandleOrElseThrow()), deleteTarget, node.getRowId(), node.getOutputSymbols());
        }

        @Override
        public PlanNode visitUpdate(UpdateNode node, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
            TableWriterNode.UpdateTarget updateTarget = (TableWriterNode.UpdateTarget)BeginTableWrite.getContextTarget(context);
            return new UpdateNode(node.getId(), this.rewriteModifyTableScan(node.getSource(), updateTarget.getHandleOrElseThrow()), updateTarget, node.getRowId(), node.getColumnValueAndRowIdSymbols(), node.getOutputSymbols());
        }

        @Override
        public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
            PlanNode child = node.getSource();
            child = context.rewrite(child, context.get());
            StatisticsWriterNode.WriteStatisticsHandle analyzeHandle = new StatisticsWriterNode.WriteStatisticsHandle(BeginTableWrite.this.metadata.beginStatisticsCollection(this.session, ((StatisticsWriterNode.WriteStatisticsReference)node.getTarget()).getHandle()));
            return new StatisticsWriterNode(node.getId(), child, analyzeHandle, node.getRowCountSymbol(), node.isRowCountEnabled(), node.getDescriptor());
        }

        @Override
        public PlanNode visitTableFinish(TableFinishNode node, SimplePlanRewriter.RewriteContext<Optional<TableWriterNode.WriterTarget>> context) {
            PlanNode child = node.getSource();
            TableWriterNode.WriterTarget originalTarget = this.getWriterTarget(child);
            TableWriterNode.WriterTarget newTarget = this.createWriterTarget(originalTarget);
            child = context.rewrite(child, Optional.of(newTarget));
            return new TableFinishNode(node.getId(), child, newTarget, node.getRowCountSymbol(), node.getStatisticsAggregation(), node.getStatisticsAggregationDescriptor());
        }

        public TableWriterNode.WriterTarget getWriterTarget(PlanNode node) {
            if (node instanceof TableWriterNode) {
                return ((TableWriterNode)node).getTarget();
            }
            if (node instanceof DeleteNode) {
                DeleteNode deleteNode = (DeleteNode)node;
                TableWriterNode.DeleteTarget delete = deleteNode.getTarget();
                return new TableWriterNode.DeleteTarget(Optional.of(this.findTableScanHandle(deleteNode.getSource())), delete.getSchemaTableName());
            }
            if (node instanceof UpdateNode) {
                UpdateNode updateNode = (UpdateNode)node;
                TableWriterNode.UpdateTarget update = updateNode.getTarget();
                return new TableWriterNode.UpdateTarget(Optional.of(this.findTableScanHandle(updateNode.getSource())), update.getSchemaTableName(), update.getUpdatedColumns(), update.getUpdatedColumnHandles());
            }
            if (node instanceof ExchangeNode || node instanceof UnionNode) {
                Set writerTargets = node.getSources().stream().map(this::getWriterTarget).collect(Collectors.toSet());
                return (TableWriterNode.WriterTarget)Iterables.getOnlyElement(writerTargets);
            }
            throw new IllegalArgumentException("Invalid child for TableCommitNode: " + node.getClass().getSimpleName());
        }

        private TableWriterNode.WriterTarget createWriterTarget(TableWriterNode.WriterTarget target) {
            if (target instanceof TableWriterNode.CreateReference) {
                TableWriterNode.CreateReference create = (TableWriterNode.CreateReference)target;
                return new TableWriterNode.CreateTarget(BeginTableWrite.this.metadata.beginCreateTable(this.session, create.getCatalog(), create.getTableMetadata(), create.getLayout()), create.getTableMetadata().getTable());
            }
            if (target instanceof TableWriterNode.InsertReference) {
                TableWriterNode.InsertReference insert = (TableWriterNode.InsertReference)target;
                return new TableWriterNode.InsertTarget(BeginTableWrite.this.metadata.beginInsert(this.session, insert.getHandle(), insert.getColumns()), BeginTableWrite.this.metadata.getTableMetadata(this.session, insert.getHandle()).getTable());
            }
            if (target instanceof TableWriterNode.DeleteTarget) {
                TableWriterNode.DeleteTarget delete = (TableWriterNode.DeleteTarget)target;
                TableHandle newHandle = BeginTableWrite.this.metadata.beginDelete(this.session, delete.getHandleOrElseThrow());
                return new TableWriterNode.DeleteTarget(Optional.of(newHandle), delete.getSchemaTableName());
            }
            if (target instanceof TableWriterNode.UpdateTarget) {
                TableWriterNode.UpdateTarget update = (TableWriterNode.UpdateTarget)target;
                TableHandle newHandle = BeginTableWrite.this.metadata.beginUpdate(this.session, update.getHandleOrElseThrow(), update.getUpdatedColumnHandles());
                return new TableWriterNode.UpdateTarget(Optional.of(newHandle), update.getSchemaTableName(), update.getUpdatedColumns(), update.getUpdatedColumnHandles());
            }
            if (target instanceof TableWriterNode.RefreshMaterializedViewReference) {
                TableWriterNode.RefreshMaterializedViewReference refreshMV = (TableWriterNode.RefreshMaterializedViewReference)target;
                return new TableWriterNode.RefreshMaterializedViewTarget(refreshMV.getStorageTableHandle(), BeginTableWrite.this.metadata.beginRefreshMaterializedView(this.session, refreshMV.getStorageTableHandle(), refreshMV.getSourceTableHandles()), BeginTableWrite.this.metadata.getTableMetadata(this.session, refreshMV.getStorageTableHandle()).getTable(), refreshMV.getSourceTableHandles());
            }
            throw new IllegalArgumentException("Unhandled target type: " + target.getClass().getSimpleName());
        }

        private TableHandle findTableScanHandle(PlanNode node) {
            JoinNode joinNode;
            if (node instanceof TableScanNode) {
                return ((TableScanNode)node).getTable();
            }
            if (node instanceof FilterNode) {
                return this.findTableScanHandle(((FilterNode)node).getSource());
            }
            if (node instanceof ProjectNode) {
                return this.findTableScanHandle(((ProjectNode)node).getSource());
            }
            if (node instanceof SemiJoinNode) {
                return this.findTableScanHandle(((SemiJoinNode)node).getSource());
            }
            if (node instanceof JoinNode && (joinNode = (JoinNode)node).getType() == JoinNode.Type.INNER && QueryCardinalityUtil.isAtMostScalar(joinNode.getRight())) {
                return this.findTableScanHandle(joinNode.getLeft());
            }
            throw new IllegalArgumentException("Invalid descendant for DeleteNode or UpdateNode: " + node.getClass().getName());
        }

        private PlanNode rewriteModifyTableScan(PlanNode node, TableHandle handle) {
            JoinNode joinNode;
            if (node instanceof TableScanNode) {
                TableScanNode scan = (TableScanNode)node;
                return new TableScanNode(scan.getId(), handle, scan.getOutputSymbols(), scan.getAssignments(), scan.getEnforcedConstraint(), scan.isUpdateTarget(), scan.getUseConnectorNodePartitioning());
            }
            if (node instanceof FilterNode) {
                PlanNode source = this.rewriteModifyTableScan(((FilterNode)node).getSource(), handle);
                return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)source));
            }
            if (node instanceof ProjectNode) {
                PlanNode source = this.rewriteModifyTableScan(((ProjectNode)node).getSource(), handle);
                return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)source));
            }
            if (node instanceof SemiJoinNode) {
                PlanNode source = this.rewriteModifyTableScan(((SemiJoinNode)node).getSource(), handle);
                return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)source, (Object)((SemiJoinNode)node).getFilteringSource()));
            }
            if (node instanceof JoinNode && (joinNode = (JoinNode)node).getType() == JoinNode.Type.INNER && QueryCardinalityUtil.isAtMostScalar(joinNode.getRight())) {
                PlanNode source = this.rewriteModifyTableScan(joinNode.getLeft(), handle);
                return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)source, (Object)joinNode.getRight()));
            }
            throw new IllegalArgumentException("Invalid descendant for DeleteNode or UpdateNode: " + node.getClass().getName());
        }
    }
}

