View Javadoc

1   package com.imcode.db.mock;
2   
3   import com.imcode.db.Database;
4   import com.imcode.db.DatabaseCommand;
5   import com.imcode.db.DatabaseException;
6   import junit.framework.Assert;
7   import org.apache.commons.collections.CollectionUtils;
8   import org.apache.commons.collections.Predicate;
9   import org.apache.commons.dbutils.ResultSetHandler;
10  import org.apache.commons.lang.ArrayUtils;
11  import org.apache.commons.lang.StringUtils;
12  
13  import java.sql.ResultSet;
14  import java.sql.SQLException;
15  import java.util.ArrayList;
16  import java.util.Arrays;
17  import java.util.Iterator;
18  import java.util.List;
19  import java.util.Map;
20  import java.util.regex.Matcher;
21  import java.util.regex.Pattern;
22  
23  public class MockDatabase implements Database {
24  
25      private List sqlCalls = new ArrayList();
26      private List expectedSqlCalls = new ArrayList();
27  
28      public int executeUpdate(String sqlStr, Object[] parameters) {
29          getResultForSqlCall(sqlStr, parameters);
30          return 0;
31      }
32  
33  
34      public Object executeQuery(String sqlQuery, Object[] parameters, ResultSetHandler resultSetHandler) {
35          ResultSet resultSet = (ResultSet) getResultForSqlCall(sqlQuery, parameters);
36          if (null == resultSet ) {
37              resultSet = new MockResultSet(new Object[0][]) ;
38          }
39          try {
40              return resultSetHandler.handle(resultSet) ;
41          } catch ( SQLException e ) {
42              throw DatabaseException.fromSQLException("", e);
43          }
44      }
45  
46      public Object execute(DatabaseCommand databaseCommand) throws DatabaseException {
47          return databaseCommand.executeOn(new MockDatabaseConnection(this));
48      }
49  
50      public Object executeCommand(DatabaseCommand databaseCommand) throws DatabaseException {
51          return execute(databaseCommand);
52      }
53  
54      public void addExpectedSqlCall(final SqlCallPredicate sqlCallPredicate, final Object result) {
55          expectedSqlCalls.add(new Map.Entry() {
56              public Object getKey() {
57                  return sqlCallPredicate;
58              }
59  
60              public Object getValue() {
61                  return result;
62              }
63  
64              public Object setValue(Object value) {
65                  throw new UnsupportedOperationException();
66              }
67  
68              public String toString() {
69                  return sqlCallPredicate + ": " + result;
70              }
71          });
72      }
73  
74      public void assertExpectedSqlCalls() {
75          if (!expectedSqlCalls.isEmpty()) {
76              Assert.fail("Remaining expected sql calls: " + expectedSqlCalls.toString());
77          }
78      }
79  
80      public int getSqlCallCount() {
81          return sqlCalls.size();
82      }
83  
84      Object getResultForSqlCall(String sql, Object[] params) {
85          SqlCall sqlCall = new SqlCall(sql, params);
86          sqlCalls.add(sqlCall);
87          Object result = null;
88          if (!expectedSqlCalls.isEmpty()) {
89              Map.Entry entry = (Map.Entry) expectedSqlCalls.get(0);
90              SqlCallPredicate predicate = (SqlCallPredicate) entry.getKey();
91              if (predicate.evaluateSqlCall(sqlCall)) {
92                  result = entry.getValue();
93                  expectedSqlCalls.remove(0);
94              }
95          }
96          return result;
97      }
98  
99      public static class SqlCall {
100 
101         private String string;
102         private Object[] parameters;
103 
104         public SqlCall(String string, Object[] parameters) {
105             this.string = string;
106             this.parameters = parameters;
107         }
108 
109         public String getString() {
110             return string;
111         }
112 
113         public Object[] getParameters() {
114             return parameters;
115         }
116 
117         public String toString() {
118             return getString() + " " + StringUtils.join(getParameters(), ", ");
119         }
120 
121     }
122 
123     public void assertCalled(SqlCallPredicate predicate) {
124         assertCalled(null, predicate);
125     }
126 
127     public void assertCalledInOrder(SqlCallPredicate[] sqlCallPredicates) {
128         int sqlCallPredicatesIndex = 0 ;
129         for ( Iterator iterator = sqlCalls.iterator(); iterator.hasNext(); ) {
130             SqlCall sqlCall = (SqlCall) iterator.next();
131             if (sqlCallPredicates[sqlCallPredicatesIndex].evaluateSqlCall(sqlCall)) {
132                 sqlCallPredicatesIndex++ ;
133                 if (sqlCallPredicatesIndex == sqlCallPredicates.length) {
134                     break ;
135                 }
136             }
137         }
138         if (sqlCallPredicatesIndex < sqlCallPredicates.length) {
139             String failureMessage = "Expected sql call \"" + sqlCallPredicates[sqlCallPredicatesIndex].getFailureMessage()+"\"";
140             if (sqlCallPredicatesIndex > 0) {
141                 failureMessage += " after sql call \""+sqlCallPredicates[sqlCallPredicatesIndex-1]+"\"" ;
142             }
143             Assert.fail(failureMessage) ;
144         }
145     }
146 
147     public void assertCalled(String message, SqlCallPredicate predicate) {
148         if (!called(predicate)) {
149             String messagePrefix = null == message ? "" : message + " ";
150             Assert.fail(messagePrefix + "Expected at least one sql call: " + predicate.getFailureMessage());
151         }
152     }
153 
154     private boolean called(SqlCallPredicate predicate) {
155         return CollectionUtils.exists(sqlCalls, predicate);
156     }
157 
158     public void assertNotCalled(SqlCallPredicate sqlCallPredicate) {
159         assertNotCalled(null, sqlCallPredicate);
160     }
161 
162     public void assertNotCalled(String message, SqlCallPredicate predicate) {
163         if (called(predicate)) {
164             String messagePrefix = null == message ? "" : message + " ";
165             Assert.fail(messagePrefix + "Got unexpected sql call: " + predicate.getFailureMessage());
166         }
167     }
168 
169     public void assertCallCount(int expectedCount, SqlCallPredicate predicate) {
170         int actualCount = CollectionUtils.countMatches(sqlCalls, predicate);
171         if (expectedCount != actualCount) {
172             Assert.fail("Expected " + expectedCount + ", but got " + actualCount + " sql calls: " + predicate.getFailureMessage());
173         }
174     }
175 
176     public abstract static class SqlCallPredicate implements Predicate {
177 
178         public final boolean evaluate(Object object) {
179             return evaluateSqlCall((MockDatabase.SqlCall) object);
180         }
181 
182         abstract boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall);
183 
184         abstract String getFailureMessage();
185 
186         public String toString() {
187             return getFailureMessage();
188         }
189     }
190 
191     public static class UpdateTableSqlCallPredicate extends SqlCallPredicate {
192 
193         private String tableName;
194         private Object parameter;
195 
196         public UpdateTableSqlCallPredicate(String tableName, Object parameter) {
197             this.tableName = tableName;
198             this.parameter = parameter;
199         }
200 
201         boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
202             boolean stringMatchesUpdateTableName = Pattern.compile("^update//s+//b" + tableName+"//b").matcher(sqlCall.getString().toLowerCase()).find();
203             boolean parametersContainsParameter = ArrayUtils.contains(sqlCall.getParameters(), parameter);
204             return stringMatchesUpdateTableName && parametersContainsParameter;
205         }
206 
207         String getFailureMessage() {
208             return "update of table " + tableName + " with one parameter = " + parameter;
209         }
210     }
211 
212     public static class InsertIntoTableSqlCallPredicate extends MatchesRegexSqlCallPredicate {
213 
214         private String tableName;
215 
216         public InsertIntoTableSqlCallPredicate(String tableName) {
217             super("^insert//s+(?:into//s+)?//b" + tableName+"//b") ;
218             this.tableName = tableName;
219         }
220 
221         String getFailureMessage() {
222             return "insert into table " + tableName ;
223         }
224     }
225 
226     public static class InsertIntoTableWithParameterSqlCallPredicate extends InsertIntoTableSqlCallPredicate {
227 
228         private String parameter;
229 
230         public InsertIntoTableWithParameterSqlCallPredicate(String tableName, String parameter) {
231             super(tableName);
232             this.parameter = parameter;
233         }
234 
235         boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
236             return super.evaluateSqlCall(sqlCall) && ArrayUtils.contains(sqlCall.getParameters(), parameter);
237         }
238 
239         String getFailureMessage() {
240             return super.getFailureMessage() + " with one parameter = \"" + parameter + "\"";
241         }
242     }
243 
244     public static class MatchesRegexSqlCallPredicate extends SqlCallPredicate {
245 
246         private String regex;
247 
248         public MatchesRegexSqlCallPredicate(String regex) {
249             this.regex = regex;
250         }
251 
252         boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
253             Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE);
254             Matcher matcher = pattern.matcher(sqlCall.getString());
255             return matcher.find();
256         }
257 
258         String getFailureMessage() {
259             return "Expected call to match regex " + regex;
260         }
261     }
262 
263     public static class EqualsSqlCallPredicate extends SqlCallPredicate {
264 
265         String sql;
266 
267         public EqualsSqlCallPredicate(String sql) {
268             this.sql = sql;
269         }
270 
271         boolean evaluateSqlCall(SqlCall sqlCall) {
272             return sql.equalsIgnoreCase(sqlCall.getString());
273         }
274 
275         String getFailureMessage() {
276             return "sql \"" + sql + "\"";
277         }
278     }
279 
280     public static class StartsWithSqlCallPredicate extends SqlCallPredicate {
281 
282         private String prefix;
283 
284         public StartsWithSqlCallPredicate(String prefix) {
285             this.prefix = prefix;
286         }
287 
288         boolean evaluateSqlCall(SqlCall sqlCall) {
289             return sqlCall.getString().startsWith(prefix);
290         }
291 
292         String getFailureMessage() {
293             return "start with " + prefix;
294         }
295     }
296 
297     public static class EqualsWithParametersSqlCallPredicate extends EqualsSqlCallPredicate {
298 
299         private String[] parameters;
300 
301         public EqualsWithParametersSqlCallPredicate(String sql, String[] parameters) {
302             super(sql);
303             this.parameters = parameters;
304         }
305 
306         boolean evaluateSqlCall(SqlCall sqlCall) {
307             return super.evaluateSqlCall(sqlCall) && Arrays.equals(parameters, sqlCall.getParameters());
308         }
309 
310         String getFailureMessage() {
311             return super.getFailureMessage() + " with parameters " + ArrayUtils.toString(parameters);
312         }
313     }
314 
315     public static class DeleteFromTableSqlCallPredicate extends MatchesRegexSqlCallPredicate {
316 
317         private String tableName;
318 
319         public DeleteFromTableSqlCallPredicate(String tableName) {
320             super("^delete//s+from//s+//b" + tableName+"//b") ;
321             this.tableName = tableName;
322         }
323 
324         String getFailureMessage() {
325             return "delete from "+tableName;
326         }
327 
328     }
329 }