View Javadoc

1   package com.imcode.db.mock;
2   
3   import com.imcode.db.AbstractDatabase;
4   import com.imcode.db.DatabaseCommand;
5   import com.imcode.db.DatabaseException;
6   import junit.framework.Assert;
7   import org.apache.commons.collections.CollectionUtils;
8   import org.apache.commons.collections.Predicate;
9   import org.apache.commons.dbutils.ResultSetHandler;
10  import org.apache.commons.lang.ArrayUtils;
11  import org.apache.commons.lang.StringUtils;
12  
13  import java.sql.ResultSet;
14  import java.sql.SQLException;
15  import java.util.ArrayList;
16  import java.util.Arrays;
17  import java.util.Iterator;
18  import java.util.List;
19  import java.util.Map;
20  import java.util.regex.Matcher;
21  import java.util.regex.Pattern;
22  
23  public class MockDatabase extends AbstractDatabase {
24  
25      private List sqlCalls = new ArrayList();
26      private List expectedSqlCalls = new ArrayList();
27  
28      public int executeUpdate(String sqlStr, Object[] parameters) {
29          getResultForSqlCall(sqlStr, parameters);
30          return 0;
31      }
32  
33  
34      public Object executeQuery(String sqlQuery, Object[] parameters, ResultSetHandler resultSetHandler) {
35          ResultSet resultSet = (ResultSet) getResultForSqlCall(sqlQuery, parameters);
36          if (null == resultSet ) {
37              resultSet = new MockResultSet(new Object[0][]) ;
38          }
39          try {
40              return resultSetHandler.handle(resultSet) ;
41          } catch ( SQLException e ) {
42              throw DatabaseException.fromSQLException("", e);
43          }
44      }
45  
46      public Object execute(DatabaseCommand databaseCommand) throws DatabaseException {
47          return databaseCommand.executeOn(new MockDatabaseConnection(this));
48      }
49  
50      public void addExpectedSqlCall(final SqlCallPredicate sqlCallPredicate, final Object result) {
51          expectedSqlCalls.add(new Map.Entry() {
52              public Object getKey() {
53                  return sqlCallPredicate;
54              }
55  
56              public Object getValue() {
57                  return result;
58              }
59  
60              public Object setValue(Object value) {
61                  throw new UnsupportedOperationException();
62              }
63  
64              public String toString() {
65                  return sqlCallPredicate + ": " + result;
66              }
67          });
68      }
69  
70      public void assertExpectedSqlCalls() {
71          if (!expectedSqlCalls.isEmpty()) {
72              Assert.fail("Remaining expected sql calls: " + expectedSqlCalls.toString());
73          }
74      }
75  
76      public int getSqlCallCount() {
77          return sqlCalls.size();
78      }
79  
80      Object getResultForSqlCall(String sql, Object[] params) {
81          SqlCall sqlCall = new SqlCall(sql, params);
82          sqlCalls.add(sqlCall);
83          Object result = null;
84          if (!expectedSqlCalls.isEmpty()) {
85              Map.Entry entry = (Map.Entry) expectedSqlCalls.get(0);
86              SqlCallPredicate predicate = (SqlCallPredicate) entry.getKey();
87              if (predicate.evaluateSqlCall(sqlCall)) {
88                  result = entry.getValue();
89                  expectedSqlCalls.remove(0);
90              }
91          }
92          return result;
93      }
94  
95      public static class SqlCall {
96  
97          private String string;
98          private Object[] parameters;
99  
100         public SqlCall(String string, Object[] parameters) {
101             this.string = string;
102             this.parameters = parameters;
103         }
104 
105         public String getString() {
106             return string;
107         }
108 
109         public Object[] getParameters() {
110             return parameters;
111         }
112 
113         public String toString() {
114             return getString() + " " + StringUtils.join(getParameters(), ", ");
115         }
116 
117     }
118 
119     public void assertCalled(SqlCallPredicate predicate) {
120         assertCalled(null, predicate);
121     }
122 
123     public void assertCalledInOrder(SqlCallPredicate[] sqlCallPredicates) {
124         int sqlCallPredicatesIndex = 0 ;
125         for ( Iterator iterator = sqlCalls.iterator(); iterator.hasNext(); ) {
126             SqlCall sqlCall = (SqlCall) iterator.next();
127             if (sqlCallPredicates[sqlCallPredicatesIndex].evaluateSqlCall(sqlCall)) {
128                 sqlCallPredicatesIndex++ ;
129                 if (sqlCallPredicatesIndex == sqlCallPredicates.length) {
130                     break ;
131                 }
132             }
133         }
134         if (sqlCallPredicatesIndex < sqlCallPredicates.length) {
135             String failureMessage = "Expected sql call \"" + sqlCallPredicates[sqlCallPredicatesIndex].getFailureMessage()+"\"";
136             if (sqlCallPredicatesIndex > 0) {
137                 failureMessage += " after sql call \""+sqlCallPredicates[sqlCallPredicatesIndex-1]+"\"" ;
138             }
139             Assert.fail(failureMessage) ;
140         }
141     }
142 
143     public void assertCalled(String message, SqlCallPredicate predicate) {
144         if (!called(predicate)) {
145             String messagePrefix = null == message ? "" : message + " ";
146             Assert.fail(messagePrefix + "Expected at least one sql call: " + predicate.getFailureMessage());
147         }
148     }
149 
150     private boolean called(SqlCallPredicate predicate) {
151         return CollectionUtils.exists(sqlCalls, predicate);
152     }
153 
154     public void assertNotCalled(SqlCallPredicate sqlCallPredicate) {
155         assertNotCalled(null, sqlCallPredicate);
156     }
157 
158     public void assertNotCalled(String message, SqlCallPredicate predicate) {
159         if (called(predicate)) {
160             String messagePrefix = null == message ? "" : message + " ";
161             Assert.fail(messagePrefix + "Got unexpected sql call: " + predicate.getFailureMessage());
162         }
163     }
164 
165     public void assertCallCount(int expectedCount, SqlCallPredicate predicate) {
166         int actualCount = CollectionUtils.countMatches(sqlCalls, predicate);
167         if (expectedCount != actualCount) {
168             Assert.fail("Expected " + expectedCount + ", but got " + actualCount + " sql calls: " + predicate.getFailureMessage());
169         }
170     }
171 
172     public abstract static class SqlCallPredicate implements Predicate {
173 
174         public final boolean evaluate(Object object) {
175             return evaluateSqlCall((MockDatabase.SqlCall) object);
176         }
177 
178         abstract boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall);
179 
180         abstract String getFailureMessage();
181 
182         public String toString() {
183             return getFailureMessage();
184         }
185     }
186 
187     public static class UpdateTableSqlCallPredicate extends SqlCallPredicate {
188 
189         private String tableName;
190         private Object parameter;
191 
192         public UpdateTableSqlCallPredicate(String tableName, Object parameter) {
193             this.tableName = tableName;
194             this.parameter = parameter;
195         }
196 
197         boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
198             boolean stringMatchesUpdateTableName = Pattern.compile("^update//s+//b" + tableName+"//b").matcher(sqlCall.getString().toLowerCase()).find();
199             boolean parametersContainsParameter = ArrayUtils.contains(sqlCall.getParameters(), parameter);
200             return stringMatchesUpdateTableName && parametersContainsParameter;
201         }
202 
203         String getFailureMessage() {
204             return "update of table " + tableName + " with one parameter = " + parameter;
205         }
206     }
207 
208     public static class InsertIntoTableSqlCallPredicate extends MatchesRegexSqlCallPredicate {
209 
210         private String tableName;
211 
212         public InsertIntoTableSqlCallPredicate(String tableName) {
213             super("^insert//s+(?:into//s+)?//b" + tableName+"//b") ;
214             this.tableName = tableName;
215         }
216 
217         String getFailureMessage() {
218             return "insert into table " + tableName ;
219         }
220     }
221 
222     public static class InsertIntoTableWithParameterSqlCallPredicate extends InsertIntoTableSqlCallPredicate {
223 
224         private String parameter;
225 
226         public InsertIntoTableWithParameterSqlCallPredicate(String tableName, String parameter) {
227             super(tableName);
228             this.parameter = parameter;
229         }
230 
231         boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
232             return super.evaluateSqlCall(sqlCall) && ArrayUtils.contains(sqlCall.getParameters(), parameter);
233         }
234 
235         String getFailureMessage() {
236             return super.getFailureMessage() + " with one parameter = \"" + parameter + "\"";
237         }
238     }
239 
240     public static class MatchesRegexSqlCallPredicate extends SqlCallPredicate {
241 
242         private String regex;
243 
244         public MatchesRegexSqlCallPredicate(String regex) {
245             this.regex = regex;
246         }
247 
248         boolean evaluateSqlCall(MockDatabase.SqlCall sqlCall) {
249             Pattern pattern = Pattern.compile(regex, Pattern.CASE_INSENSITIVE);
250             Matcher matcher = pattern.matcher(sqlCall.getString());
251             return matcher.find();
252         }
253 
254         String getFailureMessage() {
255             return "Expected call to match regex " + regex;
256         }
257     }
258 
259     public static class EqualsSqlCallPredicate extends SqlCallPredicate {
260 
261         String sql;
262 
263         public EqualsSqlCallPredicate(String sql) {
264             this.sql = sql;
265         }
266 
267         boolean evaluateSqlCall(SqlCall sqlCall) {
268             return sql.equalsIgnoreCase(sqlCall.getString());
269         }
270 
271         String getFailureMessage() {
272             return "sql \"" + sql + "\"";
273         }
274     }
275 
276     public static class StartsWithSqlCallPredicate extends SqlCallPredicate {
277 
278         private String prefix;
279 
280         public StartsWithSqlCallPredicate(String prefix) {
281             this.prefix = prefix;
282         }
283 
284         boolean evaluateSqlCall(SqlCall sqlCall) {
285             return sqlCall.getString().startsWith(prefix);
286         }
287 
288         String getFailureMessage() {
289             return "start with " + prefix;
290         }
291     }
292 
293     public static class EqualsWithParametersSqlCallPredicate extends EqualsSqlCallPredicate {
294 
295         private String[] parameters;
296 
297         public EqualsWithParametersSqlCallPredicate(String sql, String[] parameters) {
298             super(sql);
299             this.parameters = parameters;
300         }
301 
302         boolean evaluateSqlCall(SqlCall sqlCall) {
303             return super.evaluateSqlCall(sqlCall) && Arrays.equals(parameters, sqlCall.getParameters());
304         }
305 
306         String getFailureMessage() {
307             return super.getFailureMessage() + " with parameters " + ArrayUtils.toString(parameters);
308         }
309     }
310 
311     public static class DeleteFromTableSqlCallPredicate extends MatchesRegexSqlCallPredicate {
312 
313         private String tableName;
314 
315         public DeleteFromTableSqlCallPredicate(String tableName) {
316             super("^delete//s+from//s+//b" + tableName+"//b") ;
317             this.tableName = tableName;
318         }
319 
320         String getFailureMessage() {
321             return "delete from "+tableName;
322         }
323 
324     }
325 }