1 package com.imcode.db.mock; 2 3 import com.imcode.db.AbstractDatabase; 4 import com.imcode.db.DatabaseCommand; 5 import com.imcode.db.DatabaseException; 6 import junit.framework.Assert; 7 import org.apache.commons.collections.CollectionUtils; 8 import org.apache.commons.collections.Predicate; 9 import org.apache.commons.dbutils.ResultSetHandler; 10 import org.apache.commons.lang.ArrayUtils; 11 import org.apache.commons.lang.StringUtils; 12 13 import java.sql.ResultSet; 14 import java.sql.SQLException; 15 import java.util.ArrayList; 16 import java.util.Arrays; 17 import java.util.Iterator; 18 import java.util.List; 19 import java.util.Map; 20 import java.util.regex.Matcher; 21 import java.util.regex.Pattern; 22 23 public class MockDatabase extends AbstractDatabase { 24 25 private List sqlCalls = new ArrayList(); 26 private List expectedSqlCalls = new ArrayList(); 27 28 public int executeUpdate(String sqlStr, Object[] parameters) { 29 getResultForSqlCall(sqlStr, parameters); 30 return 0; 31 } 32 33 34 public Object executeQuery(String sqlQuery, Object[] parameters, ResultSetHandler resultSetHandler) { 35 ResultSet resultSet = (ResultSet) getResultForSqlCall(sqlQuery, parameters); 36 if (null == resultSet ) { 37 resultSet = new MockResultSet(new Object[0][]) ; 38 } 39 try { 40 return resultSetHandler.handle(resultSet) ; 41 } catch ( SQLException e ) { 42 throw DatabaseException.fromSQLException("", e); 43 } 44 } 45 46 public Object execute(DatabaseCommand databaseCommand) throws DatabaseException { 47 return databaseCommand.executeOn(new MockDatabaseConnection(this)); 48 } 49 50 public void addExpectedSqlCall(final SqlCallPredicate sqlCallPredicate, final Object result) { 51 expectedSqlCalls.add(new Map.Entry() { 52 public Object getKey() { 53 return sqlCallPredicate; 54 } 55 56 public Object getValue() { 57 return result; 58 } 59 60 public Object setValue(Object value) { 61 throw new UnsupportedOperationException(); 62 } 63 64 public String toString() { 65 return sqlCallPredicate + ": " + result; 66 } 67 }); 68 } 69 70 public void assertExpectedSqlCalls() { 71 if (!expectedSqlCalls.isEmpty()) { 72 Assert.fail("Remaining expected sql calls: " + expectedSqlCalls.toString()); 73 } 74 } 75 76 public int getSqlCallCount() { 77 return sqlCalls.size(); 78 } 79 80 Object getResultForSqlCall(String sql, Object[] params) { 81 SqlCall sqlCall = new SqlCall(sql, params); 82 sqlCalls.add(sqlCall); 83 Object result = null; 84 if (!expectedSqlCalls.isEmpty()) { 85 Map.Entry entry = (Map.Entry) expectedSqlCalls.get(0); 86 SqlCallPredicate predicate = (SqlCallPredicate) entry.getKey(); 87 if (predicate.evaluateSqlCall(sqlCall)) { 88 result = entry.getValue(); 89 expectedSqlCalls.remove(0); 90 } 91 } 92 return result; 93 } 94 95 public static class SqlCall { 96 97 private String string; 98 private Object[] parameters; 99 100 public SqlCall(String string, Object[] parameters) { 101 this.string = string; 102 this.parameters = parameters; 103 } 104 105 public String getString() { 106 return string; 107 } 108 109 public Object[] getParameters() { 110 return parameters; 111 } 112 113 public String toString() { 114 return getString() + " " + StringUtils.join(getParameters(), ", "); 115 } 116 117 } 118 119 public void assertCalled(SqlCallPredicate predicate) { 120 assertCalled(null, predicate); 121 } 122 123 public void assertCalledInOrder(SqlCallPredicate[] sqlCallPredicates) { 124 int sqlCallPredicatesIndex = 0 ; 125 for ( Iterator iterator = sqlCalls.iterator(); iterator.hasNext(); ) { 126 SqlCall sqlCall = (SqlCall) iterator.next(); 127 if (sqlCallPredicates[sqlCallPredicatesIndex].evaluateSqlCall(sqlCall)) { 128 sqlCallPredicatesIndex++ ; 129 if (sqlCallPredicatesIndex == sqlCallPredicates.length) { 130 break ; 131 } 132 } 133 } 134 if (sqlCallPredicatesIndex < sqlCallPredicates.length) { 135 String failureMessage = "Expected sql call \"" + sqlCallPredicates[sqlCallPredicatesIndex].getFailureMessage()+"\""; 136 if (sqlCallPredicatesIndex > 0) { 137 failureMessage += " after sql call \""+sqlCallPredicates[sqlCallPredicatesIndex-1]+"\"" ; 138 } 139 Assert.fail(failureMessage) ; 140 } 141 } 142 143 public void assertCalled(String message, SqlCallPredicate predicate) { 144 if (!called(predicate)) { 145 String messagePrefix = null == message ? "" : message + " "; 146 Assert.fail(messagePrefix + "Expected at least one sql call: " + predicate.getFailureMessage()); 147 } 148 } 149 150 private boolean called(SqlCallPredicate predicate) { 151 return CollectionUtils.exists(sqlCalls, predicate); 152 } 153 154 public void assertNotCalled(SqlCallPredicate sqlCallPredicate) { 155 assertNotCalled(null, sqlCallPredicate); 156 } 157 158 public void assertNotCalled(String message, SqlCallPredicate predicate) { 159 if (called(predicate)) { 160 String messagePrefix = null == message ? "" : message + " "; 161 Assert.fail(messagePrefix + "Got unexpected sql call: " + predicate.getFailureMessage()); 162 } 163 } 164 165 public void assertCallCount(int expectedCount, SqlCallPredicate predicate) { 166 int actualCount = CollectionUtils.countMatches(sqlCalls, predicate); 167 if (expectedCount != actualCount) { 168 Assert.fail("Expected " + expectedCount + ", but got " + actualCount + " sql calls: " + predicate.getFailureMessage()); 169 } 170 } 171 172 public abstract static class SqlCallPredicate implements Predicate { 173 174 public final boolean evaluate(Object object) { 175 return evaluateSqlCall((MockDatabase.SqlCall) object); 176 } 177 178 abstract boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall); 179 180 abstract String getFailureMessage(); 181 182 public String toString() { 183 return getFailureMessage(); 184 } 185 } 186 187 public static class UpdateTableSqlCallPredicate extends SqlCallPredicate { 188 189 private String tableName; 190 private Object parameter; 191 192 public UpdateTableSqlCallPredicate(String tableName, Object parameter) { 193 this.tableName = tableName; 194 this.parameter = parameter; 195 } 196 197 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) { 198 boolean stringMatchesUpdateTableName = Pattern.compile("^update//s+//b" + tableName+"//b").matcher(sqlCall.getString().toLowerCase()).find(); 199 boolean parametersContainsParameter = ArrayUtils.contains(sqlCall.getParameters(), parameter); 200 return stringMatchesUpdateTableName && parametersContainsParameter; 201 } 202 203 String getFailureMessage() { 204 return "update of table " + tableName + " with one parameter = " + parameter; 205 } 206 } 207 208 public static class InsertIntoTableSqlCallPredicate extends MatchesRegexSqlCallPredicate { 209 210 private String tableName; 211 212 public InsertIntoTableSqlCallPredicate(String tableName) { 213 super("^insert//s+(?:into//s+)?//b" + tableName+"//b") ; 214 this.tableName = tableName; 215 } 216 217 String getFailureMessage() { 218 return "insert into table " + tableName ; 219 } 220 } 221 222 public static class InsertIntoTableWithParameterSqlCallPredicate extends InsertIntoTableSqlCallPredicate { 223 224 private String parameter; 225 226 public InsertIntoTableWithParameterSqlCallPredicate(String tableName, String parameter) { 227 super(tableName); 228 this.parameter = parameter; 229 } 230 231 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) { 232 return super.evaluateSqlCall(sqlCall) && ArrayUtils.contains(sqlCall.getParameters(), parameter); 233 } 234 235 String getFailureMessage() { 236 return super.getFailureMessage() + " with one parameter = \"" + parameter + "\""; 237 } 238 } 239 240 public static class MatchesRegexSqlCallPredicate extends SqlCallPredicate { 241 242 private String regex; 243 244 public MatchesRegexSqlCallPredicate(String regex) { 245 this.regex = regex; 246 } 247 248 boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) { 249 Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE); 250 Matcher matcher = pattern.matcher(sqlCall.getString()); 251 return matcher.find(); 252 } 253 254 String getFailureMessage() { 255 return "Expected call to match regex " + regex; 256 } 257 } 258 259 public static class EqualsSqlCallPredicate extends SqlCallPredicate { 260 261 String sql; 262 263 public EqualsSqlCallPredicate(String sql) { 264 this.sql = sql; 265 } 266 267 boolean evaluateSqlCall(SqlCall sqlCall) { 268 return sql.equalsIgnoreCase(sqlCall.getString()); 269 } 270 271 String getFailureMessage() { 272 return "sql \"" + sql + "\""; 273 } 274 } 275 276 public static class StartsWithSqlCallPredicate extends SqlCallPredicate { 277 278 private String prefix; 279 280 public StartsWithSqlCallPredicate(String prefix) { 281 this.prefix = prefix; 282 } 283 284 boolean evaluateSqlCall(SqlCall sqlCall) { 285 return sqlCall.getString().startsWith(prefix); 286 } 287 288 String getFailureMessage() { 289 return "start with " + prefix; 290 } 291 } 292 293 public static class EqualsWithParametersSqlCallPredicate extends EqualsSqlCallPredicate { 294 295 private String[] parameters; 296 297 public EqualsWithParametersSqlCallPredicate(String sql, String[] parameters) { 298 super(sql); 299 this.parameters = parameters; 300 } 301 302 boolean evaluateSqlCall(SqlCall sqlCall) { 303 return super.evaluateSqlCall(sqlCall) && Arrays.equals(parameters, sqlCall.getParameters()); 304 } 305 306 String getFailureMessage() { 307 return super.getFailureMessage() + " with parameters " + ArrayUtils.toString(parameters); 308 } 309 } 310 311 public static class DeleteFromTableSqlCallPredicate extends MatchesRegexSqlCallPredicate { 312 313 private String tableName; 314 315 public DeleteFromTableSqlCallPredicate(String tableName) { 316 super("^delete//s+from//s+//b" + tableName+"//b") ; 317 this.tableName = tableName; 318 } 319 320 String getFailureMessage() { 321 return "delete from "+tableName; 322 } 323 324 } 325 }