/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.trino.cost.PlanCostEstimate;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.iterative.Memo;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import java.util.List;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestMemo {
    private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();

    @Test
    public void testInitialization() {
        GenericNode plan = this.node(this.node(new PlanNode[0]));
        Memo memo = new Memo(this.idAllocator, (PlanNode)plan);
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(2);
        TestMemo.assertMatchesStructure(plan, memo.extract());
    }

    @Test
    public void testReplaceSubtree() {
        GenericNode plan = this.node(this.node(this.node(new PlanNode[0])));
        Memo memo = new Memo(this.idAllocator, (PlanNode)plan);
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(3);
        GenericNode transformed = this.node(this.node(new PlanNode[0]));
        memo.replace(this.getChildGroup(memo, memo.getRootGroup()), (PlanNode)transformed, "rule");
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(3);
        TestMemo.assertMatchesStructure(memo.extract(), this.node(plan.getId(), transformed));
    }

    @Test
    public void testReplaceNode() {
        GenericNode z = this.node(new PlanNode[0]);
        GenericNode y = this.node(z);
        GenericNode x = this.node(y);
        Memo memo = new Memo(this.idAllocator, (PlanNode)x);
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(3);
        int yGroup = this.getChildGroup(memo, memo.getRootGroup());
        GroupReference zRef = (GroupReference)Iterables.getOnlyElement((Iterable)memo.getNode(yGroup).getSources());
        GenericNode transformed = this.node(new PlanNode[]{zRef});
        memo.replace(yGroup, (PlanNode)transformed, "rule");
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(3);
        TestMemo.assertMatchesStructure(memo.extract(), this.node(x.getId(), this.node(transformed.getId(), z)));
    }

    @Test
    public void testReplaceNonLeafSubtree() {
        GenericNode w = this.node(new PlanNode[0]);
        GenericNode z = this.node(w);
        GenericNode y = this.node(z);
        GenericNode x = this.node(y);
        Memo memo = new Memo(this.idAllocator, (PlanNode)x);
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(4);
        int yGroup = this.getChildGroup(memo, memo.getRootGroup());
        int zGroup = this.getChildGroup(memo, yGroup);
        PlanNode rewrittenW = (PlanNode)memo.getNode(zGroup).getSources().get(0);
        GenericNode newZ = this.node(rewrittenW);
        GenericNode newY = this.node(newZ);
        memo.replace(yGroup, (PlanNode)newY, "rule");
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(4);
        TestMemo.assertMatchesStructure(memo.extract(), this.node(x.getId(), this.node(newY.getId(), this.node(newZ.getId(), this.node(w.getId(), new PlanNode[0])))));
    }

    @Test
    public void testRemoveNode() {
        GenericNode z = this.node(new PlanNode[0]);
        GenericNode y = this.node(z);
        GenericNode x = this.node(y);
        Memo memo = new Memo(this.idAllocator, (PlanNode)x);
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(3);
        int yGroup = this.getChildGroup(memo, memo.getRootGroup());
        memo.replace(yGroup, (PlanNode)memo.getNode(yGroup).getSources().get(0), "rule");
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(2);
        TestMemo.assertMatchesStructure(memo.extract(), this.node(x.getId(), this.node(z.getId(), new PlanNode[0])));
    }

    @Test
    public void testInsertNode() {
        GenericNode z = this.node(new PlanNode[0]);
        GenericNode x = this.node(z);
        Memo memo = new Memo(this.idAllocator, (PlanNode)x);
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(2);
        int zGroup = this.getChildGroup(memo, memo.getRootGroup());
        GenericNode y = this.node(memo.getNode(zGroup));
        memo.replace(zGroup, (PlanNode)y, "rule");
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(3);
        TestMemo.assertMatchesStructure(memo.extract(), this.node(x.getId(), this.node(y.getId(), this.node(z.getId(), new PlanNode[0]))));
    }

    @Test
    public void testMultipleReferences() {
        GenericNode z = this.node(new PlanNode[0]);
        GenericNode y = this.node(z);
        GenericNode x = this.node(y);
        Memo memo = new Memo(this.idAllocator, (PlanNode)x);
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(3);
        int yGroup = this.getChildGroup(memo, memo.getRootGroup());
        PlanNode rewrittenZ = (PlanNode)memo.getNode(yGroup).getSources().get(0);
        GenericNode y1 = this.node(rewrittenZ);
        GenericNode y2 = this.node(rewrittenZ);
        GenericNode newX = this.node(y1, y2);
        memo.replace(memo.getRootGroup(), (PlanNode)newX, "rule");
        Assertions.assertThat((int)memo.getGroupCount()).isEqualTo(4);
        TestMemo.assertMatchesStructure(memo.extract(), this.node(newX.getId(), this.node(y1.getId(), this.node(z.getId(), new PlanNode[0])), this.node(y2.getId(), this.node(z.getId(), new PlanNode[0]))));
    }

    @Test
    public void testEvictStatsOnReplace() {
        GenericNode y = this.node(new PlanNode[0]);
        GenericNode x = this.node(y);
        Memo memo = new Memo(this.idAllocator, (PlanNode)x);
        int xGroup = memo.getRootGroup();
        int yGroup = this.getChildGroup(memo, memo.getRootGroup());
        PlanNodeStatsEstimate xStats = PlanNodeStatsEstimate.builder().setOutputRowCount(42.0).build();
        PlanNodeStatsEstimate yStats = PlanNodeStatsEstimate.builder().setOutputRowCount(55.0).build();
        memo.storeStats(yGroup, yStats);
        memo.storeStats(xGroup, xStats);
        Assertions.assertThat((Optional)memo.getStats(yGroup)).isEqualTo(Optional.of(yStats));
        Assertions.assertThat((Optional)memo.getStats(xGroup)).isEqualTo(Optional.of(xStats));
        memo.replace(yGroup, (PlanNode)this.node(new PlanNode[0]), "rule");
        Assertions.assertThat((Optional)memo.getStats(yGroup)).isEqualTo(Optional.empty());
        Assertions.assertThat((Optional)memo.getStats(xGroup)).isEqualTo(Optional.empty());
    }

    @Test
    public void testEvictCostOnReplace() {
        GenericNode y = this.node(new PlanNode[0]);
        GenericNode x = this.node(y);
        Memo memo = new Memo(this.idAllocator, (PlanNode)x);
        int xGroup = memo.getRootGroup();
        int yGroup = this.getChildGroup(memo, memo.getRootGroup());
        PlanCostEstimate yCost = new PlanCostEstimate(42.0, 0.0, 0.0, 0.0);
        PlanCostEstimate xCost = new PlanCostEstimate(42.0, 0.0, 0.0, 37.0);
        memo.storeCost(yGroup, yCost);
        memo.storeCost(xGroup, xCost);
        Assertions.assertThat((Optional)memo.getCost(yGroup)).isEqualTo(Optional.of(yCost));
        Assertions.assertThat((Optional)memo.getCost(xGroup)).isEqualTo(Optional.of(xCost));
        memo.replace(yGroup, (PlanNode)this.node(new PlanNode[0]), "rule");
        Assertions.assertThat((Optional)memo.getCost(yGroup)).isEqualTo(Optional.empty());
        Assertions.assertThat((Optional)memo.getCost(xGroup)).isEqualTo(Optional.empty());
    }

    private static void assertMatchesStructure(PlanNode actual, PlanNode expected) {
        Assertions.assertThat(actual.getClass()).isEqualTo(expected.getClass());
        Assertions.assertThat((Object)actual.getId()).isEqualTo((Object)expected.getId());
        Assertions.assertThat((List)actual.getSources()).hasSize(expected.getSources().size());
        for (int i = 0; i < actual.getSources().size(); ++i) {
            TestMemo.assertMatchesStructure((PlanNode)actual.getSources().get(i), (PlanNode)expected.getSources().get(i));
        }
    }

    private int getChildGroup(Memo memo, int group) {
        PlanNode node = memo.getNode(group);
        GroupReference child = (GroupReference)node.getSources().get(0);
        return child.getGroupId();
    }

    private GenericNode node(PlanNodeId id, PlanNode ... children) {
        return new GenericNode(id, (List<PlanNode>)ImmutableList.copyOf((Object[])children));
    }

    private GenericNode node(PlanNode ... children) {
        return this.node(this.idAllocator.getNextId(), children);
    }

    private static class GenericNode
    extends PlanNode {
        private final List<PlanNode> sources;

        public GenericNode(PlanNodeId id, List<PlanNode> sources) {
            super(id);
            this.sources = ImmutableList.copyOf(sources);
        }

        public List<PlanNode> getSources() {
            return this.sources;
        }

        public List<Symbol> getOutputSymbols() {
            return ImmutableList.of();
        }

        public PlanNode replaceChildren(List<PlanNode> newChildren) {
            return new GenericNode(this.getId(), newChildren);
        }
    }
}

