/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.ExpressionNodeInliner;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.SymbolReference;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

public class MergeProjectWithValues
implements Rule<ProjectNode> {
    private static final Capture<ValuesNode> VALUES = Capture.newCapture();
    private static final Pattern<ProjectNode> PATTERN = Patterns.project().with(Patterns.source().matching(Patterns.values().matching(MergeProjectWithValues::isSupportedValues).capturedAs(VALUES)));
    private final Metadata metadata;

    public MergeProjectWithValues(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public Pattern<ProjectNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isMergeProjectWithValues(session);
    }

    @Override
    public Rule.Result apply(ProjectNode node, Captures captures, Rule.Context context) {
        ValuesNode valuesNode = (ValuesNode)captures.get(VALUES);
        if (node.getOutputSymbols().isEmpty()) {
            return Rule.Result.ofPlanNode(new ValuesNode(valuesNode.getId(), valuesNode.getRowCount()));
        }
        ImmutableList assignments = ImmutableList.copyOf(node.getAssignments().entrySet());
        List outputs = (List)assignments.stream().map(Map.Entry::getKey).collect(ImmutableList.toImmutableList());
        List expressions = (List)assignments.stream().map(Map.Entry::getValue).collect(ImmutableList.toImmutableList());
        if (valuesNode.getOutputSymbols().isEmpty()) {
            return Rule.Result.ofPlanNode(new ValuesNode(valuesNode.getId(), outputs, Collections.nCopies(valuesNode.getRowCount(), new Row((List)ImmutableList.copyOf((Collection)expressions)))));
        }
        HashSet<Symbol> nonDeterministicValuesOutputs = new HashSet<Symbol>();
        for (Expression rowExpression : valuesNode.getRows().get()) {
            Row row = (Row)rowExpression;
            for (int i = 0; i < valuesNode.getOutputSymbols().size(); ++i) {
                if (DeterminismEvaluator.isDeterministic((Expression)row.getItems().get(i), this.metadata)) continue;
                nonDeterministicValuesOutputs.add(valuesNode.getOutputSymbols().get(i));
            }
        }
        Set multipleReferencedSymbols = (Set)expressions.stream().flatMap(expression -> SymbolsExtractor.extractAll(expression).stream()).collect(Collectors.groupingBy(Function.identity(), Collectors.counting())).entrySet().stream().filter(entry -> (Long)entry.getValue() > 1L).map(Map.Entry::getKey).collect(ImmutableSet.toImmutableSet());
        if (!Sets.intersection(nonDeterministicValuesOutputs, (Set)multipleReferencedSymbols).isEmpty()) {
            return Rule.Result.empty();
        }
        ImmutableList.Builder projectedRows = ImmutableList.builder();
        for (Expression rowExpression : valuesNode.getRows().get()) {
            Map<SymbolReference, Expression> mapping = this.buildMappings(valuesNode.getOutputSymbols(), (Row)rowExpression);
            Row projectedRow = new Row((List)expressions.stream().map(expression -> ExpressionNodeInliner.replaceExpression(expression, mapping)).collect(ImmutableList.toImmutableList()));
            projectedRows.add((Object)projectedRow);
        }
        return Rule.Result.ofPlanNode(new ValuesNode(valuesNode.getId(), outputs, (List<Expression>)projectedRows.build()));
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static boolean isSupportedValues(ValuesNode valuesNode) {
        if (valuesNode.getRows().isEmpty()) return true;
        if (!valuesNode.getRows().get().stream().allMatch(Row.class::isInstance)) return false;
        return true;
    }

    private Map<SymbolReference, Expression> buildMappings(List<Symbol> symbols, Row row) {
        ImmutableMap.Builder mappingBuilder = ImmutableMap.builder();
        for (int i = 0; i < row.getItems().size(); ++i) {
            mappingBuilder.put((Object)symbols.get(i).toSymbolReference(), (Object)((Expression)row.getItems().get(i)));
        }
        return mappingBuilder.buildOrThrow();
    }
}

