1 package com.imcode.db.mock; 2 3 import com.imcode.db.Database; 4 import com.imcode.db.DatabaseCommand; 5 import com.imcode.db.DatabaseException; 6 import junit.framework.Assert; 7 import org.apache.commons.collections.CollectionUtils; 8 import org.apache.commons.collections.Predicate; 9 import org.apache.commons.dbutils.ResultSetHandler; 10 import org.apache.commons.lang.ArrayUtils; 11 import org.apache.commons.lang.StringUtils; 12 13 import java.sql.ResultSet; 14 import java.sql.SQLException; 15 import java.util.ArrayList; 16 import java.util.Arrays; 17 import java.util.Iterator; 18 import java.util.List; 19 import java.util.Map; 20 import java.util.regex.Matcher; 21 import java.util.regex.Pattern; 22 23 public class MockDatabase implements Database { 24 25 private List sqlCalls = new ArrayList(); 26 private List expectedSqlCalls = new ArrayList(); 27 28 public int executeUpdate(String sqlStr, Object[] parameters) { 29 getResultForSqlCall(sqlStr, parameters); 30 return 0; 31 } 32 33 34 public Object executeQuery(String sqlQuery, Object[] parameters, ResultSetHandler resultSetHandler) { 35 ResultSet resultSet = (ResultSet) getResultForSqlCall(sqlQuery, parameters); 36 if (null == resultSet ) { 37 resultSet = new MockResultSet(new Object[0][]) ; 38 } 39 try { 40 return resultSetHandler.handle(resultSet) ; 41 } catch ( SQLException e ) { 42 throw DatabaseException.fromSQLException("", e); 43 } 44 } 45 46 public Object execute(DatabaseCommand databaseCommand) throws DatabaseException { 47 return databaseCommand.executeOn(new MockDatabaseConnection(this)); 48 } 49 50 public Object executeCommand(DatabaseCommand databaseCommand) throws DatabaseException { 51 return execute(databaseCommand); 52 } 53 54 public void addExpectedSqlCall(final SqlCallPredicate sqlCallPredicate, final Object result) { 55 expectedSqlCalls.add(new Map.Entry() { 56 public Object getKey() { 57 return sqlCallPredicate; 58 } 59 60 public Object getValue() { 61 return result; 62 } 63 64 public Object setValue(Object value) { 65 throw new UnsupportedOperationException(); 66 } 67 68 public String toString() { 69 return sqlCallPredicate + ": " + result; 70 } 71 }); 72 } 73 74 public void assertExpectedSqlCalls() { 75 if (!expectedSqlCalls.isEmpty()) { 76 Assert.fail("Remaining expected sql calls: " + expectedSqlCalls.toString()); 77 } 78 } 79 80 public int getSqlCallCount() { 81 return sqlCalls.size(); 82 } 83 84 Object getResultForSqlCall(String sql, Object[] params) { 85 SqlCall sqlCall = new SqlCall(sql, params); 86 sqlCalls.add(sqlCall); 87 Object result = null; 88 if (!expectedSqlCalls.isEmpty()) { 89 Map.Entry entry = (Map.Entry) expectedSqlCalls.get(0); 90 SqlCallPredicate predicate = (SqlCallPredicate) entry.getKey(); 91 if (predicate.evaluateSqlCall(sqlCall)) { 92 result = entry.getValue(); 93 expectedSqlCalls.remove(0); 94 } 95 } 96 return result; 97 } 98 99 public static class SqlCall { 100 101 private String string; 102 private Object[] parameters; 103 104 public SqlCall(String string, Object[] parameters) { 105 this.string = string; 106 this.parameters = parameters; 107 } 108 109 public String getString() { 110 return string; 111 } 112 113 public Object[] getParameters() { 114 return parameters; 115 } 116 117 public String toString() { 118 return getString() + " " + StringUtils.join(getParameters(), ", "); 119 } 120 121 } 122 123 public void assertCalled(SqlCallPredicate predicate) { 124 assertCalled(null, predicate); 125 } 126 127 public void assertCalledInOrder(SqlCallPredicate[] sqlCallPredicates) { 128 int sqlCallPredicatesIndex = 0 ; 129 for ( Iterator iterator = sqlCalls.iterator(); iterator.hasNext(); ) { 130 SqlCall sqlCall = (SqlCall) iterator.next(); 131 if (sqlCallPredicates[sqlCallPredicatesIndex].evaluateSqlCall(sqlCall)) { 132 sqlCallPredicatesIndex++ ; 133 if (sqlCallPredicatesIndex == sqlCallPredicates.length) { 134 break ; 135 } 136 } 137 } 138 if (sqlCallPredicatesIndex < sqlCallPredicates.length) { 139 String failureMessage = "Expected sql call \"" + sqlCallPredicates[sqlCallPredicatesIndex].getFailureMessage()+"\""; 140 if (sqlCallPredicatesIndex > 0) { 141 failureMessage += " after sql call \""+sqlCallPredicates[sqlCallPredicatesIndex-1]+"\"" ; 142 } 143 Assert.fail(failureMessage) ; 144 } 145 } 146 147 public void assertCalled(String message, SqlCallPredicate predicate) { 148 if (!called(predicate)) { 149 String messagePrefix = null == message ? "" : message + " "; 150 Assert.fail(messagePrefix + "Expected at least one sql call: " + predicate.getFailureMessage()); 151 } 152 } 153 154 private boolean called(SqlCallPredicate predicate) { 155 return CollectionUtils.exists(sqlCalls, predicate); 156 } 157 158 public void assertNotCalled(SqlCallPredicate sqlCallPredicate) { 159 assertNotCalled(null, sqlCallPredicate); 160 } 161 162 public void assertNotCalled(String message, SqlCallPredicate predicate) { 163 if (called(predicate)) { 164 String messagePrefix = null == message ? "" : message + " "; 165 Assert.fail(messagePrefix + "Got unexpected sql call: " + predicate.getFailureMessage()); 166 } 167 } 168 169 public void assertCallCount(int expectedCount, SqlCallPredicate predicate) { 170 int actualCount = CollectionUtils.countMatches(sqlCalls, predicate); 171 if (expectedCount != actualCount) { 172 Assert.fail("Expected " + expectedCount + ", but got " + actualCount + " sql calls: " + predicate.getFailureMessage()); 173 } 174 } 175 176 public abstract static class SqlCallPredicate implements Predicate { 177 178 public final boolean evaluate(Object object) { 179 return evaluateSqlCall((MockDatabase.SqlCall) object); 180 } 181 182 abstract boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall); 183 184 abstract String getFailureMessage(); 185 186 public String toString() { 187 return getFailureMessage(); 188 } 189 } 190 191 public static class UpdateTableSqlCallPredicate extends SqlCallPredicate { 192 193 private String tableName; 194 private Object parameter; 195 196 public UpdateTableSqlCallPredicate(String tableName, Object parameter) { 197 this.tableName = tableName; 198 this.parameter = parameter; 199 } 200 201 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) { 202 boolean stringMatchesUpdateTableName = Pattern.compile("^update//s+//b" + tableName+"//b").matcher(sqlCall.getString().toLowerCase()).find(); 203 boolean parametersContainsParameter = ArrayUtils.contains(sqlCall.getParameters(), parameter); 204 return stringMatchesUpdateTableName && parametersContainsParameter; 205 } 206 207 String getFailureMessage() { 208 return "update of table " + tableName + " with one parameter = " + parameter; 209 } 210 } 211 212 public static class InsertIntoTableSqlCallPredicate extends MatchesRegexSqlCallPredicate { 213 214 private String tableName; 215 216 public InsertIntoTableSqlCallPredicate(String tableName) { 217 super("^insert//s+(?:into//s+)?//b" + tableName+"//b") ; 218 this.tableName = tableName; 219 } 220 221 String getFailureMessage() { 222 return "insert into table " + tableName ; 223 } 224 } 225 226 public static class InsertIntoTableWithParameterSqlCallPredicate extends InsertIntoTableSqlCallPredicate { 227 228 private String parameter; 229 230 public InsertIntoTableWithParameterSqlCallPredicate(String tableName, String parameter) { 231 super(tableName); 232 this.parameter = parameter; 233 } 234 235 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) { 236 return super.evaluateSqlCall(sqlCall) && ArrayUtils.contains(sqlCall.getParameters(), parameter); 237 } 238 239 String getFailureMessage() { 240 return super.getFailureMessage() + " with one parameter = \"" + parameter + "\""; 241 } 242 } 243 244 public static class MatchesRegexSqlCallPredicate extends SqlCallPredicate { 245 246 private String regex; 247 248 public MatchesRegexSqlCallPredicate(String regex) { 249 this.regex = regex; 250 } 251 252 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) { 253 Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE); 254 Matcher matcher = pattern.matcher(sqlCall.getString()); 255 return matcher.find(); 256 } 257 258 String getFailureMessage() { 259 return "Expected call to match regex " + regex; 260 } 261 } 262 263 public static class EqualsSqlCallPredicate extends SqlCallPredicate { 264 265 String sql; 266 267 public EqualsSqlCallPredicate(String sql) { 268 this.sql = sql; 269 } 270 271 boolean evaluateSqlCall(SqlCall sqlCall) { 272 return sql.equalsIgnoreCase(sqlCall.getString()); 273 } 274 275 String getFailureMessage() { 276 return "sql \"" + sql + "\""; 277 } 278 } 279 280 public static class StartsWithSqlCallPredicate extends SqlCallPredicate { 281 282 private String prefix; 283 284 public StartsWithSqlCallPredicate(String prefix) { 285 this.prefix = prefix; 286 } 287 288 boolean evaluateSqlCall(SqlCall sqlCall) { 289 return sqlCall.getString().startsWith(prefix); 290 } 291 292 String getFailureMessage() { 293 return "start with " + prefix; 294 } 295 } 296 297 public static class EqualsWithParametersSqlCallPredicate extends EqualsSqlCallPredicate { 298 299 private String[] parameters; 300 301 public EqualsWithParametersSqlCallPredicate(String sql, String[] parameters) { 302 super(sql); 303 this.parameters = parameters; 304 } 305 306 boolean evaluateSqlCall(SqlCall sqlCall) { 307 return super.evaluateSqlCall(sqlCall) && Arrays.equals(parameters, sqlCall.getParameters()); 308 } 309 310 String getFailureMessage() { 311 return super.getFailureMessage() + " with parameters " + ArrayUtils.toString(parameters); 312 } 313 } 314 315 public static class DeleteFromTableSqlCallPredicate extends MatchesRegexSqlCallPredicate { 316 317 private String tableName; 318 319 public DeleteFromTableSqlCallPredicate(String tableName) { 320 super("^delete//s+from//s+//b" + tableName+"//b") ; 321 this.tableName = tableName; 322 } 323 324 String getFailureMessage() { 325 return "delete from "+tableName; 326 } 327 328 } 329 }