/*
 * Decompiled with CFR 0.152.
 */
package com.taotao.boot.data.mybatis.mybatisplus.interceptor.datascope.dataPermission.db;

import com.alibaba.ttl.TransmittableThreadLocal;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.taotao.boot.data.mybatis.mybatisplus.MpUtils;
import com.taotao.boot.data.mybatis.mybatisplus.interceptor.datascope.dataPermission.factory.DataPermissionRuleFactory;
import com.taotao.boot.data.mybatis.mybatisplus.interceptor.datascope.dataPermission.rule.DataPermissionRule;
import java.sql.Connection;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.NotExpression;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SetOperationList;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.dromara.hutool.core.collection.CollUtil;
import org.dromara.hutool.core.collection.ListUtil;

public class DataPermissionDatabaseInterceptor
extends JsqlParserSupport
implements InnerInterceptor {
    private final DataPermissionRuleFactory ruleFactory;
    private final MappedStatementCache mappedStatementCache = new MappedStatementCache();

    public DataPermissionDatabaseInterceptor(DataPermissionRuleFactory ruleFactory) {
        this.ruleFactory = ruleFactory;
    }

    public MappedStatementCache getMappedStatementCache() {
        return this.mappedStatementCache;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
        List<DataPermissionRule> rules = this.ruleFactory.getDataPermissionRule();
        if (this.mappedStatementCache.noRewritable(ms, rules)) {
            return;
        }
        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql((BoundSql)boundSql);
        try {
            ContextHolder.init(rules);
            mpBs.sql(this.parserSingle(mpBs.sql(), null));
        }
        finally {
            this.addMappedStatementCache(ms);
            ContextHolder.clear();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler((StatementHandler)sh);
        MappedStatement ms = mpSh.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
            List<DataPermissionRule> rules = this.ruleFactory.getDataPermissionRule();
            if (this.mappedStatementCache.noRewritable(ms, rules)) {
                return;
            }
            PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
            try {
                ContextHolder.init(rules);
                mpBs.sql(this.parserMulti(mpBs.sql(), null));
            }
            finally {
                this.addMappedStatementCache(ms);
                ContextHolder.clear();
            }
        }
    }

    protected void processSelect(Select select, int index, String sql, Object obj) {
        this.processSelectBody(select);
    }

    protected void processSelectBody(Select selectBody) {
        if (selectBody == null) {
            return;
        }
        if (selectBody instanceof PlainSelect) {
            this.processPlainSelect((PlainSelect)selectBody);
        } else {
            SetOperationList operationList = (SetOperationList)selectBody;
            List selectBodys = operationList.getSelects();
            if (CollectionUtils.isNotEmpty((Collection)selectBodys)) {
                selectBodys.forEach(this::processSelectBody);
            }
        }
    }

    protected void processUpdate(Update update, int index, String sql, Object obj) {
        Table table = update.getTable();
        update.setWhere(this.builderExpression(update.getWhere(), table));
    }

    protected void processDelete(Delete delete, int index, String sql, Object obj) {
        delete.setWhere(this.builderExpression(delete.getWhere(), delete.getTable()));
    }

    protected void processPlainSelect(PlainSelect plainSelect) {
        List joins;
        List selectItems;
        FromItem fromItem = plainSelect.getFromItem();
        Expression where = plainSelect.getWhere();
        this.processWhereSubSelect(where);
        if (fromItem instanceof Table) {
            Table fromTable = (Table)fromItem;
            plainSelect.setWhere(this.builderExpression(where, fromTable));
        }
        if (CollectionUtils.isNotEmpty((Collection)(selectItems = plainSelect.getSelectItems()))) {
            // empty if block
        }
        if (CollectionUtils.isNotEmpty((Collection)(joins = plainSelect.getJoins()))) {
            this.processJoins(joins);
        }
    }

    protected void processWhereSubSelect(Expression where) {
        if (where == null) {
            return;
        }
        if (where instanceof FromItem) {
            return;
        }
        if (where.toString().indexOf("SELECT") > 0) {
            if (where instanceof BinaryExpression) {
                BinaryExpression expression = (BinaryExpression)where;
                this.processWhereSubSelect(expression.getLeftExpression());
                this.processWhereSubSelect(expression.getRightExpression());
            } else if (where instanceof InExpression) {
                InExpression expression = (InExpression)where;
            } else if (where instanceof ExistsExpression) {
                ExistsExpression expression = (ExistsExpression)where;
                this.processWhereSubSelect(expression.getRightExpression());
            } else if (where instanceof NotExpression) {
                NotExpression expression = (NotExpression)where;
                this.processWhereSubSelect(expression.getExpression());
            }
        }
    }

    private void processJoins(List<Join> joins) {
        LinkedList<Table> tables = new LinkedList<Table>();
        for (Join join : joins) {
            FromItem fromItem = join.getRightItem();
            if (!(fromItem instanceof Table)) continue;
            Table fromTable = (Table)fromItem;
            Collection originOnExpressions = join.getOnExpressions();
            if (originOnExpressions.size() == 1) {
                this.processJoin(join);
                continue;
            }
            tables.push(fromTable);
            if (originOnExpressions.size() <= 1) continue;
            LinkedList<Expression> onExpressions = new LinkedList<Expression>();
            for (Expression originOnExpression : originOnExpressions) {
                Table currentTable = (Table)tables.poll();
                onExpressions.add(this.builderExpression(originOnExpression, currentTable));
            }
            join.setOnExpressions(onExpressions);
        }
    }

    protected void processJoin(Join join) {
        FromItem fromItem = join.getRightItem();
        if (fromItem instanceof Table) {
            Table fromTable = (Table)fromItem;
            Expression originOnExpression = (Expression)CollUtil.getFirst((Iterable)join.getOnExpressions());
            originOnExpression = this.builderExpression(originOnExpression, fromTable);
            join.setOnExpressions((Collection)ListUtil.of((Object[])new Expression[]{originOnExpression}));
        }
    }

    protected Expression builderExpression(Expression currentExpression, Table table) {
        Expression equalsTo = this.buildDataPermissionExpression(table);
        if (equalsTo == null) {
            return currentExpression;
        }
        if (currentExpression == null) {
            return equalsTo;
        }
        if (currentExpression instanceof OrExpression) {
            // empty if block
        }
        return new AndExpression(currentExpression, equalsTo);
    }

    private Expression buildDataPermissionExpression(Table table) {
        Expression allExpression = null;
        for (DataPermissionRule rule : ContextHolder.getRules()) {
            if (!rule.getTableNames().contains(table.getName())) continue;
            ContextHolder.setRewrite(true);
            String tableName = MpUtils.getTableName(table);
            Expression oneExpress = rule.getExpression(tableName, table.getAlias());
            allExpression = allExpression == null ? oneExpress : new AndExpression(allExpression, oneExpress);
        }
        return allExpression;
    }

    private void addMappedStatementCache(MappedStatement ms) {
        if (ContextHolder.getRewrite()) {
            return;
        }
        this.mappedStatementCache.addNoRewritable(ms, ContextHolder.getRules());
    }

    static final class MappedStatementCache {
        private final Map<Class<? extends DataPermissionRule>, Set<String>> noRewritableMappedStatements = new ConcurrentHashMap<Class<? extends DataPermissionRule>, Set<String>>();

        MappedStatementCache() {
        }

        public Map<Class<? extends DataPermissionRule>, Set<String>> getNoRewritableMappedStatements() {
            return this.noRewritableMappedStatements;
        }

        public boolean noRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
            if (CollUtil.isEmpty(rules)) {
                return true;
            }
            for (DataPermissionRule rule : rules) {
                Set<String> mappedStatementIds = this.noRewritableMappedStatements.get(rule.getClass());
                if (CollUtil.contains(mappedStatementIds, (Object)ms.getId())) continue;
                return false;
            }
            return true;
        }

        public void addNoRewritable(MappedStatement ms, List<DataPermissionRule> rules) {
            for (DataPermissionRule rule : rules) {
                Set<String> mappedStatementIds = this.noRewritableMappedStatements.get(rule.getClass());
                if (CollUtil.isNotEmpty(mappedStatementIds)) {
                    mappedStatementIds.add(ms.getId());
                    continue;
                }
                this.noRewritableMappedStatements.put(rule.getClass(), new HashSet<String>(Collections.singletonList(ms.getId())));
            }
        }
    }

    static final class ContextHolder {
        private static final ThreadLocal<List<DataPermissionRule>> RULES = new TransmittableThreadLocal();
        private static final ThreadLocal<Boolean> REWRITE = new TransmittableThreadLocal();

        ContextHolder() {
        }

        public static void init(List<DataPermissionRule> rules) {
            RULES.set(rules);
            REWRITE.set(false);
        }

        public static void clear() {
            RULES.remove();
            REWRITE.remove();
        }

        public static boolean getRewrite() {
            return REWRITE.get();
        }

        public static void setRewrite(boolean rewrite) {
            REWRITE.set(rewrite);
        }

        public static List<DataPermissionRule> getRules() {
            return RULES.get();
        }
    }
}

