/*
 * Copyright 1999-2011 Alibaba Group Holding Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.alibaba.druid.wall;

import java.sql.DatabaseMetaData;
import java.sql.SQLException;
import java.sql.Wrapper;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.alibaba.druid.DruidRuntimeException;
import com.alibaba.druid.filter.FilterAdapter;
import com.alibaba.druid.filter.FilterChain;
import com.alibaba.druid.proxy.jdbc.CallableStatementProxy;
import com.alibaba.druid.proxy.jdbc.ConnectionProxy;
import com.alibaba.druid.proxy.jdbc.DataSourceProxy;
import com.alibaba.druid.proxy.jdbc.PreparedStatementProxy;
import com.alibaba.druid.proxy.jdbc.ResultSetProxy;
import com.alibaba.druid.proxy.jdbc.StatementProxy;
import com.alibaba.druid.support.logging.Log;
import com.alibaba.druid.support.logging.LogFactory;
import com.alibaba.druid.util.JdbcUtils;
import com.alibaba.druid.wall.spi.MySqlWallProvider;
import com.alibaba.druid.wall.spi.OracleWallProvider;
import com.alibaba.druid.wall.spi.PGWallProvider;
import com.alibaba.druid.wall.spi.SQLServerWallProvider;
import com.alibaba.druid.wall.violation.SyntaxErrorViolation;

public class WallFilter extends FilterAdapter implements WallFilterMBean {

    private final static Log   LOG            = LogFactory.getLog(WallFilter.class);

    private boolean            inited         = false;

    private WallProvider       provider;

    private String             dbType;

    private WallConfig         config;

    private volatile boolean   logViolation   = false;
    private volatile boolean   throwException = true;

    public final static String ATTR_SQL_STAT  = "wall.sqlStat";

    @Override
    public void init(DataSourceProxy dataSource) {
        if (this.dbType == null || this.dbType.trim().length() == 0) {
            if (dataSource != null && dataSource.getDbType() != null) {
                this.dbType = dataSource.getDbType();
            } else {
                this.dbType = JdbcUtils.getDbType(dataSource.getRawJdbcUrl(), "");
            }
        }

        if (dbType == null) {
            dbType = JdbcUtils.getDbType(dataSource.getUrl(), null);
        }

        if (JdbcUtils.MYSQL.equals(dbType) || JdbcUtils.H2.equals(dbType)) {
            if (config == null) {
                config = new WallConfig(MySqlWallProvider.DEFAULT_CONFIG_DIR);
            }

            provider = new MySqlWallProvider(config);
        } else if (JdbcUtils.ORACLE.equals(dbType) || JdbcUtils.ALI_ORACLE.equals(dbType)) {
            if (config == null) {
                config = new WallConfig(OracleWallProvider.DEFAULT_CONFIG_DIR);
            }

            provider = new OracleWallProvider(config);
        } else if (JdbcUtils.SQL_SERVER.equals(dbType) || JdbcUtils.JTDS.equals(dbType)) {
            if (config == null) {
                config = new WallConfig(SQLServerWallProvider.DEFAULT_CONFIG_DIR);
            }

            provider = new SQLServerWallProvider(config);
        } else if (JdbcUtils.POSTGRESQL.equals(dbType)) {
            if (config == null) {
                config = new WallConfig(PGWallProvider.DEFAULT_CONFIG_DIR);
            }

            provider = new PGWallProvider(config);
        } else {
            throw new IllegalStateException("dbType not support : " + dbType + ", url " + dataSource.getUrl());
        }

        this.inited = true;
    }

    public String getDbType() {
        return dbType;
    }

    public void setDbType(String dbType) {
        this.dbType = dbType;
    }

    public boolean isLogViolation() {
        return logViolation;
    }

    public void setLogViolation(boolean logViolation) {
        this.logViolation = logViolation;
    }

    public boolean isThrowException() {
        return throwException;
    }

    public void setThrowException(boolean throwException) {
        this.throwException = throwException;
    }

    public void clearProviderCache() {
        if (provider != null) {
            provider.clearCache();
        }
    }

    public Set<String> getProviderWhiteList() {
        if (provider == null) {
            return Collections.emptySet();
        }

        return provider.getWhiteList();
    }

    public WallProvider getProvider() {
        return provider;
    }

    public WallConfig getConfig() {
        return config;
    }

    public void setConfig(WallConfig config) {
        this.config = config;
    }

    public boolean isInited() {
        return inited;
    }

    public void checkInit() {
        if (inited) {
            throw new DruidRuntimeException("wall filter is inited");
        }
    }

    @Override
    public void statement_addBatch(FilterChain chain, StatementProxy statement, String sql) throws SQLException {
        createWallContext(statement);
        try {
            sql = check(sql);
            chain.statement_addBatch(statement, sql);
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql)
                                                                                                                        throws SQLException {
        String dbType = connection.getDirectDataSource().getDbType();
        WallContext.create(dbType);
        try {
            sql = check(sql);
            PreparedStatementProxy stmt = chain.connection_prepareStatement(connection, sql);
            setSqlStatAttribute(stmt);
            return stmt;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection,
                                                              String sql, int autoGeneratedKeys) throws SQLException {
        String dbType = connection.getDirectDataSource().getDbType();
        WallContext.create(dbType);
        try {
            sql = check(sql);
            PreparedStatementProxy stmt = chain.connection_prepareStatement(connection, sql, autoGeneratedKeys);
            setSqlStatAttribute(stmt);
            return stmt;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection,
                                                              String sql, int resultSetType, int resultSetConcurrency)
                                                                                                                      throws SQLException {
        String dbType = connection.getDirectDataSource().getDbType();
        WallContext.create(dbType);
        try {
            sql = check(sql);
            PreparedStatementProxy stmt = chain.connection_prepareStatement(connection, sql, resultSetType,
                                                                            resultSetConcurrency);
            setSqlStatAttribute(stmt);
            return stmt;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection,
                                                              String sql, int resultSetType, int resultSetConcurrency,
                                                              int resultSetHoldability) throws SQLException {
        String dbType = connection.getDirectDataSource().getDbType();
        WallContext.create(dbType);
        try {
            sql = check(sql);
            PreparedStatementProxy stmt = chain.connection_prepareStatement(connection, sql, resultSetType,
                                                                            resultSetConcurrency, resultSetHoldability);
            setSqlStatAttribute(stmt);
            return stmt;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection,
                                                              String sql, int[] columnIndexes) throws SQLException {
        String dbType = connection.getDirectDataSource().getDbType();
        WallContext.create(dbType);
        try {
            sql = check(sql);
            PreparedStatementProxy stmt = chain.connection_prepareStatement(connection, sql, columnIndexes);
            setSqlStatAttribute(stmt);
            return stmt;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection,
                                                              String sql, String[] columnNames) throws SQLException {
        String dbType = connection.getDirectDataSource().getDbType();
        WallContext.create(dbType);
        try {
            sql = check(sql);
            PreparedStatementProxy stmt = chain.connection_prepareStatement(connection, sql, columnNames);
            setSqlStatAttribute(stmt);
            return stmt;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public CallableStatementProxy connection_prepareCall(FilterChain chain, ConnectionProxy connection, String sql)
                                                                                                                   throws SQLException {
        String dbType = connection.getDirectDataSource().getDbType();
        WallContext.create(dbType);
        try {
            sql = check(sql);
            CallableStatementProxy stmt = chain.connection_prepareCall(connection, sql);
            setSqlStatAttribute(stmt);
            return stmt;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public CallableStatementProxy connection_prepareCall(FilterChain chain, ConnectionProxy connection, String sql,
                                                         int resultSetType, int resultSetConcurrency)
                                                                                                     throws SQLException {
        String dbType = connection.getDirectDataSource().getDbType();
        WallContext.create(dbType);
        try {
            sql = check(sql);
            CallableStatementProxy stmt = chain.connection_prepareCall(connection, sql, resultSetType,
                                                                       resultSetConcurrency);
            setSqlStatAttribute(stmt);
            return stmt;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public CallableStatementProxy connection_prepareCall(FilterChain chain, ConnectionProxy connection, String sql,
                                                         int resultSetType, int resultSetConcurrency,
                                                         int resultSetHoldability) throws SQLException {
        String dbType = connection.getDirectDataSource().getDbType();
        WallContext.create(dbType);
        try {
            sql = check(sql);
            CallableStatementProxy stmt = chain.connection_prepareCall(connection, sql, resultSetType,
                                                                       resultSetConcurrency, resultSetHoldability);
            setSqlStatAttribute(stmt);
            return stmt;
        } finally {
            WallContext.clearContext();
        }
    }

    // //////////////

    @Override
    public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql) throws SQLException {
        WallContext originalContext = WallContext.current();
        try {
            createWallContext(statement);
            sql = check(sql);
            boolean firstResult = chain.statement_execute(statement, sql);
            if (!firstResult) {
                int updateCount = statement.getUpdateCount();
                statExecuteUpdate(updateCount);
            }
            return firstResult;
        } finally {
            if (originalContext != null) {
                WallContext.setContext(originalContext);
            }
        }
    }

    @Override
    public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, int autoGeneratedKeys)
                                                                                                                    throws SQLException {
        createWallContext(statement);
        try {
            sql = check(sql);
            boolean firstResult = chain.statement_execute(statement, sql, autoGeneratedKeys);
            if (!firstResult) {
                int updateCount = statement.getUpdateCount();
                statExecuteUpdate(updateCount);
            }
            return firstResult;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, int columnIndexes[])
                                                                                                                  throws SQLException {
        createWallContext(statement);
        try {
            sql = check(sql);
            boolean firstResult = chain.statement_execute(statement, sql, columnIndexes);
            if (!firstResult) {
                int updateCount = statement.getUpdateCount();
                statExecuteUpdate(updateCount);
            }
            return firstResult;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public boolean statement_execute(FilterChain chain, StatementProxy statement, String sql, String columnNames[])
                                                                                                                   throws SQLException {
        createWallContext(statement);
        try {
            sql = check(sql);
            boolean firstResult = chain.statement_execute(statement, sql, columnNames);
            if (!firstResult) {
                int updateCount = statement.getUpdateCount();
                statExecuteUpdate(updateCount);
            }
            return firstResult;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public int[] statement_executeBatch(FilterChain chain, StatementProxy statement) throws SQLException {
        WallSqlStat sqlStat = (WallSqlStat) statement.getAttributes().get(ATTR_SQL_STAT);
        try {
            int[] updateCounts = chain.statement_executeBatch(statement);
            for (int i = 0; i < updateCounts.length; ++i) {
                int updateCount = updateCounts[i];
                statExecuteUpdate(sqlStat, updateCount);
            }
            return updateCounts;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public ResultSetProxy statement_executeQuery(FilterChain chain, StatementProxy statement, String sql)
                                                                                                         throws SQLException {
        createWallContext(statement);
        try {
            sql = check(sql);
            return chain.statement_executeQuery(statement, sql);
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql) throws SQLException {
        createWallContext(statement);
        try {
            sql = check(sql);
            int updateCount = chain.statement_executeUpdate(statement, sql);
            statExecuteUpdate(updateCount);
            return updateCount;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql, int autoGeneratedKeys)
                                                                                                                      throws SQLException {
        createWallContext(statement);
        try {
            sql = check(sql);
            int updateCount = chain.statement_executeUpdate(statement, sql, autoGeneratedKeys);
            statExecuteUpdate(updateCount);
            return updateCount;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql, int columnIndexes[])
                                                                                                                    throws SQLException {
        createWallContext(statement);
        try {
            sql = check(sql);
            int updateCount = chain.statement_executeUpdate(statement, sql, columnIndexes);
            statExecuteUpdate(updateCount);
            return updateCount;
        } finally {
            WallContext.clearContext();
        }
    }

    public String getDbType(StatementProxy statement) {
        return statement.getConnectionProxy().getDirectDataSource().getDbType();
    }

    private WallContext createWallContext(StatementProxy statement) {
        String dbType = getDbType(statement);
        WallContext context = WallContext.create(dbType);
        return context;
    }

    @Override
    public int statement_executeUpdate(FilterChain chain, StatementProxy statement, String sql, String columnNames[])
                                                                                                                     throws SQLException {
        createWallContext(statement);
        try {
            sql = check(sql);
            int updateCount = chain.statement_executeUpdate(statement, sql, columnNames);
            statExecuteUpdate(updateCount);
            return updateCount;
        } finally {
            WallContext.clearContext();
        }
    }

    @Override
    public boolean preparedStatement_execute(FilterChain chain, PreparedStatementProxy statement) throws SQLException {
        boolean firstResult = chain.preparedStatement_execute(statement);

        if (!firstResult) {
            WallSqlStat sqlStat = (WallSqlStat) statement.getAttributes().get(ATTR_SQL_STAT);
            int updateCount = statement.getUpdateCount();
            statExecuteUpdate(sqlStat, updateCount);
        }

        return firstResult;
    }

    @Override
    public ResultSetProxy preparedStatement_executeQuery(FilterChain chain, PreparedStatementProxy statement)
                                                                                                             throws SQLException {
        return chain.preparedStatement_executeQuery(statement);
    }

    @Override
    public int preparedStatement_executeUpdate(FilterChain chain, PreparedStatementProxy statement) throws SQLException {
        int updateCount = chain.preparedStatement_executeUpdate(statement);
        statExecuteUpdate(updateCount);
        return updateCount;
    }

    public void setSqlStatAttribute(PreparedStatementProxy stmt) {
        WallContext context = WallContext.current();
        if (context == null) {
            return;
        }

        WallSqlStat sqlStat = context.getSqlStat();
        if (sqlStat == null) {
            return;
        }

        stmt.getAttributes().put(ATTR_SQL_STAT, sqlStat);
    }

    public void statExecuteUpdate(int updateCount) {
        WallContext context = WallContext.current();
        if (context == null) {
            return;
        }

        WallSqlStat sqlStat = context.getSqlStat();
        if (sqlStat == null) {
            return;
        }

        if (updateCount > 0) {
            statExecuteUpdate(sqlStat, updateCount);
        }
    }

    private void statExecuteUpdate(WallSqlStat sqlStat, int updateCount) {
        if (sqlStat == null) {
            return;
        }
        
        Map<String, WallSqlTableStat> sqlTableStats = sqlStat.getTableStats();
        if (sqlTableStats == null) {
            return;
        }

        for (Map.Entry<String, WallSqlTableStat> entry : sqlTableStats.entrySet()) {
            String tableName = entry.getKey();
            WallTableStat tableStat = provider.getTableStat(tableName);
            if (tableStat == null) {
                continue;
            }

            WallSqlTableStat sqlTableStat = entry.getValue();

            if (sqlTableStat.getDeleteCount() > 0) {
                tableStat.addDeleteDataCount(updateCount);
            } else if (sqlTableStat.getUpdateCount() > 0) {
                tableStat.addUpdateDataCount(updateCount);
            }
        }
    }

    public String check(String sql) throws SQLException {
        WallCheckResult checkResult = provider.check(sql);
        List<Violation> violations = checkResult.getViolations();

        if (violations.size() > 0) {
            Violation firstViolation = violations.get(0);
            if (isLogViolation()) {
                LOG.error("sql injection violation, " + firstViolation.getMessage() + " : " + sql);
            }

            if (throwException) {
                if (violations.get(0) instanceof SyntaxErrorViolation) {
                    SyntaxErrorViolation violation = (SyntaxErrorViolation) violations.get(0);
                    throw new SQLException("sql injection violation, " + firstViolation.getMessage() + " : " + sql,
                                           violation.getException());
                } else {
                    throw new SQLException("sql injection violation, " + firstViolation.getMessage() + " : " + sql);
                }
            }
        }

        return sql;
    }

    @Override
    public boolean isWrapperFor(FilterChain chain, Wrapper wrapper, Class<?> iface) throws SQLException {
        if (config.isDoPrivilegedAllow() && WallProvider.ispPivileged()) {
            return chain.isWrapperFor(wrapper, iface);
        }

        if (!this.provider.getConfig().isWrapAllow()) {
            return false;
        }
        return chain.isWrapperFor(wrapper, iface);
    }

    @Override
    public <T> T unwrap(FilterChain chain, Wrapper wrapper, Class<T> iface) throws SQLException {
        if (config.isDoPrivilegedAllow() && WallProvider.ispPivileged()) {
            return chain.unwrap(wrapper, iface);
        }

        if (!this.provider.getConfig().isWrapAllow()) {
            return null;
        }

        return chain.unwrap(wrapper, iface);
    }

    @Override
    public DatabaseMetaData connection_getMetaData(FilterChain chain, ConnectionProxy connection) throws SQLException {
        if (config.isDoPrivilegedAllow() && WallProvider.ispPivileged()) {
            return chain.connection_getMetaData(connection);
        }

        if (!this.provider.getConfig().isMetadataAllow()) {
            if (isLogViolation()) {
                LOG.error("not support method : Connection.getMetdataData");
            }

            if (throwException) {
                throw new WallSQLException("not support method : Connection.getMetdataData");
            } else {

            }
        }

        return chain.connection_getMetaData(connection);
    }

    @Override
    public void resultSet_close(FilterChain chain, ResultSetProxy resultSet) throws SQLException {
        chain.resultSet_close(resultSet);
        int fetchRowCount = resultSet.getFetchRowCount();

        WallSqlStat sqlStat = (WallSqlStat) resultSet.getStatementProxy().getAttributes().get(ATTR_SQL_STAT);
        if (sqlStat == null) {
            return;
        }

        Map<String, WallSqlTableStat> sqlTableStats = sqlStat.getTableStats();
        if (sqlTableStats == null) {
            return;
        }

        for (Map.Entry<String, WallSqlTableStat> entry : sqlTableStats.entrySet()) {
            String tableName = entry.getKey();
            WallTableStat tableStat = provider.getTableStat(tableName);
            if (tableStat == null) {
                continue;
            }

            WallSqlTableStat sqlTableStat = entry.getValue();

            if (sqlTableStat.getSelectCount() > 0) {
                tableStat.addFetchRowCount(fetchRowCount);
            }
        }
    }

    public long getViolationCount() {
        return this.provider.getViolationCount();
    }

    public void resetViolationCount() {
        this.provider.reset();
    }

    public void clearWhiteList() {
        this.provider.clearCache();
    }

    public boolean checkValid(String sql) {
        return provider.checkValid(sql);
    }
}
