Sunday, November 02, 2008

JDBC debugging

Recently I had to muck through a whole bunch of code to fix a connection pooling issue because of which some partial results were getting committed into the database. While the actual problem turned out to be a commit getting issued deep in the custom framework code while querying for the nextval in an oracle sequence. As a by product of that investigation I wrote the following proxies to see which queries where executed by which connection. Usage is pretty simple. Just add the ConnectionInvocationHandler.createProxy(<your raw connection>) where ever you are getting the actual connection. The following classes only handle the most frequently occurring cases for me (for oracle).
package org.foo.sql;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.Statement;

public class ConnectionInvocationHandler implements InvocationHandler {
  
  Connection delegate = null;
  
  public static Connection createProxy(Connection conn) {
    return (Connection) Proxy.newProxyInstance(ConnectionInvocationHandler.class.getClassLoader(), new Class[]{ Connection.class }, new ConnectionInvocationHandler(conn));
  }
  
  ConnectionInvocationHandler(Connection original) {
    delegate = original;
  }

  public Object invoke(Object proxy, Method method, Object[] args)
      throws Throwable {
    Object result = method.invoke(delegate, args);
    String methodName = method.getName();
    if ( "prepareStatement".equals(methodName) ) {
      result = PreparedStatementInvocationHandler.createProxy((PreparedStatement) result, delegate.toString(), (String) args[0]);
    } else if ( "createStatement".equals(methodName) ) {
      result = StatementInvocationHandler.createProxy((Statement) result, delegate.toString());
    } else if ( "setAutoCommit".equals(methodName) ) {
      log("setAutoCommit(" + args[0] + ")");
    } else if ( "commit".equals(methodName) ) {
      log("commit()");
    } else if ( "rollback".equals(methodName) ) {
      log("rollback()");
    }
    return result;
  }
  
  private void log(String msg) {
    System.out.println("[" + delegate.toString() + "] " + msg);
  }
  
}
package org.foo.sql;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.Statement;

public class StatementInvocationHandler implements InvocationHandler {

 String connId = null;
 Statement delegate = null;
  
  static Statement createProxy(Statement stmt, String connId) {
    return (Statement) Proxy.newProxyInstance(StatementInvocationHandler.class.getClassLoader(), new Class[]{ Statement.class }, new StatementInvocationHandler(stmt, connId));
  }
  
  StatementInvocationHandler(Statement stmt, String connId) {
    delegate = stmt;
    this.connId = connId;
  }
  
  public Object invoke(Object proxy, Method method, Object[] args)
      throws Throwable {
    Object result = method.invoke(delegate, args);
    String methodName = method.getName();
    if ( "executeQuery".equals(methodName) ) {
      log((String) args[0]);
    } else if ( "executeUpdate".equals(methodName) ) {
      log((String) args[0]);
    }
    return result;
  }
  
  private void log(String msg) {
    System.out.println("[" + connId + "][" + delegate.toString() + "] " + msg);
  }
}
package org.foo.sql;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.PreparedStatement;
import java.text.DecimalFormat;

public class PreparedStatementInvocationHandler implements InvocationHandler {
  
  String connId = null;
  StringBuilder query = null;
  PreparedStatement delegate = null;
  final DecimalFormat df = new DecimalFormat("0.#");
  
  static PreparedStatement createProxy(PreparedStatement pstmt, String connId, String query) {
    return (PreparedStatement) Proxy.newProxyInstance(PreparedStatementInvocationHandler.class.getClassLoader(), new Class[]{ PreparedStatement.class }, new PreparedStatementInvocationHandler(pstmt, connId, query));
  }
  
  PreparedStatementInvocationHandler(PreparedStatement original, String connId, String query) {
    delegate = original;
    this.connId = connId;
    this.query = new StringBuilder(query);
  }

  public Object invoke(Object proxy, Method method, Object[] args)
      throws Throwable {
    Object result = method.invoke(delegate, args);
    String methodName = method.getName();
    if ( "setInt".equals(methodName) ) {
      addParam(args[1].toString());
    } else if ( "setLong".equals(methodName) ) {
      addParam(args[1].toString());
    } else if ( "setDouble".equals(methodName) ) {
      addParam(df.format(((Double) args[1]).doubleValue()));
    } else if ( "setTimestamp".equals(methodName) ) {
      String dParam = args[1].toString();
      dParam = dParam.substring(dParam.indexOf('.')); // remove the ms
      addParam("to_date('" + dParam + "', 'yyyy-mm-dd hh24:mi:ss')");
    } else if ( "setString".equals(methodName) ) {
      addParam("'" + args[1].toString() + "'");
    } else if ( methodName.startsWith("set") ) {
      addParam("'<" + methodName.substring(3) + ">'");
    } else if ( "executeUpdate".equals(methodName) ) {
      log(query.toString());
    } else if ( "executeQuery".equals(methodName) ) {
      log(query.toString());
    }
    return result;
  }
  
  protected void addParam(String param) {
    int idx = query.indexOf("?");
    if ( idx > -1 ) {
      query.replace(idx, idx + 1, param);
    }
  }
  
  protected void log(String msg) {
    System.out.println("[" + connId + "][" + delegate.toString() + "] " + msg);
  }
}