/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import io.trino.SessionTestUtils;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TestingPlannerContext;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.FrameBoundType;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.WindowFrameType;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.sanity.TypeValidator;
import io.trino.testing.TestingHandles;
import io.trino.testing.TestingMetadata;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestTypeValidator {
    private static final TypeValidator TYPE_VALIDATOR = new TypeValidator();
    private final TestingFunctionResolution functionResolution = new TestingFunctionResolution();
    private final SymbolAllocator symbolAllocator = new SymbolAllocator();
    private final Symbol columnA = this.symbolAllocator.newSymbol("a", (Type)BigintType.BIGINT);
    private final Symbol columnB = this.symbolAllocator.newSymbol("b", (Type)IntegerType.INTEGER);
    private final Symbol columnC = this.symbolAllocator.newSymbol("c", (Type)DoubleType.DOUBLE);
    private final Symbol columnD = this.symbolAllocator.newSymbol("d", (Type)DateType.DATE);
    private final Symbol columnE = this.symbolAllocator.newSymbol("e", (Type)VarcharType.createVarcharType((int)3));
    private final TableScanNode baseTableScan = new TableScanNode(TestTypeValidator.newId(), TestingHandles.TEST_TABLE_HANDLE, (List)ImmutableList.copyOf(ImmutableMap.builder().put((Object)this.columnA, (Object)new TestingMetadata.TestingColumnHandle("a")).put((Object)this.columnB, (Object)new TestingMetadata.TestingColumnHandle("b")).put((Object)this.columnC, (Object)new TestingMetadata.TestingColumnHandle("c")).put((Object)this.columnD, (Object)new TestingMetadata.TestingColumnHandle("d")).put((Object)this.columnE, (Object)new TestingMetadata.TestingColumnHandle("e")).buildOrThrow().keySet()), (Map)ImmutableMap.builder().put((Object)this.columnA, (Object)new TestingMetadata.TestingColumnHandle("a")).put((Object)this.columnB, (Object)new TestingMetadata.TestingColumnHandle("b")).put((Object)this.columnC, (Object)new TestingMetadata.TestingColumnHandle("c")).put((Object)this.columnD, (Object)new TestingMetadata.TestingColumnHandle("d")).put((Object)this.columnE, (Object)new TestingMetadata.TestingColumnHandle("e")).buildOrThrow(), TupleDomain.all(), Optional.empty(), false, Optional.empty());

    @Test
    public void testValidProject() {
        Cast expression1 = new Cast((Expression)this.columnB.toSymbolReference(), (Type)BigintType.BIGINT);
        Cast expression2 = new Cast((Expression)this.columnC.toSymbolReference(), (Type)BigintType.BIGINT);
        Assignments assignments = Assignments.builder().put(this.symbolAllocator.newSymbol((Expression)expression1), (Expression)expression1).put(this.symbolAllocator.newSymbol((Expression)expression2), (Expression)expression2).build();
        ProjectNode node = new ProjectNode(TestTypeValidator.newId(), (PlanNode)this.baseTableScan, assignments);
        this.assertTypesValid((PlanNode)node);
    }

    @Test
    public void testValidUnion() {
        Symbol outputSymbol = this.symbolAllocator.newSymbol("output", (Type)DateType.DATE);
        ImmutableListMultimap mappings = ImmutableListMultimap.builder().put((Object)outputSymbol, (Object)this.columnD).put((Object)outputSymbol, (Object)this.columnD).build();
        UnionNode node = new UnionNode(TestTypeValidator.newId(), (List)ImmutableList.of((Object)this.baseTableScan, (Object)this.baseTableScan), (ListMultimap)mappings, (List)ImmutableList.copyOf((Collection)mappings.keySet()));
        this.assertTypesValid((PlanNode)node);
    }

    @Test
    public void testValidWindow() {
        Symbol windowSymbol = this.symbolAllocator.newSymbol("sum", (Type)DoubleType.DOUBLE);
        ResolvedFunction resolvedFunction = this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE}));
        WindowNode.Frame frame = new WindowNode.Frame(WindowFrameType.RANGE, FrameBoundType.UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), FrameBoundType.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty());
        WindowNode.Function function = new WindowNode.Function(resolvedFunction, (List)ImmutableList.of((Object)this.columnC.toSymbolReference()), Optional.empty(), frame, false, false);
        DataOrganizationSpecification specification = new DataOrganizationSpecification((List)ImmutableList.of(), Optional.empty());
        WindowNode node = new WindowNode(TestTypeValidator.newId(), (PlanNode)this.baseTableScan, specification, (Map)ImmutableMap.of((Object)windowSymbol, (Object)function), Optional.empty(), (Set)ImmutableSet.of(), 0);
        this.assertTypesValid((PlanNode)node);
    }

    @Test
    public void testValidAggregation() {
        Symbol aggregationSymbol = this.symbolAllocator.newSymbol("sum", (Type)DoubleType.DOUBLE);
        AggregationNode node = AggregationNode.singleAggregation((PlanNodeId)TestTypeValidator.newId(), (PlanNode)this.baseTableScan, (Map)ImmutableMap.of((Object)aggregationSymbol, (Object)new AggregationNode.Aggregation(this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE})), (List)ImmutableList.of((Object)this.columnC.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), (AggregationNode.GroupingSetDescriptor)AggregationNode.singleGroupingSet((List)ImmutableList.of((Object)this.columnA, (Object)this.columnB)));
        this.assertTypesValid((PlanNode)node);
    }

    @Test
    public void testInvalidAggregationFunctionCall() {
        Symbol aggregationSymbol = this.symbolAllocator.newSymbol("sum", (Type)DoubleType.DOUBLE);
        AggregationNode node = AggregationNode.singleAggregation((PlanNodeId)TestTypeValidator.newId(), (PlanNode)this.baseTableScan, (Map)ImmutableMap.of((Object)aggregationSymbol, (Object)new AggregationNode.Aggregation(this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE})), (List)ImmutableList.of((Object)this.columnA.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), (AggregationNode.GroupingSetDescriptor)AggregationNode.singleGroupingSet((List)ImmutableList.of((Object)this.columnA, (Object)this.columnB)));
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> this.lambda$testInvalidAggregationFunctionCall$0((PlanNode)node)).isInstanceOf(IllegalArgumentException.class)).hasMessageMatching("type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint");
    }

    @Test
    public void testInvalidAggregationFunctionSignature() {
        Symbol aggregationSymbol = this.symbolAllocator.newSymbol("sum", (Type)BigintType.BIGINT);
        AggregationNode node = AggregationNode.singleAggregation((PlanNodeId)TestTypeValidator.newId(), (PlanNode)this.baseTableScan, (Map)ImmutableMap.of((Object)aggregationSymbol, (Object)new AggregationNode.Aggregation(this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE})), (List)ImmutableList.of((Object)this.columnC.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), (AggregationNode.GroupingSetDescriptor)AggregationNode.singleGroupingSet((List)ImmutableList.of((Object)this.columnA, (Object)this.columnB)));
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> this.lambda$testInvalidAggregationFunctionSignature$0((PlanNode)node)).isInstanceOf(IllegalArgumentException.class)).hasMessageMatching("type of symbol 'sum(_[0-9]+)?' is expected to be bigint, but the actual type is double");
    }

    @Test
    public void testInvalidWindowFunctionCall() {
        Symbol windowSymbol = this.symbolAllocator.newSymbol("sum", (Type)DoubleType.DOUBLE);
        ResolvedFunction resolvedFunction = this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE}));
        WindowNode.Frame frame = new WindowNode.Frame(WindowFrameType.RANGE, FrameBoundType.UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), FrameBoundType.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty());
        WindowNode.Function function = new WindowNode.Function(resolvedFunction, (List)ImmutableList.of((Object)this.columnA.toSymbolReference()), Optional.empty(), frame, false, false);
        DataOrganizationSpecification specification = new DataOrganizationSpecification((List)ImmutableList.of(), Optional.empty());
        WindowNode node = new WindowNode(TestTypeValidator.newId(), (PlanNode)this.baseTableScan, specification, (Map)ImmutableMap.of((Object)windowSymbol, (Object)function), Optional.empty(), (Set)ImmutableSet.of(), 0);
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> this.lambda$testInvalidWindowFunctionCall$0((PlanNode)node)).isInstanceOf(IllegalArgumentException.class)).hasMessageMatching("type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint");
    }

    @Test
    public void testInvalidWindowFunctionSignature() {
        Symbol windowSymbol = this.symbolAllocator.newSymbol("sum", (Type)BigintType.BIGINT);
        ResolvedFunction resolvedFunction = this.functionResolution.resolveFunction("sum", TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE}));
        WindowNode.Frame frame = new WindowNode.Frame(WindowFrameType.RANGE, FrameBoundType.UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), FrameBoundType.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty());
        WindowNode.Function function = new WindowNode.Function(resolvedFunction, (List)ImmutableList.of((Object)this.columnC.toSymbolReference()), Optional.empty(), frame, false, false);
        DataOrganizationSpecification specification = new DataOrganizationSpecification((List)ImmutableList.of(), Optional.empty());
        WindowNode node = new WindowNode(TestTypeValidator.newId(), (PlanNode)this.baseTableScan, specification, (Map)ImmutableMap.of((Object)windowSymbol, (Object)function), Optional.empty(), (Set)ImmutableSet.of(), 0);
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> this.lambda$testInvalidWindowFunctionSignature$0((PlanNode)node)).isInstanceOf(IllegalArgumentException.class)).hasMessageMatching("type of symbol 'sum(_[0-9]+)?' is expected to be bigint, but the actual type is double");
    }

    @Test
    public void testInvalidUnion() {
        Symbol outputSymbol = this.symbolAllocator.newSymbol("output", (Type)DateType.DATE);
        ImmutableListMultimap mappings = ImmutableListMultimap.builder().put((Object)outputSymbol, (Object)this.columnD).put((Object)outputSymbol, (Object)this.columnA).build();
        UnionNode node = new UnionNode(TestTypeValidator.newId(), (List)ImmutableList.of((Object)this.baseTableScan, (Object)this.baseTableScan), (ListMultimap)mappings, (List)ImmutableList.copyOf((Collection)mappings.keySet()));
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> this.lambda$testInvalidUnion$0((PlanNode)node)).isInstanceOf(IllegalArgumentException.class)).hasMessageMatching("type of symbol 'output(_[0-9]+)?' is expected to be date, but the actual type is bigint");
    }

    private void assertTypesValid(PlanNode node) {
        TYPE_VALIDATOR.validate(node, SessionTestUtils.TEST_SESSION, TestingPlannerContext.PLANNER_CONTEXT, WarningCollector.NOOP);
    }

    private static PlanNodeId newId() {
        return new PlanNodeId(UUID.randomUUID().toString());
    }

    private /* synthetic */ void lambda$testInvalidUnion$0(PlanNode node) throws Throwable {
        this.assertTypesValid(node);
    }

    private /* synthetic */ void lambda$testInvalidWindowFunctionSignature$0(PlanNode node) throws Throwable {
        this.assertTypesValid(node);
    }

    private /* synthetic */ void lambda$testInvalidWindowFunctionCall$0(PlanNode node) throws Throwable {
        this.assertTypesValid(node);
    }

    private /* synthetic */ void lambda$testInvalidAggregationFunctionSignature$0(PlanNode node) throws Throwable {
        this.assertTypesValid(node);
    }

    private /* synthetic */ void lambda$testInvalidAggregationFunctionCall$0(PlanNode node) throws Throwable {
        this.assertTypesValid(node);
    }
}

