/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.deltalake.expression;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import io.trino.plugin.deltalake.expression.ArithmeticBinaryExpression;
import io.trino.plugin.deltalake.expression.BooleanLiteral;
import io.trino.plugin.deltalake.expression.ComparisonExpression;
import io.trino.plugin.deltalake.expression.Identifier;
import io.trino.plugin.deltalake.expression.LogicalExpression;
import io.trino.plugin.deltalake.expression.LongLiteral;
import io.trino.plugin.deltalake.expression.ParsingException;
import io.trino.plugin.deltalake.expression.SparkExpression;
import io.trino.plugin.deltalake.expression.SparkExpressionBaseBaseVisitor;
import io.trino.plugin.deltalake.expression.SparkExpressionBaseParser;
import io.trino.plugin.deltalake.expression.StringLiteral;
import java.util.HexFormat;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.TerminalNode;

public class SparkExpressionBuilder
extends SparkExpressionBaseBaseVisitor<Object> {
    private static final char STRING_LITERAL_ESCAPE_CHARACTER = '\\';

    @Override
    public Object visitStandaloneExpression(SparkExpressionBaseParser.StandaloneExpressionContext context) {
        return this.visit((ParseTree)context.expression());
    }

    @Override
    public Object visitPredicated(SparkExpressionBaseParser.PredicatedContext context) {
        if (context.predicate() != null) {
            return this.visit((ParseTree)context.predicate());
        }
        return this.visit((ParseTree)context.valueExpression);
    }

    @Override
    public Object visitArithmeticBinary(SparkExpressionBaseParser.ArithmeticBinaryContext context) {
        return new ArithmeticBinaryExpression(SparkExpressionBuilder.getArithmeticBinaryOperator(context.operator), (SparkExpression)this.visit((ParseTree)context.left), (SparkExpression)this.visit((ParseTree)context.right));
    }

    private static ArithmeticBinaryExpression.Operator getArithmeticBinaryOperator(Token operator) {
        switch (operator.getType()) {
            case 11: {
                return ArithmeticBinaryExpression.Operator.ADD;
            }
            case 12: {
                return ArithmeticBinaryExpression.Operator.SUBTRACT;
            }
            case 13: {
                return ArithmeticBinaryExpression.Operator.MULTIPLY;
            }
            case 14: {
                return ArithmeticBinaryExpression.Operator.DIVIDE;
            }
            case 15: {
                return ArithmeticBinaryExpression.Operator.MODULUS;
            }
            case 16: {
                return ArithmeticBinaryExpression.Operator.BITWISE_AND;
            }
            case 17: {
                return ArithmeticBinaryExpression.Operator.BITWISE_XOR;
            }
        }
        throw new UnsupportedOperationException("Unsupported operator: " + operator.getText());
    }

    @Override
    public Object visitComparison(SparkExpressionBaseParser.ComparisonContext context) {
        return new ComparisonExpression(SparkExpressionBuilder.getComparisonOperator(((TerminalNode)context.comparisonOperator().getChild(0)).getSymbol()), (SparkExpression)this.visit((ParseTree)context.value), (SparkExpression)this.visit((ParseTree)context.right));
    }

    @Override
    public SparkExpression visitAnd(SparkExpressionBaseParser.AndContext context) {
        Verify.verify((context.booleanExpression().size() == 2 ? 1 : 0) != 0, (String)("AND operator expects two expressions: " + context.booleanExpression()), (Object[])new Object[0]);
        return new LogicalExpression(LogicalExpression.Operator.AND, this.visit(context.left, SparkExpression.class), this.visit(context.right, SparkExpression.class));
    }

    @Override
    public Object visitOr(SparkExpressionBaseParser.OrContext context) {
        Verify.verify((context.booleanExpression().size() == 2 ? 1 : 0) != 0, (String)("AND operator expects two expressions: " + context.booleanExpression()), (Object[])new Object[0]);
        return new LogicalExpression(LogicalExpression.Operator.OR, this.visit(context.left, SparkExpression.class), this.visit(context.right, SparkExpression.class));
    }

    @Override
    public Object visitColumnReference(SparkExpressionBaseParser.ColumnReferenceContext context) {
        return this.visit((ParseTree)context.identifier());
    }

    private static ComparisonExpression.Operator getComparisonOperator(Token symbol) {
        return switch (symbol.getType()) {
            case 5 -> ComparisonExpression.Operator.EQUAL;
            case 6 -> ComparisonExpression.Operator.NOT_EQUAL;
            case 7 -> ComparisonExpression.Operator.LESS_THAN;
            case 8 -> ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
            case 9 -> ComparisonExpression.Operator.GREATER_THAN;
            case 10 -> ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
            default -> throw new IllegalArgumentException("Unsupported operator: " + symbol.getText());
        };
    }

    @Override
    public Object visitBooleanLiteral(SparkExpressionBaseParser.BooleanLiteralContext context) {
        return new BooleanLiteral(context.getText());
    }

    @Override
    public SparkExpression visitIntegerLiteral(SparkExpressionBaseParser.IntegerLiteralContext context) {
        return new LongLiteral(context.getText());
    }

    @Override
    public Object visitUnicodeStringLiteral(SparkExpressionBaseParser.UnicodeStringLiteralContext context) {
        return new StringLiteral(SparkExpressionBuilder.decodeUnicodeLiteral(context));
    }

    private static String decodeUnicodeLiteral(SparkExpressionBaseParser.UnicodeStringLiteralContext context) {
        String rawContent = SparkExpressionBuilder.unquote(context.getText());
        StringBuilder value = new StringBuilder();
        StringBuilder unicodeEscapeCharacters = new StringBuilder();
        int unicodeEscapeCharactersNeeded = 0;
        UnicodeDecodeState state = UnicodeDecodeState.BASE;
        block5: for (int i = 0; i < rawContent.length(); ++i) {
            char ch = rawContent.charAt(i);
            switch (state) {
                case BASE: {
                    if (ch == '\\') {
                        state = UnicodeDecodeState.ESCAPED;
                        continue block5;
                    }
                    value.append(ch);
                    continue block5;
                }
                case ESCAPED: {
                    if (ch == '\\') {
                        value.append('\\');
                        state = UnicodeDecodeState.BASE;
                        continue block5;
                    }
                    if (ch == 'u') {
                        state = UnicodeDecodeState.UNICODE_SEQUENCE;
                        unicodeEscapeCharactersNeeded = 4;
                        continue block5;
                    }
                    if (ch == 'U') {
                        state = UnicodeDecodeState.UNICODE_SEQUENCE;
                        unicodeEscapeCharactersNeeded = 8;
                        continue block5;
                    }
                    if (HexFormat.isHexDigit(ch)) {
                        state = UnicodeDecodeState.UNICODE_SEQUENCE;
                        unicodeEscapeCharacters.append(ch);
                        continue block5;
                    }
                    throw new ParsingException("Invalid hexadecimal digit: " + ch);
                }
                case UNICODE_SEQUENCE: {
                    if (!HexFormat.isHexDigit(ch)) continue block5;
                    unicodeEscapeCharacters.append(ch);
                    if (unicodeEscapeCharactersNeeded == unicodeEscapeCharacters.length()) {
                        String currentEscapedCode = unicodeEscapeCharacters.toString();
                        unicodeEscapeCharacters.setLength(0);
                        int codePoint = Integer.parseInt(currentEscapedCode, 16);
                        Preconditions.checkState((boolean)Character.isValidCodePoint(codePoint), (String)"Invalid escaped character: %s", (Object)currentEscapedCode);
                        value.appendCodePoint(codePoint);
                        state = UnicodeDecodeState.BASE;
                        unicodeEscapeCharactersNeeded = -1;
                        continue block5;
                    }
                    Preconditions.checkState((unicodeEscapeCharactersNeeded > unicodeEscapeCharacters.length() ? 1 : 0) != 0, (String)"Unexpected escape sequence length: %s", (int)unicodeEscapeCharacters.length());
                }
            }
        }
        if (state != UnicodeDecodeState.BASE) {
            throw new ParsingException(String.format("Incomplete escape sequence '%s' at the end of %s literal", unicodeEscapeCharacters, context.getText()));
        }
        return value.toString();
    }

    private static String unquote(String value) {
        if (value.startsWith("\"") && value.endsWith("\"")) {
            return value.substring(1, value.length() - 1).replace("\"\"", "\"");
        }
        if (value.startsWith("'") && value.endsWith("'")) {
            return value.substring(1, value.length() - 1).replace("''", "'");
        }
        throw new IllegalArgumentException("Unexpected value: " + value);
    }

    @Override
    public SparkExpression visitUnquotedIdentifier(SparkExpressionBaseParser.UnquotedIdentifierContext context) {
        return new Identifier(context.getText());
    }

    @Override
    public Object visitBackQuotedIdentifier(SparkExpressionBaseParser.BackQuotedIdentifierContext context) {
        String token = context.getText();
        String identifier = token.substring(1, token.length() - 1).replace("``", "`");
        return new Identifier(identifier);
    }

    private <T> T visit(ParserRuleContext context, Class<T> expected) {
        return expected.cast(super.visit((ParseTree)context));
    }

    protected Object aggregateResult(Object aggregate, Object nextResult) {
        if (nextResult == null) {
            throw new UnsupportedOperationException("not yet implemented");
        }
        if (aggregate == null) {
            return nextResult;
        }
        throw new UnsupportedOperationException(String.format("Cannot combine %s and %s", aggregate, nextResult));
    }

    private static enum UnicodeDecodeState {
        BASE,
        ESCAPED,
        UNICODE_SEQUENCE;

    }
}

