/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.base.Function;
import com.google.common.base.Functions;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import io.trino.Session;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.ResolvedIndex;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.function.FunctionKind;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.IndexJoinNode;
import io.trino.sql.planner.plan.IndexSourceNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.plan.SortNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.WindowFrameType;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

public class IndexJoinOptimizer
implements PlanOptimizer {
    private final PlannerContext plannerContext;

    public IndexJoinOptimizer(PlannerContext plannerContext) {
        this.plannerContext = plannerContext;
    }

    @Override
    public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) {
        Objects.requireNonNull(plan, "plan is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(context.symbolAllocator(), context.idAllocator(), this.plannerContext, context.session()), plan, null);
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private final SymbolAllocator symbolAllocator;
        private final PlanNodeIdAllocator idAllocator;
        private final PlannerContext plannerContext;
        private final Session session;

        private Rewriter(SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, PlannerContext plannerContext, Session session) {
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.session = Objects.requireNonNull(session, "session is null");
        }

        @Override
        public PlanNode visitJoin(JoinNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            PlanNode leftRewritten = context.rewrite(node.getLeft());
            PlanNode rightRewritten = context.rewrite(node.getRight());
            if (!node.getCriteria().isEmpty()) {
                Optional<PlanNode> rightIndexCandidate;
                List leftJoinSymbols = Lists.transform(node.getCriteria(), JoinNode.EquiJoinClause::getLeft);
                List rightJoinSymbols = Lists.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight);
                Optional<PlanNode> leftIndexCandidate = IndexSourceRewriter.rewriteWithIndex(leftRewritten, (Set<Symbol>)ImmutableSet.copyOf((Collection)leftJoinSymbols), this.symbolAllocator, this.idAllocator, this.plannerContext, this.session);
                if (leftIndexCandidate.isPresent()) {
                    Map<Symbol, Symbol> trace = IndexKeyTracer.trace(leftIndexCandidate.get(), (Set<Symbol>)ImmutableSet.copyOf((Collection)leftJoinSymbols));
                    Preconditions.checkState((!trace.isEmpty() ? 1 : 0) != 0);
                }
                if ((rightIndexCandidate = IndexSourceRewriter.rewriteWithIndex(rightRewritten, (Set<Symbol>)ImmutableSet.copyOf((Collection)rightJoinSymbols), this.symbolAllocator, this.idAllocator, this.plannerContext, this.session)).isPresent()) {
                    Map<Symbol, Symbol> trace = IndexKeyTracer.trace(rightIndexCandidate.get(), (Set<Symbol>)ImmutableSet.copyOf((Collection)rightJoinSymbols));
                    Preconditions.checkState((!trace.isEmpty() ? 1 : 0) != 0);
                }
                switch (node.getType()) {
                    case INNER: {
                        PlanNode indexJoinNode = null;
                        if (rightIndexCandidate.isPresent()) {
                            indexJoinNode = new IndexJoinNode(this.idAllocator.getNextId(), IndexJoinNode.Type.INNER, leftRewritten, rightIndexCandidate.get(), Rewriter.createEquiJoinClause(leftJoinSymbols, rightJoinSymbols), Optional.empty(), Optional.empty());
                        } else if (leftIndexCandidate.isPresent()) {
                            indexJoinNode = new IndexJoinNode(this.idAllocator.getNextId(), IndexJoinNode.Type.INNER, rightRewritten, leftIndexCandidate.get(), Rewriter.createEquiJoinClause(rightJoinSymbols, leftJoinSymbols), Optional.empty(), Optional.empty());
                        }
                        if (indexJoinNode == null) break;
                        if (node.getFilter().isPresent()) {
                            indexJoinNode = new FilterNode(this.idAllocator.getNextId(), indexJoinNode, node.getFilter().get());
                        }
                        if (!indexJoinNode.getOutputSymbols().equals(node.getOutputSymbols())) {
                            indexJoinNode = new ProjectNode(this.idAllocator.getNextId(), indexJoinNode, Assignments.identity(node.getOutputSymbols()));
                        }
                        return indexJoinNode;
                    }
                    case LEFT: {
                        if (!node.getFilter().isEmpty() || !rightIndexCandidate.isPresent()) break;
                        return Rewriter.createIndexJoinWithExpectedOutputs(node.getOutputSymbols(), IndexJoinNode.Type.SOURCE_OUTER, leftRewritten, rightIndexCandidate.get(), Rewriter.createEquiJoinClause(leftJoinSymbols, rightJoinSymbols), this.idAllocator);
                    }
                    case RIGHT: {
                        if (!node.getFilter().isEmpty() || !leftIndexCandidate.isPresent()) break;
                        return Rewriter.createIndexJoinWithExpectedOutputs(node.getOutputSymbols(), IndexJoinNode.Type.SOURCE_OUTER, rightRewritten, leftIndexCandidate.get(), Rewriter.createEquiJoinClause(rightJoinSymbols, leftJoinSymbols), this.idAllocator);
                    }
                    case FULL: {
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Unknown type: " + String.valueOf((Object)node.getType()));
                    }
                }
            }
            if (leftRewritten != node.getLeft() || rightRewritten != node.getRight()) {
                return new JoinNode(node.getId(), node.getType(), leftRewritten, rightRewritten, node.getCriteria(), node.getLeftOutputSymbols(), node.getRightOutputSymbols(), node.isMaySkipOutputDuplicates(), node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType(), node.isSpillable(), node.getDynamicFilters(), node.getReorderJoinStatsAndCost());
            }
            return node;
        }

        private static PlanNode createIndexJoinWithExpectedOutputs(List<Symbol> expectedOutputs, IndexJoinNode.Type type, PlanNode probe, PlanNode index, List<IndexJoinNode.EquiJoinClause> equiJoinClause, PlanNodeIdAllocator idAllocator) {
            PlanNode result = new IndexJoinNode(idAllocator.getNextId(), type, probe, index, equiJoinClause, Optional.empty(), Optional.empty());
            if (!((PlanNode)result).getOutputSymbols().equals(expectedOutputs)) {
                result = new ProjectNode(idAllocator.getNextId(), result, Assignments.identity(expectedOutputs));
            }
            return result;
        }

        private static List<IndexJoinNode.EquiJoinClause> createEquiJoinClause(List<Symbol> probeSymbols, List<Symbol> indexSymbols) {
            Preconditions.checkArgument((probeSymbols.size() == indexSymbols.size() ? 1 : 0) != 0);
            ImmutableList.Builder builder = ImmutableList.builder();
            for (int i = 0; i < probeSymbols.size(); ++i) {
                builder.add((Object)new IndexJoinNode.EquiJoinClause(probeSymbols.get(i), indexSymbols.get(i)));
            }
            return builder.build();
        }
    }

    public static final class IndexKeyTracer {
        public static Map<Symbol, Symbol> trace(PlanNode node, Set<Symbol> lookupSymbols) {
            return node.accept(new Visitor(), lookupSymbols);
        }

        private static class Visitor
        extends PlanVisitor<Map<Symbol, Symbol>, Set<Symbol>> {
            private Visitor() {
            }

            @Override
            protected Map<Symbol, Symbol> visitPlan(PlanNode node, Set<Symbol> lookupSymbols) {
                throw new UnsupportedOperationException("Node not expected to be part of Index pipeline: " + String.valueOf(node));
            }

            @Override
            public Map<Symbol, Symbol> visitProject(ProjectNode node, Set<Symbol> lookupSymbols) {
                Map directSymbolTranslationOutputMap = Maps.transformValues((Map)Maps.filterValues(node.getAssignments().getMap(), SymbolReference.class::isInstance), Symbol::from);
                Map outputToSourceMap = (Map)lookupSymbols.stream().filter(directSymbolTranslationOutputMap.keySet()::contains).collect(ImmutableMap.toImmutableMap(java.util.function.Function.identity(), directSymbolTranslationOutputMap::get));
                Preconditions.checkState((!outputToSourceMap.isEmpty() ? 1 : 0) != 0, (Object)"No lookup symbols were able to pass through the projection");
                Map<Symbol, Symbol> sourceToIndexMap = node.getSource().accept(this, ImmutableSet.copyOf(outputToSourceMap.values()));
                Map outputToIndexMap = Maps.transformValues((Map)Maps.filterValues((Map)outputToSourceMap, (Predicate)Predicates.in(sourceToIndexMap.keySet())), (Function)Functions.forMap(sourceToIndexMap));
                return ImmutableMap.copyOf((Map)outputToIndexMap);
            }

            @Override
            public Map<Symbol, Symbol> visitFilter(FilterNode node, Set<Symbol> lookupSymbols) {
                return node.getSource().accept(this, lookupSymbols);
            }

            @Override
            public Map<Symbol, Symbol> visitWindow(WindowNode node, Set<Symbol> lookupSymbols) {
                Set partitionByLookupSymbols = (Set)lookupSymbols.stream().filter(node.getPartitionBy()::contains).collect(ImmutableSet.toImmutableSet());
                Preconditions.checkState((!partitionByLookupSymbols.isEmpty() ? 1 : 0) != 0, (Object)"No lookup symbols were able to pass through the aggregation group by");
                return node.getSource().accept(this, partitionByLookupSymbols);
            }

            @Override
            public Map<Symbol, Symbol> visitIndexJoin(IndexJoinNode node, Set<Symbol> lookupSymbols) {
                Set probeLookupSymbols = (Set)lookupSymbols.stream().filter(node.getProbeSource().getOutputSymbols()::contains).collect(ImmutableSet.toImmutableSet());
                Preconditions.checkState((!probeLookupSymbols.isEmpty() ? 1 : 0) != 0, (Object)"No lookup symbols were able to pass through the index join probe source");
                return node.getProbeSource().accept(this, probeLookupSymbols);
            }

            @Override
            public Map<Symbol, Symbol> visitAggregation(AggregationNode node, Set<Symbol> lookupSymbols) {
                Set groupByLookupSymbols = (Set)lookupSymbols.stream().filter(node.getGroupingKeys()::contains).collect(ImmutableSet.toImmutableSet());
                Preconditions.checkState((!groupByLookupSymbols.isEmpty() ? 1 : 0) != 0, (Object)"No lookup symbols were able to pass through the aggregation group by");
                return node.getSource().accept(this, groupByLookupSymbols);
            }

            @Override
            public Map<Symbol, Symbol> visitSort(SortNode node, Set<Symbol> lookupSymbols) {
                return node.getSource().accept(this, lookupSymbols);
            }

            @Override
            public Map<Symbol, Symbol> visitIndexSource(IndexSourceNode node, Set<Symbol> lookupSymbols) {
                Preconditions.checkState((boolean)node.getLookupSymbols().equals(lookupSymbols), (Object)"lookupSymbols must be the same as IndexSource lookup symbols");
                return (Map)lookupSymbols.stream().collect(ImmutableMap.toImmutableMap(java.util.function.Function.identity(), java.util.function.Function.identity()));
            }
        }
    }

    private static class IndexSourceRewriter
    extends SimplePlanRewriter<Context> {
        private final SymbolAllocator symbolAllocator;
        private final PlanNodeIdAllocator idAllocator;
        private final PlannerContext plannerContext;
        private final DomainTranslator domainTranslator;
        private final Session session;

        private IndexSourceRewriter(SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, PlannerContext plannerContext, Session session) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.domainTranslator = new DomainTranslator(plannerContext);
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.session = Objects.requireNonNull(session, "session is null");
        }

        public static Optional<PlanNode> rewriteWithIndex(PlanNode planNode, Set<Symbol> lookupSymbols, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, PlannerContext plannerContext, Session session) {
            AtomicBoolean success = new AtomicBoolean();
            IndexSourceRewriter indexSourceRewriter = new IndexSourceRewriter(symbolAllocator, idAllocator, plannerContext, session);
            PlanNode rewritten = SimplePlanRewriter.rewriteWith(indexSourceRewriter, planNode, new Context(lookupSymbols, success));
            if (success.get()) {
                return Optional.of(rewritten);
            }
            return Optional.empty();
        }

        @Override
        public PlanNode visitPlan(PlanNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            return node;
        }

        @Override
        public PlanNode visitTableScan(TableScanNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            return this.planTableScan(node, (Expression)BooleanLiteral.TRUE_LITERAL, context.get());
        }

        private PlanNode planTableScan(TableScanNode node, Expression predicate, Context context) {
            DomainTranslator.ExtractionResult decomposedPredicate = DomainTranslator.getExtractionResult(this.plannerContext, this.session, predicate, this.symbolAllocator.getTypes());
            TupleDomain simplifiedConstraint = decomposedPredicate.getTupleDomain().transformKeys(node.getAssignments()::get).intersect(node.getEnforcedConstraint());
            Preconditions.checkState((boolean)node.getOutputSymbols().containsAll(context.getLookupSymbols()));
            Set lookupColumns = (Set)context.getLookupSymbols().stream().map(node.getAssignments()::get).collect(ImmutableSet.toImmutableSet());
            Set outputColumns = (Set)node.getOutputSymbols().stream().map(node.getAssignments()::get).collect(ImmutableSet.toImmutableSet());
            Optional<ResolvedIndex> optionalResolvedIndex = this.plannerContext.getMetadata().resolveIndex(this.session, node.getTable(), lookupColumns, outputColumns, (TupleDomain<ColumnHandle>)simplifiedConstraint);
            if (optionalResolvedIndex.isEmpty()) {
                return node;
            }
            ResolvedIndex resolvedIndex = optionalResolvedIndex.get();
            ImmutableBiMap inverseAssignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse();
            PlanNode source = new IndexSourceNode(this.idAllocator.getNextId(), resolvedIndex.getIndexHandle(), node.getTable(), context.getLookupSymbols(), node.getOutputSymbols(), node.getAssignments());
            Expression[] expressionArray = new Expression[2];
            expressionArray[0] = this.domainTranslator.toPredicate((TupleDomain<Symbol>)resolvedIndex.getUnresolvedTupleDomain().transformKeys(((Map)inverseAssignments)::get));
            expressionArray[1] = decomposedPredicate.getRemainingExpression();
            Expression resultingPredicate = IrUtils.combineConjuncts(this.plannerContext.getMetadata(), expressionArray);
            if (!resultingPredicate.equals((Object)BooleanLiteral.TRUE_LITERAL)) {
                source = new FilterNode(this.idAllocator.getNextId(), source, resultingPredicate);
            }
            context.markSuccess();
            return source;
        }

        @Override
        public PlanNode visitProject(ProjectNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            Set newLookupSymbols = (Set)context.get().getLookupSymbols().stream().map(node.getAssignments()::get).filter(SymbolReference.class::isInstance).map(Symbol::from).collect(ImmutableSet.toImmutableSet());
            if (newLookupSymbols.size() != context.get().getLookupSymbols().size()) {
                return node;
            }
            return context.defaultRewrite(node, new Context(newLookupSymbols, context.get().getSuccess()));
        }

        @Override
        public PlanNode visitFilter(FilterNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            if (node.getSource() instanceof TableScanNode) {
                return this.planTableScan((TableScanNode)node.getSource(), node.getPredicate(), context.get());
            }
            return context.defaultRewrite(node, new Context(context.get().getLookupSymbols(), context.get().getSuccess()));
        }

        @Override
        public PlanNode visitWindow(WindowNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            if (!node.getWindowFunctions().values().stream().map(WindowNode.Function::getResolvedFunction).map(ResolvedFunction::getFunctionKind).allMatch(arg_0 -> FunctionKind.AGGREGATE.equals(arg_0))) {
                return node;
            }
            if (node.getOrderingScheme().isPresent()) {
                return node;
            }
            if (node.getFrames().stream().map(WindowNode.Frame::getType).anyMatch(type -> type != WindowFrameType.RANGE)) {
                return node;
            }
            if (!node.getPartitionBy().containsAll(context.get().getLookupSymbols())) {
                return node;
            }
            return context.defaultRewrite(node, new Context(context.get().getLookupSymbols(), context.get().getSuccess()));
        }

        @Override
        public PlanNode visitIndexSource(IndexSourceNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            throw new IllegalStateException("Should not be trying to generate an Index on something that has already been determined to use an Index");
        }

        @Override
        public PlanNode visitIndexJoin(IndexJoinNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            if (!node.getProbeSource().getOutputSymbols().containsAll(context.get().getLookupSymbols())) {
                return node;
            }
            PlanNode rewrittenProbeSource = context.rewrite(node.getProbeSource(), new Context(context.get().getLookupSymbols(), context.get().getSuccess()));
            IndexJoinNode source = node;
            if (rewrittenProbeSource != node.getProbeSource()) {
                source = new IndexJoinNode(node.getId(), node.getType(), rewrittenProbeSource, node.getIndexSource(), node.getCriteria(), node.getProbeHashSymbol(), node.getIndexHashSymbol());
            }
            return source;
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            if (!node.getGroupingKeys().containsAll(context.get().getLookupSymbols())) {
                return node;
            }
            return context.defaultRewrite(node, new Context(context.get().getLookupSymbols(), context.get().getSuccess()));
        }

        @Override
        public PlanNode visitSort(SortNode node, SimplePlanRewriter.RewriteContext<Context> context) {
            return context.rewrite(node.getSource(), context.get());
        }

        public static class Context {
            private final Set<Symbol> lookupSymbols;
            private final AtomicBoolean success;

            public Context(Set<Symbol> lookupSymbols, AtomicBoolean success) {
                Preconditions.checkArgument((!lookupSymbols.isEmpty() ? 1 : 0) != 0, (Object)"lookupSymbols cannot be empty");
                this.lookupSymbols = ImmutableSet.copyOf((Collection)Objects.requireNonNull(lookupSymbols, "lookupSymbols is null"));
                this.success = Objects.requireNonNull(success, "success is null");
            }

            public Set<Symbol> getLookupSymbols() {
                return this.lookupSymbols;
            }

            public AtomicBoolean getSuccess() {
                return this.success;
            }

            public void markSuccess() {
                Preconditions.checkState((boolean)this.success.compareAndSet(false, true), (Object)"Can only have one success per context");
            }
        }
    }
}

