/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Multimap;
import io.trino.metadata.Metadata;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.ExpressionNodeInliner;
import io.trino.sql.planner.NullabilityAnalyzer;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.util.DisjointSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;

public class EqualityInference {
    private final Comparator<Expression> canonicalComparator;
    private final Multimap<Expression, Expression> equalitySets;
    private final Map<Expression, Expression> canonicalMap;
    private final Set<Expression> derivedExpressions;
    private final Map<Expression, List<Expression>> expressionCache = new HashMap<Expression, List<Expression>>();
    private final Map<Expression, List<Symbol>> symbolsCache = new HashMap<Expression, List<Symbol>>();
    private final Map<Expression, Set<Symbol>> uniqueSymbolsCache = new HashMap<Expression, Set<Symbol>>();

    public EqualityInference(Metadata metadata, Expression ... expressions) {
        this(metadata, Arrays.asList(expressions));
    }

    public EqualityInference(Metadata metadata, Collection<Expression> expressions) {
        DisjointSet<Expression> equalities = new DisjointSet<Expression>();
        expressions.stream().flatMap(expression -> IrUtils.extractConjuncts(expression).stream()).filter(expression -> EqualityInference.isInferenceCandidate(metadata, expression)).forEach(expression -> {
            ComparisonExpression comparison = (ComparisonExpression)expression;
            Expression expression1 = comparison.getLeft();
            Expression expression2 = comparison.getRight();
            equalities.findAndUnion(expression1, expression2);
        });
        Collection equivalentClasses = equalities.getEquivalentClasses();
        LinkedHashMap byExpression = new LinkedHashMap();
        for (Set<Expression> set : equivalentClasses) {
            set.forEach(expression -> byExpression.put(expression, equivalence));
        }
        LinkedHashSet<Expression> derivedExpressions = new LinkedHashSet<Expression>();
        for (Expression expression2 : byExpression.keySet()) {
            if (derivedExpressions.contains(expression2)) continue;
            this.extractSubExpressions(expression2).stream().filter(e -> !e.equals(expression2)).forEach(subExpression -> ((Set)byExpression.getOrDefault(subExpression, ImmutableSet.of())).stream().filter(e -> !e.equals(subExpression)).forEach(equivalentSubExpression -> {
                Expression rewritten = ExpressionNodeInliner.replaceExpression(expression2, (Map<? extends Expression, ? extends Expression>)ImmutableMap.of((Object)subExpression, (Object)equivalentSubExpression));
                equalities.findAndUnion(expression2, rewritten);
                derivedExpressions.add(rewritten);
            }));
        }
        Comparator<Expression> comparator = Comparator.comparingInt(expression -> this.extractAllSymbols((Expression)expression).size()).thenComparingLong(expression -> this.extractSubExpressions((Expression)expression).size()).thenComparing(Object::toString);
        Multimap<Expression, Expression> equalitySets = EqualityInference.makeEqualitySets(equalities, comparator);
        ImmutableMap.Builder canonicalMappings = ImmutableMap.builder();
        for (Map.Entry entry : equalitySets.entries()) {
            Expression canonical = (Expression)entry.getKey();
            Expression expression3 = (Expression)entry.getValue();
            canonicalMappings.put((Object)expression3, (Object)canonical);
        }
        this.equalitySets = equalitySets;
        this.canonicalMap = canonicalMappings.buildOrThrow();
        this.derivedExpressions = derivedExpressions;
        this.canonicalComparator = comparator;
    }

    public Expression rewrite(Expression expression, Set<Symbol> scope) {
        return this.rewrite(expression, scope::contains, true);
    }

    public EqualityPartition generateEqualitiesPartitionedBy(Set<Symbol> scope) {
        ImmutableSet.Builder scopeEqualities = ImmutableSet.builder();
        ImmutableSet.Builder scopeComplementEqualities = ImmutableSet.builder();
        ImmutableSet.Builder scopeStraddlingEqualities = ImmutableSet.builder();
        for (Collection equalitySet : this.equalitySets.asMap().values()) {
            LinkedHashSet scopeExpressions = new LinkedHashSet();
            LinkedHashSet scopeComplementExpressions = new LinkedHashSet();
            LinkedHashSet scopeStraddlingExpressions = new LinkedHashSet();
            equalitySet.stream().filter(candidate -> !this.derivedExpressions.contains(candidate)).forEach(candidate -> {
                Expression scopeComplementRewritten;
                Expression scopeRewritten = this.rewrite((Expression)candidate, scope::contains, false);
                if (scopeRewritten != null) {
                    scopeExpressions.add(scopeRewritten);
                }
                if ((scopeComplementRewritten = this.rewrite((Expression)candidate, symbol -> !scope.contains(symbol), false)) != null) {
                    scopeComplementExpressions.add(scopeComplementRewritten);
                }
                if (scopeRewritten == null && scopeComplementRewritten == null) {
                    scopeStraddlingExpressions.add(candidate);
                }
            });
            Expression matchingCanonical = this.getCanonical(scopeExpressions.stream());
            if (scopeExpressions.size() >= 2) {
                scopeExpressions.stream().filter(expression -> !expression.equals(matchingCanonical)).map(expression -> new ComparisonExpression(ComparisonExpression.Operator.EQUAL, matchingCanonical, (Expression)expression)).forEach(arg_0 -> ((ImmutableSet.Builder)scopeEqualities).add(arg_0));
            }
            Expression complementCanonical = this.getCanonical(scopeComplementExpressions.stream());
            if (scopeComplementExpressions.size() >= 2) {
                scopeComplementExpressions.stream().filter(expression -> !expression.equals(complementCanonical)).map(expression -> new ComparisonExpression(ComparisonExpression.Operator.EQUAL, complementCanonical, (Expression)expression)).forEach(arg_0 -> ((ImmutableSet.Builder)scopeComplementEqualities).add(arg_0));
            }
            Optional<Expression> matchingConnecting = scopeExpressions.stream().filter(expression -> SymbolsExtractor.extractAll(expression).isEmpty() || this.rewrite((Expression)expression, symbol -> !scope.contains(symbol), false) == null).min(this.canonicalComparator);
            Optional<Expression> complementConnecting = scopeComplementExpressions.stream().filter(expression -> {
                if (SymbolsExtractor.extractAll(expression).isEmpty()) return true;
                if (this.rewrite((Expression)expression, scope::contains, false) != null) return false;
                return true;
            }).min(this.canonicalComparator);
            if (matchingConnecting.isPresent() && complementConnecting.isPresent() && !matchingConnecting.equals(complementConnecting)) {
                scopeStraddlingEqualities.add((Object)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, matchingConnecting.get(), complementConnecting.get()));
            }
            ArrayList<Expression> straddlingExpressions = new ArrayList<Expression>();
            if (matchingCanonical != null) {
                straddlingExpressions.add(matchingCanonical);
            } else if (complementCanonical != null) {
                straddlingExpressions.add(complementCanonical);
            }
            straddlingExpressions.addAll(scopeStraddlingExpressions);
            Expression connectingCanonical = this.getCanonical(straddlingExpressions.stream());
            if (connectingCanonical == null) continue;
            straddlingExpressions.stream().filter(expression -> !expression.equals(connectingCanonical)).map(expression -> new ComparisonExpression(ComparisonExpression.Operator.EQUAL, connectingCanonical, (Expression)expression)).forEach(arg_0 -> ((ImmutableSet.Builder)scopeStraddlingEqualities).add(arg_0));
        }
        return new EqualityPartition((Iterable<Expression>)scopeEqualities.build(), (Iterable<Expression>)scopeComplementEqualities.build(), (Iterable<Expression>)scopeStraddlingEqualities.build());
    }

    public static boolean isInferenceCandidate(Metadata metadata, Expression expression) {
        if (expression instanceof ComparisonExpression) {
            ComparisonExpression comparison = (ComparisonExpression)expression;
            if (DeterminismEvaluator.isDeterministic(expression, metadata) && !NullabilityAnalyzer.mayReturnNullOnNonNullInput(expression) && comparison.getOperator() == ComparisonExpression.Operator.EQUAL) {
                return !comparison.getLeft().equals(comparison.getRight());
            }
        }
        return false;
    }

    public static Stream<Expression> nonInferrableConjuncts(Metadata metadata, Expression expression) {
        return IrUtils.extractConjuncts(expression).stream().filter(e -> !EqualityInference.isInferenceCandidate(metadata, e));
    }

    private Expression rewrite(Expression expression, Predicate<Symbol> symbolScope, boolean allowFullReplacement) {
        HashMap expressionRemap = new HashMap();
        this.extractSubExpressions(expression).stream().filter(allowFullReplacement ? subExpression -> true : subExpression -> !subExpression.equals(expression)).forEach(subExpression -> {
            Expression canonical = this.getScopedCanonical((Expression)subExpression, symbolScope);
            if (canonical != null) {
                expressionRemap.putIfAbsent(subExpression, canonical);
            }
        });
        Expression rewritten = ExpressionNodeInliner.replaceExpression(expression, expressionRemap);
        if (!this.isScoped(rewritten, symbolScope)) {
            return null;
        }
        return rewritten;
    }

    private Expression getCanonical(Stream<Expression> expressions) {
        return expressions.min(this.canonicalComparator).orElse(null);
    }

    @VisibleForTesting
    Expression getScopedCanonical(Expression expression, Predicate<Symbol> symbolScope) {
        Expression canonicalIndex = this.canonicalMap.get(expression);
        if (canonicalIndex == null) {
            return null;
        }
        Collection equivalences = this.equalitySets.get((Object)canonicalIndex);
        if (expression instanceof SymbolReference) {
            boolean inScope = equivalences.stream().filter(SymbolReference.class::isInstance).map(Symbol::from).anyMatch(symbolScope);
            if (!inScope) {
                return null;
            }
        }
        return this.getCanonical(equivalences.stream().filter(e -> this.isScoped((Expression)e, symbolScope)));
    }

    private boolean isScoped(Expression expression, Predicate<Symbol> symbolScope) {
        return this.extractUniqueSymbols(expression).stream().allMatch(symbolScope);
    }

    private static Multimap<Expression, Expression> makeEqualitySets(DisjointSet<Expression> equalities, Comparator<Expression> canonicalComparator) {
        ImmutableSetMultimap.Builder builder = ImmutableSetMultimap.builder();
        for (Set<Expression> equalityGroup : equalities.getEquivalentClasses()) {
            if (equalityGroup.isEmpty()) continue;
            builder.putAll((Object)equalityGroup.stream().min(canonicalComparator).get(), equalityGroup);
        }
        return builder.build();
    }

    private List<Expression> extractSubExpressions(Expression expression) {
        return this.expressionCache.computeIfAbsent(expression, e -> (List)IrUtils.preOrder(e).collect(ImmutableList.toImmutableList()));
    }

    private Set<Symbol> extractUniqueSymbols(Expression expression) {
        return this.uniqueSymbolsCache.computeIfAbsent(expression, e -> ImmutableSet.copyOf(this.extractAllSymbols(expression)));
    }

    private List<Symbol> extractAllSymbols(Expression expression) {
        return this.symbolsCache.computeIfAbsent(expression, SymbolsExtractor::extractAll);
    }

    public static class EqualityPartition {
        private final List<Expression> scopeEqualities;
        private final List<Expression> scopeComplementEqualities;
        private final List<Expression> scopeStraddlingEqualities;

        public EqualityPartition(Iterable<Expression> scopeEqualities, Iterable<Expression> scopeComplementEqualities, Iterable<Expression> scopeStraddlingEqualities) {
            this.scopeEqualities = ImmutableList.copyOf(Objects.requireNonNull(scopeEqualities, "scopeEqualities is null"));
            this.scopeComplementEqualities = ImmutableList.copyOf(Objects.requireNonNull(scopeComplementEqualities, "scopeComplementEqualities is null"));
            this.scopeStraddlingEqualities = ImmutableList.copyOf(Objects.requireNonNull(scopeStraddlingEqualities, "scopeStraddlingEqualities is null"));
        }

        public List<Expression> getScopeEqualities() {
            return this.scopeEqualities;
        }

        public List<Expression> getScopeComplementEqualities() {
            return this.scopeComplementEqualities;
        }

        public List<Expression> getScopeStraddlingEqualities() {
            return this.scopeStraddlingEqualities;
        }
    }
}

