/*
 * Decompiled with CFR 0.152.
 */
package io.trino.spi.connector;

import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.StandardFunctions;
import io.trino.spi.expression.Variable;
import java.util.ArrayDeque;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@Deprecated
public final class JoinCondition {
    private final Operator operator;
    private final ConnectorExpression leftExpression;
    private final ConnectorExpression rightExpression;

    public static Optional<JoinCondition> from(ConnectorExpression expression, Set<String> leftSymbols, Set<String> rightSymbols) {
        Call call;
        if (expression instanceof Call && (call = (Call)expression).getArguments().size() == 2) {
            return Optional.ofNullable(Operator.byFunctionName.get(call.getFunctionName())).flatMap(operator -> {
                rightSymbols.stream().filter(leftSymbols::contains).findAny().ifPresent(symbol -> {
                    throw new IllegalArgumentException("Left and right symbol sets overlap, are both include %s: %s, %s".formatted(symbol, leftSymbols, rightSymbols));
                });
                ConnectorExpression left = call.getArguments().get(0);
                ConnectorExpression right = call.getArguments().get(1);
                Set<String> leftExpressionSymbols = JoinCondition.findVariableNames(left);
                Set<String> rightExpressionSymbols = JoinCondition.findVariableNames(right);
                if (leftSymbols.containsAll(leftExpressionSymbols) && rightSymbols.containsAll(rightExpressionSymbols)) {
                    return Optional.of(new JoinCondition((Operator)((Object)operator), left, right));
                }
                if (rightSymbols.containsAll(leftExpressionSymbols) && leftSymbols.containsAll(rightExpressionSymbols)) {
                    return Optional.of(new JoinCondition(operator.flip(), right, left));
                }
                return Optional.empty();
            });
        }
        return Optional.empty();
    }

    private static Set<String> findVariableNames(ConnectorExpression expression) {
        HashSet<String> variableNames = new HashSet<String>();
        HashSet<ConnectorExpression> visited = new HashSet<ConnectorExpression>();
        ArrayDeque<ConnectorExpression> pending = new ArrayDeque<ConnectorExpression>(List.of(expression));
        while (!pending.isEmpty()) {
            ConnectorExpression next = (ConnectorExpression)pending.remove();
            if (!visited.add(next)) continue;
            pending.addAll(next.getChildren());
            if (!(next instanceof Variable)) continue;
            Variable variable = (Variable)next;
            variableNames.add(variable.getName());
        }
        return variableNames;
    }

    public JoinCondition(Operator operator, ConnectorExpression leftExpression, ConnectorExpression rightExpression) {
        this.operator = Objects.requireNonNull(operator, "operator is null");
        this.leftExpression = Objects.requireNonNull(leftExpression, "leftExpression is null");
        this.rightExpression = Objects.requireNonNull(rightExpression, "rightExpression is null");
    }

    public Operator getOperator() {
        return this.operator;
    }

    public ConnectorExpression getLeftExpression() {
        return this.leftExpression;
    }

    public ConnectorExpression getRightExpression() {
        return this.rightExpression;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        JoinCondition that = (JoinCondition)o;
        return this.operator == that.operator && Objects.equals(this.leftExpression, that.leftExpression) && Objects.equals(this.rightExpression, that.rightExpression);
    }

    public int hashCode() {
        return Objects.hash(new Object[]{this.operator, this.leftExpression, this.rightExpression});
    }

    public String toString() {
        return String.format("%s %s %s", this.leftExpression, this.operator.getValue(), this.rightExpression);
    }

    @Deprecated
    public static enum Operator {
        EQUAL("=", StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME),
        NOT_EQUAL("<>", StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME),
        LESS_THAN("<", StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME),
        LESS_THAN_OR_EQUAL("<=", StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME),
        GREATER_THAN(">", StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME),
        GREATER_THAN_OR_EQUAL(">=", StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME),
        IDENTICAL("\u2261", StandardFunctions.IDENTICAL_OPERATOR_FUNCTION_NAME);

        private static final Map<FunctionName, Operator> byFunctionName;
        private final String value;
        private final FunctionName callFunctionName;

        private Operator(String value, FunctionName callFunctionName) {
            this.value = value;
            this.callFunctionName = callFunctionName;
        }

        public String getValue() {
            return this.value;
        }

        public Operator flip() {
            return switch (this.ordinal()) {
                default -> throw new MatchException(null, null);
                case 0, 1, 6 -> this;
                case 2 -> GREATER_THAN;
                case 3 -> GREATER_THAN_OR_EQUAL;
                case 4 -> LESS_THAN;
                case 5 -> LESS_THAN_OR_EQUAL;
            };
        }

        static {
            byFunctionName = Stream.of(Operator.values()).collect(Collectors.toUnmodifiableMap(operator -> operator.callFunctionName, Function.identity()));
        }
    }
}

