/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.ChildReplacer;
import io.trino.sql.planner.plan.PlanNode;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;

public class PlanNodeSearcher {
    private final PlanNode node;
    private final Lookup lookup;
    private Predicate<PlanNode> where = Predicates.alwaysTrue();
    private Predicate<PlanNode> recurseOnlyWhen = Predicates.alwaysTrue();

    public static PlanNodeSearcher searchFrom(PlanNode node) {
        return PlanNodeSearcher.searchFrom(node, Lookup.noLookup());
    }

    public static PlanNodeSearcher searchFrom(PlanNode node, Lookup lookup) {
        return new PlanNodeSearcher(node, lookup);
    }

    private PlanNodeSearcher(PlanNode node, Lookup lookup) {
        this.node = Objects.requireNonNull(node, "node is null");
        this.lookup = Objects.requireNonNull(lookup, "lookup is null");
    }

    @SafeVarargs
    public final PlanNodeSearcher whereIsInstanceOfAny(Class<? extends PlanNode> ... classes) {
        return this.whereIsInstanceOfAny(Arrays.asList(classes));
    }

    public final PlanNodeSearcher whereIsInstanceOfAny(List<Class<? extends PlanNode>> classes) {
        Object predicate = Predicates.alwaysFalse();
        for (Class<? extends PlanNode> clazz : classes) {
            predicate = predicate.or(clazz::isInstance);
        }
        return this.where((Predicate<PlanNode>)predicate);
    }

    public PlanNodeSearcher where(Predicate<PlanNode> where) {
        this.where = Objects.requireNonNull(where, "where is null");
        return this;
    }

    public PlanNodeSearcher recurseOnlyWhen(Predicate<PlanNode> skipOnly) {
        this.recurseOnlyWhen = Objects.requireNonNull(skipOnly, "skipOnly is null");
        return this;
    }

    public Optional<PlanNode> findFirst() {
        return this.findFirstRecursive(this.node);
    }

    private Optional<PlanNode> findFirstRecursive(PlanNode node) {
        if (this.where.test(node = this.lookup.resolve(node))) {
            return Optional.of(node);
        }
        if (this.recurseOnlyWhen.test(node)) {
            for (PlanNode source : node.getSources()) {
                Optional<PlanNode> found = this.findFirstRecursive(source);
                if (!found.isPresent()) continue;
                return found;
            }
        }
        return Optional.empty();
    }

    public List<PlanNode> findAll() {
        ImmutableList.Builder nodes = ImmutableList.builder();
        this.findAllRecursive(this.node, (ImmutableList.Builder<PlanNode>)nodes);
        return nodes.build();
    }

    public PlanNode findOnlyElement() {
        return (PlanNode)Iterables.getOnlyElement(this.findAll());
    }

    private void findAllRecursive(PlanNode node, ImmutableList.Builder<PlanNode> nodes) {
        if (this.where.test(node = this.lookup.resolve(node))) {
            nodes.add((Object)node);
        }
        if (this.recurseOnlyWhen.test(node)) {
            for (PlanNode source : node.getSources()) {
                this.findAllRecursive(source, nodes);
            }
        }
    }

    public PlanNode removeAll() {
        return this.removeAllRecursive(this.node);
    }

    private PlanNode removeAllRecursive(PlanNode node) {
        if (this.where.test(node = this.lookup.resolve(node))) {
            Preconditions.checkArgument((node.getSources().size() == 1 ? 1 : 0) != 0, (Object)"Unable to remove plan node as it contains 0 or more than 1 children");
            return (PlanNode)Iterables.getOnlyElement(node.getSources());
        }
        if (this.recurseOnlyWhen.test(node)) {
            List sources = (List)node.getSources().stream().map(this::removeAllRecursive).collect(ImmutableList.toImmutableList());
            return ChildReplacer.replaceChildren(node, sources);
        }
        return node;
    }

    public PlanNode removeFirst() {
        return this.removeFirstRecursive(this.node);
    }

    private PlanNode removeFirstRecursive(PlanNode node) {
        if (this.where.test(node = this.lookup.resolve(node))) {
            Preconditions.checkArgument((node.getSources().size() == 1 ? 1 : 0) != 0, (Object)"Unable to remove plan node as it contains 0 or more than 1 children");
            return (PlanNode)Iterables.getOnlyElement(node.getSources());
        }
        if (this.recurseOnlyWhen.test(node)) {
            List<PlanNode> sources = node.getSources();
            if (sources.isEmpty()) {
                return node;
            }
            if (sources.size() == 1) {
                return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)this.removeFirstRecursive((PlanNode)Iterables.getOnlyElement(sources))));
            }
            throw new IllegalArgumentException("Unable to remove first node when a node has multiple children, use removeAll instead");
        }
        return node;
    }

    public PlanNode replaceAll(PlanNode newPlanNode) {
        return this.replaceAllRecursive(this.node, newPlanNode);
    }

    private PlanNode replaceAllRecursive(PlanNode node, PlanNode nodeToReplace) {
        if (this.where.test(node = this.lookup.resolve(node))) {
            return nodeToReplace;
        }
        if (this.recurseOnlyWhen.test(node)) {
            List sources = (List)node.getSources().stream().map(source -> this.replaceAllRecursive((PlanNode)source, nodeToReplace)).collect(ImmutableList.toImmutableList());
            return ChildReplacer.replaceChildren(node, sources);
        }
        return node;
    }

    public PlanNode replaceFirst(PlanNode newPlanNode) {
        return this.replaceFirstRecursive(this.node, newPlanNode);
    }

    private PlanNode replaceFirstRecursive(PlanNode node, PlanNode nodeToReplace) {
        if (this.where.test(node = this.lookup.resolve(node))) {
            return nodeToReplace;
        }
        List<PlanNode> sources = node.getSources();
        if (sources.isEmpty()) {
            return node;
        }
        if (sources.size() == 1) {
            return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)this.replaceFirstRecursive(node, (PlanNode)Iterables.getOnlyElement(sources))));
        }
        throw new IllegalArgumentException("Unable to replace first node when a node has multiple children, use replaceAll instead");
    }

    public boolean matches() {
        return this.findFirst().isPresent();
    }

    public int count() {
        return this.findAll().size();
    }
}

