1package com.google.inject.servlet;
2
3import static org.easymock.EasyMock.anyObject;
4import static org.easymock.EasyMock.createMock;
5import static org.easymock.EasyMock.expect;
6import static org.easymock.EasyMock.replay;
7import static org.easymock.EasyMock.verify;
8
9import com.google.common.collect.ImmutableMap;
10import com.google.common.collect.Sets;
11import com.google.inject.Binding;
12import com.google.inject.Injector;
13import com.google.inject.Key;
14import com.google.inject.spi.BindingScopingVisitor;
15
16import junit.framework.TestCase;
17
18import java.io.IOException;
19import java.util.Enumeration;
20import java.util.HashMap;
21import java.util.Map;
22
23import javax.servlet.Filter;
24import javax.servlet.FilterChain;
25import javax.servlet.FilterConfig;
26import javax.servlet.ServletContext;
27import javax.servlet.ServletException;
28import javax.servlet.ServletRequest;
29import javax.servlet.ServletResponse;
30import javax.servlet.http.HttpServletRequest;
31
32/**
33 * Tests the lifecycle of the encapsulated {@link FilterDefinition} class.
34 *
35 * @author Dhanji R. Prasanna (dhanji@gmail com)
36 */
37public class FilterDefinitionTest extends TestCase {
38
39  public final void testFilterInitAndConfig() throws ServletException {
40
41    Injector injector = createMock(Injector.class);
42    Binding binding = createMock(Binding.class);
43
44    final MockFilter mockFilter = new MockFilter();
45
46    expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject()))
47        .andReturn(true);
48    expect(injector.getBinding(Key.get(Filter.class)))
49        .andReturn(binding);
50
51    expect(injector.getInstance(Key.get(Filter.class)))
52        .andReturn(mockFilter)
53        .anyTimes();
54
55    replay(binding, injector);
56
57    //some init params
58    //noinspection SSBasedInspection
59    final Map<String, String> initParams = new ImmutableMap.Builder<String, String>()
60      .put("ahsd", "asdas24dok")
61      .put("ahssd", "asdasd124ok").build();
62
63    ServletContext servletContext = createMock(ServletContext.class);
64    final String contextName = "thing__!@@44";
65    expect(servletContext.getServletContextName()).andReturn(contextName);
66
67    replay(servletContext);
68
69    String pattern = "/*";
70    final FilterDefinition filterDef = new FilterDefinition(pattern, Key.get(Filter.class),
71        UriPatternType.get(UriPatternType.SERVLET, pattern), initParams, null);
72    filterDef.init(servletContext, injector, Sets.<Filter>newIdentityHashSet());
73
74    assertTrue(filterDef.getFilter() instanceof MockFilter);
75    final FilterConfig filterConfig = mockFilter.getConfig();
76    assertTrue(null != filterConfig);
77    assertEquals(filterConfig.getServletContext().getServletContextName(), contextName);
78    assertEquals(filterConfig.getFilterName(), Key.get(Filter.class).toString());
79
80    final Enumeration names = filterConfig.getInitParameterNames();
81    while (names.hasMoreElements()) {
82      String name = (String) names.nextElement();
83
84      assertTrue(initParams.containsKey(name));
85      assertEquals(filterConfig.getInitParameter(name), initParams.get(name));
86    }
87
88    verify(binding, injector, servletContext);
89  }
90
91  public final void testFilterCreateDispatchDestroy() throws ServletException, IOException {
92    Injector injector = createMock(Injector.class);
93    Binding binding = createMock(Binding.class);
94    HttpServletRequest request = createMock(HttpServletRequest.class);
95
96    final MockFilter mockFilter = new MockFilter();
97
98    expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject()))
99        .andReturn(true);
100    expect(injector.getBinding(Key.get(Filter.class)))
101        .andReturn(binding);
102
103    expect(injector.getInstance(Key.get(Filter.class)))
104        .andReturn(mockFilter)
105        .anyTimes();
106
107    expect(request.getRequestURI()).andReturn("/index.html");
108    expect(request.getContextPath())
109        .andReturn("")
110        .anyTimes();
111
112    replay(injector, binding, request);
113
114    String pattern = "/*";
115    final FilterDefinition filterDef = new FilterDefinition(pattern, Key.get(Filter.class),
116        UriPatternType.get(UriPatternType.SERVLET, pattern),
117        new HashMap<String, String>(), null);
118    //should fire on mockfilter now
119    filterDef.init(createMock(ServletContext.class), injector, Sets.<Filter>newIdentityHashSet());
120    assertTrue(filterDef.getFilter() instanceof MockFilter);
121
122    assertTrue("Init did not fire", mockFilter.isInit());
123
124    Filter matchingFilter = filterDef.getFilterIfMatching(request);
125    assertSame(mockFilter, matchingFilter);
126
127    final boolean proceed[] = new boolean[1];
128    matchingFilter.doFilter(request, null, new FilterChainInvocation(null, null, null) {
129      @Override
130      public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
131        proceed[0] = true;
132      }
133    });
134
135    assertTrue("Filter did not proceed down chain", proceed[0]);
136
137    filterDef.destroy(Sets.<Filter>newIdentityHashSet());
138    assertTrue("Destroy did not fire", mockFilter.isDestroy());
139
140    verify(injector, request);
141
142  }
143
144  public final void testFilterCreateDispatchDestroySupressChain()
145      throws ServletException, IOException {
146
147    Injector injector = createMock(Injector.class);
148    Binding binding = createMock(Binding.class);
149    HttpServletRequest request = createMock(HttpServletRequest.class);
150
151    final MockFilter mockFilter = new MockFilter() {
152      @Override
153      public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
154          FilterChain filterChain) {
155        //suppress rest of chain...
156      }
157    };
158
159    expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject()))
160        .andReturn(true);
161    expect(injector.getBinding(Key.get(Filter.class)))
162        .andReturn(binding);
163
164    expect(injector.getInstance(Key.get(Filter.class)))
165        .andReturn(mockFilter)
166        .anyTimes();
167
168    expect(request.getRequestURI()).andReturn("/index.html");
169    expect(request.getContextPath())
170        .andReturn("")
171        .anyTimes();
172
173    replay(injector, binding, request);
174
175    String pattern = "/*";
176    final FilterDefinition filterDef = new FilterDefinition(pattern, Key.get(Filter.class),
177        UriPatternType.get(UriPatternType.SERVLET, pattern),
178        new HashMap<String, String>(), null);
179    //should fire on mockfilter now
180    filterDef.init(createMock(ServletContext.class), injector, Sets.<Filter>newIdentityHashSet());
181    assertTrue(filterDef.getFilter() instanceof MockFilter);
182
183
184    assertTrue("init did not fire", mockFilter.isInit());
185
186    Filter matchingFilter = filterDef.getFilterIfMatching(request);
187    assertSame(mockFilter, matchingFilter);
188
189    final boolean proceed[] = new boolean[1];
190    matchingFilter.doFilter(request, null, new FilterChainInvocation(null, null, null) {
191      @Override
192      public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
193        proceed[0] = true;
194      }
195    });
196
197    assertFalse("filter did not suppress chain", proceed[0]);
198
199    filterDef.destroy(Sets.<Filter>newIdentityHashSet());
200    assertTrue("destroy did not fire", mockFilter.isDestroy());
201
202    verify(injector, request);
203
204  }
205
206  public void testGetFilterIfMatching() throws ServletException {
207    String pattern = "/*";
208    final FilterDefinition filterDef = new FilterDefinition(pattern, Key.get(Filter.class),
209        UriPatternType.get(UriPatternType.SERVLET, pattern),
210        new HashMap<String, String>(), null);
211    HttpServletRequest servletRequest = createMock(HttpServletRequest.class);
212    ServletContext servletContext = createMock(ServletContext.class);
213    Injector injector = createMock(Injector.class);
214    Binding binding = createMock(Binding.class);
215
216    final MockFilter mockFilter = new MockFilter() {
217      @Override
218      public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
219          FilterChain filterChain) {
220        //suppress rest of chain...
221      }
222    };
223    expect(injector.getBinding(Key.get(Filter.class)))
224        .andReturn(binding);
225    expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject()))
226        .andReturn(true);
227    expect(injector.getInstance(Key.get(Filter.class)))
228        .andReturn(mockFilter)
229        .anyTimes();
230
231    expect(servletRequest.getContextPath()).andReturn("/a_context_path");
232    expect(servletRequest.getRequestURI()).andReturn("/a_context_path/test.html");
233
234    replay(servletRequest, binding, injector);
235    filterDef.init(servletContext, injector, Sets.<Filter>newIdentityHashSet());
236    Filter filter = filterDef.getFilterIfMatching(servletRequest);
237    assertSame(filter, mockFilter);
238    verify(servletRequest, binding, injector);
239  }
240
241  public void testGetFilterIfMatchingNotMatching() throws ServletException {
242    String pattern = "/*";
243    final FilterDefinition filterDef = new FilterDefinition(pattern, Key.get(Filter.class),
244        UriPatternType.get(UriPatternType.SERVLET, pattern),
245        new HashMap<String, String>(), null);
246    HttpServletRequest servletRequest = createMock(HttpServletRequest.class);
247    ServletContext servletContext = createMock(ServletContext.class);
248    Injector injector = createMock(Injector.class);
249    Binding binding = createMock(Binding.class);
250
251    final MockFilter mockFilter = new MockFilter() {
252      @Override
253      public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
254          FilterChain filterChain) {
255        //suppress rest of chain...
256      }
257    };
258    expect(injector.getBinding(Key.get(Filter.class)))
259        .andReturn(binding);
260    expect(binding.acceptScopingVisitor((BindingScopingVisitor) anyObject()))
261        .andReturn(true);
262    expect(injector.getInstance(Key.get(Filter.class)))
263        .andReturn(mockFilter)
264        .anyTimes();
265
266    expect(servletRequest.getContextPath()).andReturn("/a_context_path");
267    expect(servletRequest.getRequestURI()).andReturn("/test.html");
268
269    replay(servletRequest, binding, injector);
270    filterDef.init(servletContext, injector, Sets.<Filter>newIdentityHashSet());
271    Filter filter = filterDef.getFilterIfMatching(servletRequest);
272    assertNull(filter);
273    verify(servletRequest, binding, injector);
274  }
275
276  private static class MockFilter implements Filter {
277    private boolean init;
278    private boolean destroy;
279    private FilterConfig config;
280
281    public void init(FilterConfig filterConfig) {
282      init = true;
283
284      this.config = filterConfig;
285    }
286
287    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
288        FilterChain filterChain) throws IOException, ServletException {
289      //proceed
290      filterChain.doFilter(servletRequest, servletResponse);
291    }
292
293    public void destroy() {
294      destroy = true;
295    }
296
297    public boolean isInit() {
298      return init;
299    }
300
301    public boolean isDestroy() {
302      return destroy;
303    }
304
305    public FilterConfig getConfig() {
306      return config;
307    }
308  }
309}
310