IBatis 和 MyBatis 实现对写操作的拦截

背景

有时候我们需要记录数据库的所有写操作,将这些信息采集起来用于分析,这里分别介绍IBatis和MyBatis中不同 的处理方法,

MyBatis拦截写操作

Mybatis的写操作的底层入口:org.apache.ibatis.executor.Executor#update
所以我们只需要拦截该方法即可

配置

在mybatis-congif.xml配置文件中添加:

<bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
        <property name="dataSource" ref="mySqlDataSource" />
        ...
        <property name="plugins">
            <array>
                <bean class="com.MyBatisExecutorIntercept"/>
            </array>
        </property>
    </bean>

拦截器

import com.alibaba.fastjson.JSONObject;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * MyBatis拦截Excuter.update方法,记录所有写操锁的SQL语句、参数、结果信息
 *
 * @author guolinlin
 * @version V1.0
 * @since 2017-09-26 15:41
 */
@Intercepts({ @Signature(type = Executor.class, method = "update", args = { MappedStatement.class, Object.class }) })
public class MyBatisExecutorIntercept implements Interceptor {
    private static final Logger LOG = LoggerFactory.getLogger(MyBatisExecutorIntercept.class);

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        Object parameter = null;
        if (invocation.getArgs().length > 1) {
            parameter = invocation.getArgs()[1];
        }
        String sqlId = mappedStatement.getId();
   
        BoundSql boundSql = mappedStatement.getBoundSql(parameter);
        Object returnValue = invocation.proceed();
        try {
            String messageId = DwCtxThreadLocalHolder.getDwCtxId();
            String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
            String params = getParams(boundSql);
            LOG.info("messageId:[ {} ],sqlId:[ {} ],SQL:[ {} ],params:[ {} ],result:[ {} ]", messageId, sqlId, sql,
                params, returnValue == null ? null : JSONObject.toJSONString(returnValue));
        } catch (Exception e) {
            LOG.error("数据采集异常!");
        } finally {
            return returnValue;
        }
    }

    private static Object getParameterValue(Object obj) {
        if (obj == null) {
            return null;
        }
        if (obj instanceof Date) {
            return TransformUtil.formatDate((Date) obj, "yyyy-MM-dd HH:mm:ss");
        }
        return obj;
    }

    public static String getParams(BoundSql boundSql) {
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        Map<String, Object> paramMap = null;
        if (parameterMappings.size() > 0 && parameterObject != null) {
            paramMap = new HashMap<String, Object>();
            if (parameterObject instanceof HashMap) {
                Iterator iter = ((HashMap) parameterObject).entrySet().iterator();
                while (iter.hasNext()) {
                    Map.Entry entry = (Map.Entry) iter.next();
                    String key = (String) entry.getKey();
                    if (key.startsWith("param")) {
                        continue;
                    }
                    Object val = entry.getValue();
                    paramMap.put(key, val);
                }

            } else {
                paramMap.put(parameterObject.getClass().getName(), getParameterValue(parameterObject));
            }
        }
        if (paramMap == null || paramMap.isEmpty()) {
            return null;
        }
        return JSONObject.toJSONString(paramMap);
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        String dialect = properties.getProperty("dialect");
        LOG.info("mybatis intercept dialect:{}", dialect);
    }
}


日志记录结果

采集数据的结果:

2017-09-27 19:29:23,057 [AkkaJavaSpring-qo-pinned-dispatcher-16] INFO  [c.g.h.r.i.MyBatisExecutorIntercept.intercept(64)]
messageId:[ 5d3b56ae-edd1-463e-b71d-61dea26ae0c5 ],
sqlId:[ Mapper.insert ],
SQL:[ insert into ***** ],
params:[ ******* ],
result:[ 1 ]

Ibatis拦截所有写操作

基本思路就是找到ibatis执行sql的地方,采集我们需要的信息。通过分析ibatis源码知道,最终负责执行sql的类是 com.ibatis.sqlmap.engine.execution.SqlExecutor。所有的写操作的入口是executeUpdate方法。
由于Ibatis没有提供Plugins的机制,所以需要自己继承该类来实现,当然还有一种方式就是修改源码,我们这里介绍的是继承的方式;

覆盖SqlExecutor


import com.alibaba.fastjson.JSONObject;
import com.ibatis.sqlmap.engine.execution.SqlExecutor;
import com.ibatis.sqlmap.engine.scope.StatementScope;

import java.sql.Connection;
import java.sql.SQLException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * 覆盖SqlExecutor的executeUpdate方法,采集数据
 *
 * @author guolinlin
 * @version V1.0
 * @since 2017-09-27 16:09
 */
public class HrsSqlExecutor extends SqlExecutor {
    private static final Logger LOG = LoggerFactory.getLogger(HrsSqlExecutor.class);

    @Override
    public int executeUpdate(StatementScope statementScope, Connection conn, String sql, Object[] parameters)
        throws SQLException {
        int result = super.executeUpdate(statementScope, conn, sql, parameters);
        String sqlId = statementScope.getStatement().getId();
        String messageId = DwCtxThreadLocalHolder.getDwCtxId();
        String params = JSONObject.toJSONString(parameters);
        sql = sql.replaceAll("[\\s]+", " ");
        LOG.info("messageId:[ {} ],sqlId:[ {} ],SQL:[ {} ],params:[ {} ],result:[ {} ]", messageId, sqlId, sql, params,
            result);
        return result;
    }
}

注入新的Executor

我们需要在BaseDAO中注入自定义的Executor,
由于sqlExecutor是 com.ibatis.sqlmap.engine.impl.SqlMapExecutorDelegate的私有成员,且没有公开的set方法,所以此处 通过反射绕过java的访问控制,来修改SqlMapExecutorDelegate的sqlExecutor

<bean id="mySqlExecutor" class="com.MySqlExecutor"/>
    <bean abstract="true" id="baseDAO" >
        <property name="sqlExecutor" ref="mySqlExecutor" />
        <property name="sqlMapClient" ref="mySqlMapClient" />
    </bean>
public class BaseDAO extends SqlMapClientDaoSupport {

    private static final Logger LOG = LoggerFactory.getLogger(BaseDAO.class);

    @Resource(name = "mySqlExecutor")
    private SqlExecutor sqlExecutor;

    public SqlExecutor getSqlExecutor() {
        return sqlExecutor;
    }

    public void setSqlExecutor(SqlExecutor sqlExecutor) {
        this.sqlExecutor = sqlExecutor;
    }

    @Resource(name = "mySqlMapClient")
    protected void injectSqlMapClient(SqlMapClient sqlMapClient) {
        if (sqlMapClient instanceof SqlMapClientImpl) {
            //用反射的方式,修改Delegate中的sqlExecutor成员变量,使之变更为新的Executor
            ReflectUtil.setFieldValue(((SqlMapClientImpl) sqlMapClient).getDelegate(), "sqlExecutor", SqlExecutor.class,
                sqlExecutor);
        }
        setSqlMapClient(sqlMapClient);
    }

    @SuppressWarnings("unchecked")
    protected <T> PageQuery<T> listByQuery(String statementName, PageQuery<T> query) {
        // 获取总记录数
        query.setTotalRecord(countByQuery(statementName + "_count", query));
        // 如果有记录
        if (query.getTotalRecord() > 0) {
            List<T> result = (List<T>) getSqlMapClientTemplate().queryForList(statementName, query);
            if (result == null) {
                result = new ArrayList<T>();
            }
            query.setDataList(result);
        }
        return query;
    }
    ...
}
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * 通过反射修改成员变量
 *
 * @author guolinlin
 * @version V1.0
 * @since 2017-09-27 16:37
 */
public class ReflectUtil {

    private static final Log logger = LogFactory.getLog(ReflectUtil.class);

    public static void setFieldValue(Object target, String fname, Class ftype, Object fvalue) {
        if (target == null || fname == null || "".equals(fname) || (fvalue != null && !ftype
            .isAssignableFrom(fvalue.getClass()))) {
            return;
        }
        Class clazz = target.getClass();
        try {
            Method method = clazz
                .getDeclaredMethod("set" + Character.toUpperCase(fname.charAt(0)) + fname.substring(1), ftype);
            if (!Modifier.isPublic(method.getModifiers())) {
                method.setAccessible(true);
            }
            method.invoke(target, fvalue);

        } catch (Exception me) {
            if (logger.isDebugEnabled()) {
                logger.debug(me);
            }
            try {
                Field field = clazz.getDeclaredField(fname);
                if (!Modifier.isPublic(field.getModifiers())) {
                    field.setAccessible(true);
                }
                field.set(target, fvalue);
            } catch (Exception fe) {
                if (logger.isDebugEnabled()) {
                    logger.debug(fe);
                }
            }
        }
    }
}

日志记录结果

2017-09-27 19:29:23,002 [qtp1280551684-187] INFO  [c.g.h.b.dal.intercept.MySqlExecutor.executeUpdate(40)]
messageId:[ 5d3b56ae-edd1-463e-b71d-61dea26ae0c5 ],
sqlId:[ orderOffline.insert ],
SQL:[  insert into **** ],
params:[ ********** ],
result:[ 1 ]

参考:

http://www.wang1314.com/doc/topic-1865158-1.html

http://www.blogjava.net/libin2722/articles/192504.html

    原文作者:冰零00
    原文地址: https://www.jianshu.com/p/98556a7d3969
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞