1/**
2 * Copyright (C) 2008 Google Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.google.inject.servlet;
18
19import static com.google.inject.servlet.ManagedServletPipeline.REQUEST_DISPATCHER_REQUEST;
20import static org.easymock.EasyMock.createMock;
21import static org.easymock.EasyMock.expect;
22import static org.easymock.EasyMock.replay;
23import static org.easymock.EasyMock.verify;
24
25import com.google.inject.Guice;
26import com.google.inject.Injector;
27import com.google.inject.Key;
28import com.google.inject.Singleton;
29
30import junit.framework.TestCase;
31
32import java.io.IOException;
33
34import javax.servlet.Filter;
35import javax.servlet.FilterChain;
36import javax.servlet.FilterConfig;
37import javax.servlet.ServletConfig;
38import javax.servlet.ServletException;
39import javax.servlet.ServletRequest;
40import javax.servlet.ServletResponse;
41import javax.servlet.http.HttpServlet;
42import javax.servlet.http.HttpServletRequest;
43import javax.servlet.http.HttpServletResponse;
44
45/**
46 * Tests the FilterPipeline that dispatches to guice-managed servlets,
47 * is a full integration test, with a real injector.
48 *
49 * @author Dhanji R. Prasanna (dhanji gmail com)
50 */
51public class ServletDispatchIntegrationTest extends TestCase {
52  private static int inits, services, destroys, doFilters;
53
54  @Override
55  public void setUp() {
56    inits = 0;
57    services = 0;
58    destroys = 0;
59    doFilters = 0;
60
61    GuiceFilter.reset();
62  }
63
64  public final void testDispatchRequestToManagedPipelineServlets()
65      throws ServletException, IOException {
66    final Injector injector = Guice.createInjector(new ServletModule() {
67
68      @Override
69      protected void configureServlets() {
70        serve("/*").with(TestServlet.class);
71
72        // These servets should never fire... (ordering test)
73        serve("*.html").with(NeverServlet.class);
74        serve("/test/*").with(Key.get(NeverServlet.class));
75        serve("/index/*").with(Key.get(NeverServlet.class));
76        serve("*.jsp").with(Key.get(NeverServlet.class));
77      }
78    });
79
80    final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
81
82    pipeline.initPipeline(null);
83
84    //create ourselves a mock request with test URI
85    HttpServletRequest requestMock = createMock(HttpServletRequest.class);
86
87    expect(requestMock.getRequestURI())
88        .andReturn("/index.html")
89        .times(1);
90    expect(requestMock.getContextPath())
91        .andReturn("")
92        .anyTimes();
93
94    //dispatch request
95    replay(requestMock);
96
97    pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
98
99    pipeline.destroyPipeline();
100
101    verify(requestMock);
102
103    assertTrue("lifecycle states did not fire correct number of times-- inits: " + inits + "; dos: "
104            + services + "; destroys: " + destroys,
105        inits == 2 && services == 1 && destroys == 2);
106  }
107
108  public final void testDispatchRequestToManagedPipelineWithFilter()
109      throws ServletException, IOException {
110    final Injector injector = Guice.createInjector(new ServletModule() {
111
112      @Override
113      protected void configureServlets() {
114        filter("/*").through(TestFilter.class);
115
116        serve("/*").with(TestServlet.class);
117
118        // These servets should never fire...
119        serve("*.html").with(NeverServlet.class);
120        serve("/test/*").with(Key.get(NeverServlet.class));
121        serve("/index/*").with(Key.get(NeverServlet.class));
122        serve("*.jsp").with(Key.get(NeverServlet.class));
123
124      }
125    });
126
127    final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
128
129    pipeline.initPipeline(null);
130
131    //create ourselves a mock request with test URI
132    HttpServletRequest requestMock = createMock(HttpServletRequest.class);
133
134    expect(requestMock.getRequestURI())
135        .andReturn("/index.html")
136        .times(2);
137    expect(requestMock.getContextPath())
138        .andReturn("")
139        .anyTimes();
140
141    //dispatch request
142    replay(requestMock);
143
144    pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
145
146    pipeline.destroyPipeline();
147
148    verify(requestMock);
149
150    assertTrue("lifecycle states did not fire correct number of times-- inits: " + inits + "; dos: "
151            + services + "; destroys: " + destroys + "; doFilters: " + doFilters,
152        inits == 3 && services == 1 && destroys == 3 && doFilters == 1);
153  }
154
155  @Singleton
156  public static class TestServlet extends HttpServlet {
157    public void init(ServletConfig filterConfig) throws ServletException {
158      inits++;
159    }
160
161    public void service(ServletRequest servletRequest, ServletResponse servletResponse)
162        throws IOException, ServletException {
163      services++;
164    }
165
166    public void destroy() {
167      destroys++;
168    }
169  }
170
171  @Singleton
172  public static class NeverServlet extends HttpServlet {
173    public void init(ServletConfig filterConfig) throws ServletException {
174      inits++;
175    }
176
177    public void service(ServletRequest servletRequest, ServletResponse servletResponse)
178        throws IOException, ServletException {
179      fail("NeverServlet was fired, when it should not have been.");
180    }
181
182    public void destroy() {
183      destroys++;
184    }
185  }
186
187  @Singleton
188  public static class TestFilter implements Filter {
189    public void init(FilterConfig filterConfig) throws ServletException {
190      inits++;
191    }
192
193    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
194        FilterChain filterChain) throws IOException, ServletException {
195      doFilters++;
196      filterChain.doFilter(servletRequest, servletResponse);
197    }
198
199    public void destroy() {
200      destroys++;
201    }
202  }
203
204
205  @Singleton
206  public static class ForwardingServlet extends HttpServlet {
207    public void service(ServletRequest servletRequest, ServletResponse servletResponse)
208        throws IOException, ServletException {
209      final HttpServletRequest request = (HttpServletRequest) servletRequest;
210
211      request.getRequestDispatcher("/blah.jsp")
212          .forward(servletRequest, servletResponse);
213    }
214  }
215
216  @Singleton
217  public static class ForwardedServlet extends HttpServlet {
218    static int forwardedTo = 0;
219
220    // Reset for test.
221    public ForwardedServlet() {
222      forwardedTo = 0;
223    }
224
225    public void service(ServletRequest servletRequest, ServletResponse servletResponse)
226        throws IOException, ServletException {
227      final HttpServletRequest request = (HttpServletRequest) servletRequest;
228
229      assertTrue((Boolean) request.getAttribute(REQUEST_DISPATCHER_REQUEST));
230      forwardedTo++;
231    }
232  }
233
234  public void testForwardUsingRequestDispatcher() throws IOException, ServletException {
235    Guice.createInjector(new ServletModule() {
236      @Override
237      protected void configureServlets() {
238        serve("/").with(ForwardingServlet.class);
239        serve("/blah.jsp").with(ForwardedServlet.class);
240      }
241    });
242
243    final HttpServletRequest requestMock = createMock(HttpServletRequest.class);
244    HttpServletResponse responseMock = createMock(HttpServletResponse.class);
245    expect(requestMock.getRequestURI())
246        .andReturn("/")
247        .anyTimes();
248    expect(requestMock.getContextPath())
249        .andReturn("")
250        .anyTimes();
251
252    requestMock.setAttribute(REQUEST_DISPATCHER_REQUEST, true);
253    expect(requestMock.getAttribute(REQUEST_DISPATCHER_REQUEST)).andReturn(true);
254    requestMock.removeAttribute(REQUEST_DISPATCHER_REQUEST);
255
256    expect(responseMock.isCommitted()).andReturn(false);
257    responseMock.resetBuffer();
258
259    replay(requestMock, responseMock);
260
261    new GuiceFilter()
262        .doFilter(requestMock, responseMock,
263            createMock(FilterChain.class));
264
265    assertEquals("Incorrect number of forwards", 1, ForwardedServlet.forwardedTo);
266    verify(requestMock, responseMock);
267  }
268
269  public final void testQueryInRequestUri_regex() throws Exception {
270    final Injector injector = Guice.createInjector(new ServletModule() {
271
272      @Override
273      protected void configureServlets() {
274        filterRegex("(.)*\\.html").through(TestFilter.class);
275
276        serveRegex("(.)*\\.html").with(TestServlet.class);
277      }
278    });
279
280    final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
281
282    pipeline.initPipeline(null);
283
284    //create ourselves a mock request with test URI
285    HttpServletRequest requestMock = createMock(HttpServletRequest.class);
286
287    expect(requestMock.getRequestURI())
288        .andReturn("/index.html?query=params")
289        .atLeastOnce();
290    expect(requestMock.getContextPath())
291        .andReturn("")
292        .anyTimes();
293
294    //dispatch request
295    replay(requestMock);
296
297    pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
298
299    pipeline.destroyPipeline();
300
301    verify(requestMock);
302
303    assertEquals(1, doFilters);
304    assertEquals(1, services);
305  }
306
307  public final void testQueryInRequestUri() throws Exception {
308    final Injector injector = Guice.createInjector(new ServletModule() {
309
310      @Override
311      protected void configureServlets() {
312        filter("/index.html").through(TestFilter.class);
313
314        serve("/index.html").with(TestServlet.class);
315      }
316    });
317
318    final FilterPipeline pipeline = injector.getInstance(FilterPipeline.class);
319
320    pipeline.initPipeline(null);
321
322    //create ourselves a mock request with test URI
323    HttpServletRequest requestMock = createMock(HttpServletRequest.class);
324
325    expect(requestMock.getRequestURI())
326        .andReturn("/index.html?query=params")
327        .atLeastOnce();
328    expect(requestMock.getContextPath())
329        .andReturn("")
330        .anyTimes();
331
332    //dispatch request
333    replay(requestMock);
334
335    pipeline.dispatch(requestMock, null, createMock(FilterChain.class));
336
337    pipeline.destroyPipeline();
338
339    verify(requestMock);
340
341    assertEquals(1, doFilters);
342    assertEquals(1, services);
343  }
344}
345