/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.security;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import org.apache.calcite.sql.SqlAsOperator;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlJoin;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.util.SqlBasicVisitor;
import org.apache.calcite.sql.util.SqlVisitor;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.text.StrBuilder;
import org.apache.kylin.common.KylinConfig;
import org.apache.kylin.common.QueryContext;
import org.apache.kylin.common.exception.KylinRuntimeException;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.metadata.acl.AclTCRManager;
import org.apache.kylin.metadata.model.tool.CalciteParser;
import org.apache.kylin.query.IQueryTransformer;
import org.apache.kylin.source.adhocquery.IPushDownConverter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RowFilter
implements IQueryTransformer,
IPushDownConverter {
    private static final Logger logger = LoggerFactory.getLogger(RowFilter.class);

    static boolean needEscape(String sql, String defaultSchema, Map<String, String> cond) {
        return StringUtils.isEmpty((String)defaultSchema) || StringUtils.isEmpty((String)sql) || !StringUtils.containsIgnoreCase((String)sql, (String)"from") || cond.isEmpty();
    }

    static String whereClauseBracketsCompletion(String schema, String inputSQL, Set<String> candidateTables) {
        return RowFilter.whereClauseBracketsCompletion(schema, inputSQL, candidateTables, null);
    }

    static String whereClauseBracketsCompletion(String schema, String inputSQL, Set<String> candidateTables, String project) {
        Map<SqlSelect, List<Table>> selectClausesWithTbls = RowFilter.getSelectClausesWithTbls(inputSQL, schema, project);
        ArrayList<Pair<Integer, String>> toBeInsertedPosAndExprs = new ArrayList<Pair<Integer, String>>();
        block0: for (Map.Entry<SqlSelect, List<Table>> select : selectClausesWithTbls.entrySet()) {
            if (!select.getKey().hasWhere()) continue;
            for (Table table : select.getValue()) {
                if (!candidateTables.contains(table.getName())) continue;
                Pair replacePos = CalciteParser.getReplacePos((SqlNode)select.getKey().getWhere(), (String)inputSQL);
                toBeInsertedPosAndExprs.add((Pair<Integer, String>)Pair.newPair((Object)replacePos.getFirst(), (Object)"("));
                toBeInsertedPosAndExprs.add((Pair<Integer, String>)Pair.newPair((Object)replacePos.getSecond(), (Object)")"));
                continue block0;
            }
        }
        return RowFilter.afterInsertSQL(inputSQL, toBeInsertedPosAndExprs);
    }

    static String rowFilter(String schema, String inputSQL, Map<String, String> whereCondWithTbls) {
        return RowFilter.rowFilter(schema, inputSQL, whereCondWithTbls, null);
    }

    static String rowFilter(String schema, String inputSQL, Map<String, String> whereCondWithTbls, String project) {
        Map<SqlSelect, List<Table>> selectClausesWithTbls = RowFilter.getSelectClausesWithTbls(inputSQL, schema, project);
        List<Pair<Integer, String>> toBeInsertedPosAndExprs = RowFilter.getInsertPosAndExpr(inputSQL, whereCondWithTbls, selectClausesWithTbls);
        return RowFilter.afterInsertSQL(inputSQL, toBeInsertedPosAndExprs);
    }

    private static String afterInsertSQL(String inputSQL, List<Pair<Integer, String>> toBeInsertedPosAndExprs) {
        Collections.sort(toBeInsertedPosAndExprs, (o1, o2) -> -((Integer)o1.getFirst() - (Integer)o2.getFirst()));
        StrBuilder convertedSQL = new StrBuilder(inputSQL);
        for (Pair<Integer, String> toBeInserted : toBeInsertedPosAndExprs) {
            int insertPos = (Integer)toBeInserted.getFirst();
            convertedSQL.insert(insertPos, (String)toBeInserted.getSecond());
        }
        return convertedSQL.toString();
    }

    private static List<Pair<Integer, String>> getInsertPosAndExpr(String inputSQL, Map<String, String> whereCondWithTbls, Map<SqlSelect, List<Table>> selectClausesWithTbls) {
        ArrayList<Pair<Integer, String>> toBeReplacedPosAndExprs = new ArrayList<Pair<Integer, String>>();
        for (Map.Entry<SqlSelect, List<Table>> select : selectClausesWithTbls.entrySet()) {
            int insertPos = RowFilter.getInsertPos(inputSQL, select.getKey());
            List<Table> tables = select.getValue();
            String whereCond = RowFilter.getToBeInsertCond(whereCondWithTbls, select.getKey(), tables);
            if (whereCond.isEmpty()) continue;
            toBeReplacedPosAndExprs.add((Pair<Integer, String>)Pair.newPair((Object)insertPos, (Object)whereCond));
        }
        return toBeReplacedPosAndExprs;
    }

    private static int getInsertPos(String inputSQL, SqlSelect select) {
        SqlNode insertAfter = RowFilter.getInsertAfterNode(select);
        Pair pos = CalciteParser.getReplacePos((SqlNode)insertAfter, (String)inputSQL);
        int finalPos = (Integer)pos.getSecond();
        int bracketNum = 0;
        int j = (Integer)pos.getFirst() - 1;
        while (true) {
            if (inputSQL.charAt(j) != ' ' && inputSQL.charAt(j) != '\t' && inputSQL.charAt(j) != '\n') {
                if (inputSQL.charAt(j) != '(') break;
                ++bracketNum;
            }
            --j;
        }
        for (int i = ((Integer)pos.getSecond()).intValue(); i < inputSQL.length() && bracketNum > 0; ++i) {
            if (inputSQL.charAt(i) == ' ' || inputSQL.charAt(i) == '\t' || inputSQL.charAt(i) == '\n') continue;
            if (inputSQL.charAt(i) != ')') break;
            finalPos = i + 1;
            --bracketNum;
        }
        return finalPos;
    }

    private static String getToBeInsertCond(Map<String, String> whereCondWithTbls, SqlSelect select, List<Table> tables) {
        StringBuilder whereCond = new StringBuilder();
        boolean isHeadCond = true;
        for (Table table : tables) {
            String cond = whereCondWithTbls.get(table.getName());
            if (StringUtils.isEmpty((String)cond)) continue;
            cond = CalciteParser.insertAliasInExpr((String)cond, (String)table.getAlias());
            if (isHeadCond && !select.hasWhere()) {
                whereCond = new StringBuilder(" WHERE " + cond);
                isHeadCond = false;
                continue;
            }
            whereCond.append(" AND ").append(cond);
        }
        return whereCond.toString();
    }

    private static SqlNode getInsertAfterNode(SqlSelect select) {
        SqlNode rightMost = !select.hasWhere() ? (select.getFrom() instanceof SqlJoin ? (SqlNode)Preconditions.checkNotNull((Object)((SqlJoin)select.getFrom()).getCondition(), (Object)"Join without \"ON\"") : select.getFrom()) : select.getWhere();
        return rightMost;
    }

    private static Map<SqlSelect, List<Table>> getSelectClausesWithTbls(String inputSQL, String schema, String project) {
        HashMap<SqlSelect, List<Table>> selectWithTables = new HashMap<SqlSelect, List<Table>>();
        for (SqlSelect select : SelectClauseFinder.getSelectClauses(inputSQL, project)) {
            List<Table> tblsWithAlias = RowFilter.getTblWithAlias(schema, select);
            if (tblsWithAlias.size() <= 0) continue;
            selectWithTables.put(select, tblsWithAlias);
        }
        return selectWithTables;
    }

    static List<Table> getTblWithAlias(String schema, SqlSelect select) {
        List<Table> tblsWithAlias = NonSubqueryTablesFinder.getTblsWithAlias(select.getFrom());
        for (int i = 0; i < tblsWithAlias.size(); ++i) {
            Table t = tblsWithAlias.get(i);
            if (t.getName().split("\\.").length != 1) continue;
            tblsWithAlias.set(i, new Table(schema + "." + t.getName(), t.getAlias()));
        }
        return tblsWithAlias;
    }

    private static boolean hasAdminPermission(QueryContext.AclInfo aclInfo) {
        if (Objects.isNull(aclInfo) || Objects.isNull(aclInfo.getGroups())) {
            return false;
        }
        return aclInfo.getGroups().stream().anyMatch("ROLE_ADMIN"::equals) || aclInfo.isHasAdminPermission();
    }

    public String convert(String originSql, String project, String defaultSchema) {
        return this.transform(originSql, project, defaultSchema);
    }

    @Override
    public String transform(String sql, String project, String defaultSchema) {
        QueryContext.AclInfo aclLocal = QueryContext.current().getAclInfo();
        if (!KylinConfig.getInstanceFromEnv().isAclTCREnabled() || RowFilter.hasAdminPermission(aclLocal)) {
            return sql;
        }
        Map<String, String> allWhereCondWithTbls = this.getAllWhereCondWithTbls(project, aclLocal);
        if (RowFilter.needEscape(sql, defaultSchema, allWhereCondWithTbls)) {
            return sql;
        }
        logger.debug("\nStart to transform SQL with row ACL\n");
        sql = RowFilter.whereClauseBracketsCompletion(defaultSchema, sql, this.getCandidateTables(allWhereCondWithTbls), project);
        sql = RowFilter.rowFilter(defaultSchema, sql, allWhereCondWithTbls, project);
        logger.debug("\nFinish transforming SQL with row ACL.\n");
        return sql;
    }

    private Set<String> getCandidateTables(Map<String, String> allWhereCondWithTbls) {
        TreeSet<String> candidateTables = new TreeSet<String>(String.CASE_INSENSITIVE_ORDER);
        candidateTables.addAll(allWhereCondWithTbls.keySet());
        return candidateTables;
    }

    private Map<String, String> getAllWhereCondWithTbls(String project, QueryContext.AclInfo aclInfo) {
        String user = Objects.nonNull(aclInfo) ? aclInfo.getUsername() : null;
        Set groups = Objects.nonNull(aclInfo) ? aclInfo.getGroups() : null;
        return AclTCRManager.getInstance((KylinConfig)KylinConfig.getInstanceFromEnv(), (String)project).getTableColumnConcatWhereCondition(user, groups);
    }

    static class Table {
        private String name;
        private String alias;

        public Table(String name, String alias) {
            this.name = name;
            this.alias = alias;
        }

        public void setTable(Pair<String, String> tableWithAlias) {
            this.name = (String)tableWithAlias.getFirst();
            this.alias = (String)tableWithAlias.getSecond();
        }

        public String getName() {
            return this.name;
        }

        public String getAlias() {
            return this.alias;
        }
    }

    static class NonSubqueryTablesFinder
    extends SqlBasicVisitor<SqlNode> {
        private List<Table> tablesWithAlias = new ArrayList<Table>();

        private NonSubqueryTablesFinder() {
        }

        static List<Table> getTblsWithAlias(SqlNode fromNode) {
            NonSubqueryTablesFinder sv = new NonSubqueryTablesFinder();
            fromNode.accept((SqlVisitor)sv);
            return sv.getTblsWithAlias();
        }

        private List<Table> getTblsWithAlias() {
            return this.tablesWithAlias;
        }

        public SqlNode visit(SqlNodeList nodeList) {
            return null;
        }

        public SqlNode visit(SqlCall call) {
            if (!(call instanceof SqlSelect)) {
                if (call instanceof SqlBasicCall) {
                    SqlBasicCall node = (SqlBasicCall)call;
                    if (node.getOperator() instanceof SqlAsOperator && node.getOperands()[0] instanceof SqlIdentifier) {
                        SqlIdentifier id0 = (SqlIdentifier)((SqlBasicCall)call).getOperands()[0];
                        SqlIdentifier id1 = (SqlIdentifier)((SqlBasicCall)call).getOperands()[1];
                        String table = id0.toString();
                        String alais = CalciteParser.getLastNthName((SqlIdentifier)id1, (int)1);
                        this.tablesWithAlias.add(new Table(table, alais));
                    }
                } else if (call instanceof SqlJoin) {
                    SqlJoin node = (SqlJoin)call;
                    node.getLeft().accept((SqlVisitor)this);
                    node.getRight().accept((SqlVisitor)this);
                } else {
                    for (SqlNode operand : call.getOperandList()) {
                        if (operand == null) continue;
                        operand.accept((SqlVisitor)this);
                    }
                }
            }
            return null;
        }

        public SqlNode visit(SqlIdentifier id) {
            String[] dotSplits = id.toString().toUpperCase(Locale.ROOT).split("\\.");
            String table = dotSplits[dotSplits.length - 1];
            this.tablesWithAlias.add(new Table(id.toString().toUpperCase(Locale.ROOT), table));
            return null;
        }
    }

    static class SelectClauseFinder
    extends SqlBasicVisitor<SqlNode> {
        private List<SqlSelect> selects = new ArrayList<SqlSelect>();

        SelectClauseFinder() {
        }

        static List<SqlSelect> getSelectClauses(String inputSQL, String project) {
            SqlNode node = null;
            try {
                node = CalciteParser.parse((String)inputSQL, (String)project);
            }
            catch (SqlParseException e) {
                throw new KylinRuntimeException("Failed to parse SQL '" + inputSQL + "', please make sure the SQL is valid");
            }
            SelectClauseFinder sv = new SelectClauseFinder();
            node.accept((SqlVisitor)sv);
            return sv.getSelectClauses();
        }

        private List<SqlSelect> getSelectClauses() {
            return this.selects;
        }

        public SqlNode visit(SqlNodeList nodeList) {
            for (int i = 0; i < nodeList.size(); ++i) {
                SqlNode node = nodeList.get(i);
                node.accept((SqlVisitor)this);
            }
            return null;
        }

        public SqlNode visit(SqlCall call) {
            if (call instanceof SqlSelect) {
                SqlSelect select = (SqlSelect)call;
                this.selects.add(select);
            }
            for (SqlNode operand : call.getOperandList()) {
                if (operand == null) continue;
                operand.accept((SqlVisitor)this);
            }
            return null;
        }
    }
}

