1/**
2 * Copyright (C) 2008 Google Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16package com.google.inject.servlet;
17
18import com.google.common.base.Preconditions;
19import com.google.common.collect.Lists;
20import com.google.common.collect.Sets;
21import com.google.inject.Binding;
22import com.google.inject.Inject;
23import com.google.inject.Injector;
24import com.google.inject.Singleton;
25import com.google.inject.TypeLiteral;
26
27import java.io.IOException;
28import java.util.List;
29import java.util.Set;
30
31import javax.servlet.RequestDispatcher;
32import javax.servlet.ServletContext;
33import javax.servlet.ServletException;
34import javax.servlet.ServletRequest;
35import javax.servlet.ServletResponse;
36import javax.servlet.http.HttpServlet;
37import javax.servlet.http.HttpServletRequest;
38import javax.servlet.http.HttpServletRequestWrapper;
39
40/**
41 * A wrapping dispatcher for servlets, in much the same way as {@link ManagedFilterPipeline} is for
42 * filters.
43 *
44 * @author dhanji@gmail.com (Dhanji R. Prasanna)
45 */
46@Singleton
47class ManagedServletPipeline {
48  private final ServletDefinition[] servletDefinitions;
49  private static final TypeLiteral<ServletDefinition> SERVLET_DEFS =
50      TypeLiteral.get(ServletDefinition.class);
51
52  @Inject
53  public ManagedServletPipeline(Injector injector) {
54    this.servletDefinitions = collectServletDefinitions(injector);
55  }
56
57  boolean hasServletsMapped() {
58    return servletDefinitions.length > 0;
59  }
60
61  /**
62   * Introspects the injector and collects all instances of bound {@code List<ServletDefinition>}
63   * into a master list.
64   *
65   * We have a guarantee that {@link com.google.inject.Injector#getBindings()} returns a map
66   * that preserves insertion order in entry-set iterators.
67   */
68  private ServletDefinition[] collectServletDefinitions(Injector injector) {
69    List<ServletDefinition> servletDefinitions = Lists.newArrayList();
70    for (Binding<ServletDefinition> entry : injector.findBindingsByType(SERVLET_DEFS)) {
71        servletDefinitions.add(entry.getProvider().get());
72    }
73
74    // Copy to a fixed size array for speed.
75    return servletDefinitions.toArray(new ServletDefinition[servletDefinitions.size()]);
76  }
77
78  public void init(ServletContext servletContext, Injector injector) throws ServletException {
79    Set<HttpServlet> initializedSoFar = Sets.newIdentityHashSet();
80
81    for (ServletDefinition servletDefinition : servletDefinitions) {
82      servletDefinition.init(servletContext, injector, initializedSoFar);
83    }
84  }
85
86  public boolean service(ServletRequest request, ServletResponse response)
87      throws IOException, ServletException {
88
89    //stop at the first matching servlet and service
90    for (ServletDefinition servletDefinition : servletDefinitions) {
91      if (servletDefinition.service(request, response)) {
92        return true;
93      }
94    }
95
96    //there was no match...
97    return false;
98  }
99
100  public void destroy() {
101    Set<HttpServlet> destroyedSoFar = Sets.newIdentityHashSet();
102    for (ServletDefinition servletDefinition : servletDefinitions) {
103      servletDefinition.destroy(destroyedSoFar);
104    }
105  }
106
107  /**
108   * @return Returns a request dispatcher wrapped with a servlet mapped to
109   * the given path or null if no mapping was found.
110   */
111  RequestDispatcher getRequestDispatcher(String path) {
112    final String newRequestUri = path;
113
114    // TODO(dhanji): check servlet spec to see if the following is legal or not.
115    // Need to strip query string if requested...
116
117    for (final ServletDefinition servletDefinition : servletDefinitions) {
118      if (servletDefinition.shouldServe(path)) {
119        return new RequestDispatcher() {
120          public void forward(ServletRequest servletRequest, ServletResponse servletResponse)
121              throws ServletException, IOException {
122            Preconditions.checkState(!servletResponse.isCommitted(),
123                "Response has been committed--you can only call forward before"
124                + " committing the response (hint: don't flush buffers)");
125
126            // clear buffer before forwarding
127            servletResponse.resetBuffer();
128
129            ServletRequest requestToProcess;
130            if (servletRequest instanceof HttpServletRequest) {
131               requestToProcess = wrapRequest((HttpServletRequest)servletRequest, newRequestUri);
132            } else {
133              // This should never happen, but instead of throwing an exception
134              // we will allow a happy case pass thru for maximum tolerance to
135              // legacy (and internal) code.
136              requestToProcess = servletRequest;
137            }
138
139            // now dispatch to the servlet
140            doServiceImpl(servletDefinition, requestToProcess, servletResponse);
141          }
142
143          public void include(ServletRequest servletRequest, ServletResponse servletResponse)
144              throws ServletException, IOException {
145            // route to the target servlet
146            doServiceImpl(servletDefinition, servletRequest, servletResponse);
147          }
148
149          private void doServiceImpl(ServletDefinition servletDefinition, ServletRequest servletRequest,
150              ServletResponse servletResponse) throws ServletException, IOException {
151            servletRequest.setAttribute(REQUEST_DISPATCHER_REQUEST, Boolean.TRUE);
152
153            try {
154              servletDefinition.doService(servletRequest, servletResponse);
155            } finally {
156              servletRequest.removeAttribute(REQUEST_DISPATCHER_REQUEST);
157            }
158          }
159        };
160      }
161    }
162
163    //otherwise, can't process
164    return null;
165  }
166
167  // visible for testing
168  static HttpServletRequest wrapRequest(HttpServletRequest request, String newUri) {
169    return new RequestDispatcherRequestWrapper(request, newUri);
170  }
171
172  /**
173   * A Marker constant attribute that when present in the request indicates to Guice servlet that
174   * this request has been generated by a request dispatcher rather than the servlet pipeline.
175   * In accordance with section 8.4.2 of the Servlet 2.4 specification.
176   */
177  public static final String REQUEST_DISPATCHER_REQUEST = "javax.servlet.forward.servlet_path";
178
179  private static class RequestDispatcherRequestWrapper extends HttpServletRequestWrapper {
180    private final String newRequestUri;
181
182    public RequestDispatcherRequestWrapper(HttpServletRequest servletRequest, String newRequestUri) {
183      super(servletRequest);
184      this.newRequestUri = newRequestUri;
185    }
186
187    @Override
188    public String getRequestURI() {
189      return newRequestUri;
190    }
191
192    @Override
193    public StringBuffer getRequestURL() {
194      StringBuffer url = new StringBuffer();
195      String scheme = getScheme();
196      int port = getServerPort();
197
198      url.append(scheme);
199      url.append("://");
200      url.append(getServerName());
201      // port might be -1 in some cases (see java.net.URL.getPort)
202      if (port > 0 &&
203          (("http".equals(scheme) && (port != 80)) ||
204           ("https".equals(scheme) && (port != 443)))) {
205        url.append(':');
206        url.append(port);
207      }
208      url.append(getRequestURI());
209
210      return (url);
211    }
212  }
213}
214