/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.connector.system.GlobalSystemConnector;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.Plugin;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionKind;
import io.trino.spi.function.FunctionNullability;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.ExpressionMatcher;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.PruneSpatialJoinChildrenColumns;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.junit.jupiter.api.Test;

public class TestPruneSpatialJoinChildrenColumns
extends BaseRuleTest {
    public static final ResolvedFunction TEST_ST_DISTANCE_FUNCTION = new ResolvedFunction(new BoundSignature(GlobalFunctionCatalog.builtinFunctionName((String)"st_distance"), (Type)BigintType.BIGINT, (List)ImmutableList.of((Object)BigintType.BIGINT, (Object)BigintType.BIGINT)), GlobalSystemConnector.CATALOG_HANDLE, new FunctionId("st_distance"), FunctionKind.SCALAR, true, new FunctionNullability(false, (List)ImmutableList.of((Object)false, (Object)false)), (Map)ImmutableMap.of(), (Set)ImmutableSet.of());

    public TestPruneSpatialJoinChildrenColumns() {
        super(new Plugin[0]);
    }

    @Test
    public void testPruneOneChild() {
        this.tester().assertThat((Rule<?>)new PruneSpatialJoinChildrenColumns()).on(p -> {
            Symbol a = p.symbol("a");
            Symbol b = p.symbol("b");
            Symbol r = p.symbol("r");
            Symbol unused = p.symbol("unused");
            return p.spatialJoin(SpatialJoinNode.Type.INNER, (PlanNode)p.values(a, unused), (PlanNode)p.values(b, r), (List<Symbol>)ImmutableList.of((Object)a, (Object)b, (Object)r), (Expression)new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, (Expression)new FunctionCall(TEST_ST_DISTANCE_FUNCTION.toQualifiedName(), (List)ImmutableList.of((Object)a.toSymbolReference(), (Object)b.toSymbolReference())), (Expression)r.toSymbolReference()));
        }).matches(PlanMatchPattern.spatialJoin("ST_Distance(a, b) <= r", Optional.empty(), Optional.of(ImmutableList.of((Object)"a", (Object)"b", (Object)"r")), PlanMatchPattern.strictProject((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"a", (Object)PlanMatchPattern.expression("a")), PlanMatchPattern.values("a", "unused")), PlanMatchPattern.values("b", "r")));
    }

    @Test
    public void testPruneBothChildren() {
        this.tester().assertThat((Rule<?>)new PruneSpatialJoinChildrenColumns()).on(p -> {
            Symbol a = p.symbol("a");
            Symbol b = p.symbol("b");
            Symbol r = p.symbol("r");
            Symbol unusedLeft = p.symbol("unused_left");
            Symbol unusedRight = p.symbol("unused_right");
            return p.spatialJoin(SpatialJoinNode.Type.INNER, (PlanNode)p.values(a, unusedLeft), (PlanNode)p.values(b, r, unusedRight), (List<Symbol>)ImmutableList.of((Object)a, (Object)b, (Object)r), (Expression)new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, (Expression)new FunctionCall(TEST_ST_DISTANCE_FUNCTION.toQualifiedName(), (List)ImmutableList.of((Object)a.toSymbolReference(), (Object)b.toSymbolReference())), (Expression)r.toSymbolReference()));
        }).matches(PlanMatchPattern.spatialJoin("ST_Distance(a, b) <= r", Optional.empty(), Optional.of(ImmutableList.of((Object)"a", (Object)"b", (Object)"r")), PlanMatchPattern.strictProject((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"a", (Object)PlanMatchPattern.expression("a")), PlanMatchPattern.values("a", "unused_left")), PlanMatchPattern.strictProject((Map<String, ExpressionMatcher>)ImmutableMap.of((Object)"b", (Object)PlanMatchPattern.expression("b"), (Object)"r", (Object)PlanMatchPattern.expression("r")), PlanMatchPattern.values("b", "r", "unused_right"))));
    }

    @Test
    public void testDoNotPruneOneOutputOrFilterSymbols() {
        this.tester().assertThat((Rule<?>)new PruneSpatialJoinChildrenColumns()).on(p -> {
            Symbol a = p.symbol("a");
            Symbol b = p.symbol("b");
            Symbol r = p.symbol("r");
            Symbol output = p.symbol("output");
            return p.spatialJoin(SpatialJoinNode.Type.INNER, (PlanNode)p.values(a), (PlanNode)p.values(b, r, output), (List<Symbol>)ImmutableList.of((Object)output), PlanBuilder.expression("ST_Distance(a, b) <= r"));
        }).doesNotFire();
    }

    @Test
    public void testDoNotPrunePartitionSymbols() {
        this.tester().assertThat((Rule<?>)new PruneSpatialJoinChildrenColumns()).on(p -> {
            Symbol a = p.symbol("a");
            Symbol b = p.symbol("b");
            Symbol r = p.symbol("r");
            Symbol leftPartitionSymbol = p.symbol("left_partition_symbol");
            Symbol rightPartitionSymbol = p.symbol("right_partition_symbol");
            return p.spatialJoin(SpatialJoinNode.Type.INNER, (PlanNode)p.values(a, leftPartitionSymbol), (PlanNode)p.values(b, r, rightPartitionSymbol), (List<Symbol>)ImmutableList.of((Object)a, (Object)b, (Object)r), PlanBuilder.expression("ST_Distance(a, b) <= r"), Optional.of(leftPartitionSymbol), Optional.of(rightPartitionSymbol), Optional.of("some nice kdb tree"));
        }).doesNotFire();
    }
}

