/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shardingsphere.sharding.route.engine.validator.dml.impl;

import java.util.List;
import java.util.Optional;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.infra.metadata.model.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sharding.route.engine.validator.dml.ShardingDMLStatementValidator;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.InExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ListExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.UpdateStatementHandler;

public final class ShardingUpdateStatementValidator
extends ShardingDMLStatementValidator<UpdateStatement> {
    @Override
    public void preValidate(ShardingRule shardingRule, SQLStatementContext<UpdateStatement> sqlStatementContext, List<Object> parameters, ShardingSphereMetaData metaData) {
        this.validateShardingMultipleTable(shardingRule, sqlStatementContext);
        UpdateStatement sqlStatement = (UpdateStatement)sqlStatementContext.getSqlStatement();
        String tableName = ((SimpleTableSegment)sqlStatementContext.getTablesContext().getTables().iterator().next()).getTableName().getIdentifier().getValue();
        for (AssignmentSegment each : sqlStatement.getSetAssignment().getAssignments()) {
            String shardingColumn = each.getColumn().getIdentifier().getValue();
            if (!shardingRule.isShardingColumn(shardingColumn, tableName)) continue;
            Optional<Object> shardingColumnSetAssignmentValue = this.getShardingColumnSetAssignmentValue(each, parameters);
            Optional<Object> shardingValue = Optional.empty();
            Optional whereSegmentOptional = sqlStatement.getWhere();
            if (whereSegmentOptional.isPresent()) {
                shardingValue = this.getShardingValue((WhereSegment)whereSegmentOptional.get(), parameters, shardingColumn);
            }
            if (shardingColumnSetAssignmentValue.isPresent() && shardingValue.isPresent() && shardingColumnSetAssignmentValue.get().equals(shardingValue.get())) continue;
            throw new ShardingSphereException("Can not update sharding key, logic table: [%s], column: [%s].", new Object[]{tableName, shardingColumn});
        }
    }

    private Optional<Object> getShardingColumnSetAssignmentValue(AssignmentSegment assignmentSegment, List<Object> parameters) {
        ExpressionSegment segment = assignmentSegment.getValue();
        int shardingSetAssignIndex = -1;
        if (segment instanceof ParameterMarkerExpressionSegment) {
            shardingSetAssignIndex = ((ParameterMarkerExpressionSegment)segment).getParameterMarkerIndex();
        }
        if (segment instanceof LiteralExpressionSegment) {
            return Optional.of(((LiteralExpressionSegment)segment).getLiterals());
        }
        if (-1 == shardingSetAssignIndex || shardingSetAssignIndex > parameters.size() - 1) {
            return Optional.empty();
        }
        return Optional.of(parameters.get(shardingSetAssignIndex));
    }

    private Optional<Object> getShardingValue(WhereSegment whereSegment, List<Object> parameters, String shardingColumn) {
        if (null != whereSegment) {
            return this.getShardingValue(whereSegment.getExpr(), parameters, shardingColumn);
        }
        return Optional.empty();
    }

    private Optional<Object> getShardingValue(ExpressionSegment expression, List<Object> parameters, String shardingColumn) {
        boolean logical;
        ColumnSegment column;
        boolean compare;
        ColumnSegment column2;
        if (expression instanceof InExpression && ((InExpression)expression).getLeft() instanceof ColumnSegment && !shardingColumn.equalsIgnoreCase((column2 = (ColumnSegment)((InExpression)expression).getLeft()).getIdentifier().getValue())) {
            return this.getPredicateInShardingValue(((InExpression)expression).getRight(), parameters);
        }
        if (!(expression instanceof BinaryOperationExpression)) {
            return Optional.empty();
        }
        String operator = ((BinaryOperationExpression)expression).getOperator();
        boolean bl = compare = ">".equalsIgnoreCase(operator) || ">=".equalsIgnoreCase(operator) || "=".equalsIgnoreCase(operator) || "<".equalsIgnoreCase(operator) || "<=".equalsIgnoreCase(operator);
        if (compare && ((BinaryOperationExpression)expression).getLeft() instanceof ColumnSegment && shardingColumn.equalsIgnoreCase((column = (ColumnSegment)((BinaryOperationExpression)expression).getLeft()).getIdentifier().getValue())) {
            return this.getPredicateCompareShardingValue(((BinaryOperationExpression)expression).getRight(), parameters);
        }
        boolean bl2 = logical = "and".equalsIgnoreCase(operator) || "&&".equalsIgnoreCase(operator) || "OR".equalsIgnoreCase(operator) || "||".equalsIgnoreCase(operator);
        if (logical) {
            Optional<Object> leftResult = this.getShardingValue(((BinaryOperationExpression)expression).getLeft(), parameters, shardingColumn);
            return leftResult.isPresent() ? leftResult : this.getShardingValue(((BinaryOperationExpression)expression).getRight(), parameters, shardingColumn);
        }
        return Optional.empty();
    }

    private Optional<Object> getPredicateCompareShardingValue(ExpressionSegment segment, List<Object> parameters) {
        if (segment instanceof ParameterMarkerExpressionSegment) {
            int shardingValueParameterMarkerIndex = ((ParameterMarkerExpressionSegment)segment).getParameterMarkerIndex();
            if (-1 == shardingValueParameterMarkerIndex || shardingValueParameterMarkerIndex > parameters.size() - 1) {
                return Optional.empty();
            }
            return Optional.of(parameters.get(shardingValueParameterMarkerIndex));
        }
        if (segment instanceof LiteralExpressionSegment) {
            return Optional.of(((LiteralExpressionSegment)segment).getLiterals());
        }
        return Optional.empty();
    }

    private Optional<Object> getPredicateInShardingValue(ExpressionSegment segments, List<Object> parameters) {
        if (!(segments instanceof ListExpression)) {
            return Optional.empty();
        }
        List expressionSegments = ((ListExpression)segments).getItems();
        for (ExpressionSegment each : expressionSegments) {
            if (each instanceof ParameterMarkerExpressionSegment) {
                int shardingColumnWhereIndex = ((ParameterMarkerExpressionSegment)each).getParameterMarkerIndex();
                if (-1 == shardingColumnWhereIndex || shardingColumnWhereIndex > parameters.size() - 1) continue;
                return Optional.of(parameters.get(shardingColumnWhereIndex));
            }
            if (!(each instanceof LiteralExpressionSegment)) continue;
            return Optional.of(((LiteralExpressionSegment)each).getLiterals());
        }
        return Optional.empty();
    }

    @Override
    public void postValidate(UpdateStatement sqlStatement, RouteContext routeContext) {
        if (UpdateStatementHandler.getLimitSegment((UpdateStatement)sqlStatement).isPresent() && routeContext.getRouteUnits().size() > 1) {
            throw new ShardingSphereException("UPDATE ... LIMIT can not support sharding route to multiple data nodes.", new Object[0]);
        }
    }
}

