1/* 2 * Copyright (C) 2011 Google Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package com.google.mockwebserver; 18 19import static com.google.mockwebserver.SocketPolicy.DISCONNECT_AT_START; 20import java.io.BufferedInputStream; 21import java.io.BufferedOutputStream; 22import java.io.ByteArrayOutputStream; 23import java.io.IOException; 24import java.io.InputStream; 25import java.io.OutputStream; 26import java.net.HttpURLConnection; 27import java.net.InetAddress; 28import java.net.InetSocketAddress; 29import java.net.MalformedURLException; 30import java.net.Proxy; 31import java.net.ServerSocket; 32import java.net.Socket; 33import java.net.SocketException; 34import java.net.URL; 35import java.net.UnknownHostException; 36import java.util.ArrayList; 37import java.util.Collections; 38import java.util.Iterator; 39import java.util.List; 40import java.util.Set; 41import java.util.concurrent.BlockingQueue; 42import java.util.concurrent.ConcurrentHashMap; 43import java.util.concurrent.ExecutorService; 44import java.util.concurrent.Executors; 45import java.util.concurrent.LinkedBlockingDeque; 46import java.util.concurrent.LinkedBlockingQueue; 47import java.util.concurrent.atomic.AtomicInteger; 48import java.util.logging.Level; 49import java.util.logging.Logger; 50import javax.net.ssl.SSLSocket; 51import javax.net.ssl.SSLSocketFactory; 52 53/** 54 * A scriptable web server. Callers supply canned responses and the server 55 * replays them upon request in sequence. 56 */ 57public final class MockWebServer { 58 59 static final String ASCII = "US-ASCII"; 60 61 private static final Logger logger = Logger.getLogger(MockWebServer.class.getName()); 62 private final BlockingQueue<RecordedRequest> requestQueue 63 = new LinkedBlockingQueue<RecordedRequest>(); 64 private final BlockingQueue<MockResponse> responseQueue 65 = new LinkedBlockingDeque<MockResponse>(); 66 private final Set<Socket> openClientSockets 67 = Collections.newSetFromMap(new ConcurrentHashMap<Socket, Boolean>()); 68 private boolean singleResponse; 69 private final AtomicInteger requestCount = new AtomicInteger(); 70 private int bodyLimit = Integer.MAX_VALUE; 71 private ServerSocket serverSocket; 72 private SSLSocketFactory sslSocketFactory; 73 private ExecutorService executor; 74 private boolean tunnelProxy; 75 76 private int port = -1; 77 78 public int getPort() { 79 if (port == -1) { 80 throw new IllegalStateException("Cannot retrieve port before calling play()"); 81 } 82 return port; 83 } 84 85 public String getHostName() { 86 try { 87 return InetAddress.getLocalHost().getHostName(); 88 } catch (UnknownHostException e) { 89 throw new AssertionError(); 90 } 91 } 92 93 public Proxy toProxyAddress() { 94 return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort())); 95 } 96 97 /** 98 * Returns a URL for connecting to this server. 99 * 100 * @param path the request path, such as "/". 101 */ 102 public URL getUrl(String path) throws MalformedURLException, UnknownHostException { 103 return sslSocketFactory != null 104 ? new URL("https://" + getHostName() + ":" + getPort() + path) 105 : new URL("http://" + getHostName() + ":" + getPort() + path); 106 } 107 108 /** 109 * Sets the number of bytes of the POST body to keep in memory to the given 110 * limit. 111 */ 112 public void setBodyLimit(int maxBodyLength) { 113 this.bodyLimit = maxBodyLength; 114 } 115 116 /** 117 * Serve requests with HTTPS rather than otherwise. 118 * 119 * @param tunnelProxy whether to expect the HTTP CONNECT method before 120 * negotiating TLS. 121 */ 122 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 123 this.sslSocketFactory = sslSocketFactory; 124 this.tunnelProxy = tunnelProxy; 125 } 126 127 /** 128 * Awaits the next HTTP request, removes it, and returns it. Callers should 129 * use this to verify the request sent was as intended. 130 */ 131 public RecordedRequest takeRequest() throws InterruptedException { 132 return requestQueue.take(); 133 } 134 135 /** 136 * Returns the number of HTTP requests received thus far by this server. 137 * This may exceed the number of HTTP connections when connection reuse is 138 * in practice. 139 */ 140 public int getRequestCount() { 141 return requestCount.get(); 142 } 143 144 public void enqueue(MockResponse response) { 145 responseQueue.add(response.clone()); 146 } 147 148 /** 149 * By default, this class processes requests coming in by adding them to a 150 * queue and serves responses by removing them from another queue. This mode 151 * is appropriate for correctness testing. 152 * 153 * <p>Serving a single response causes the server to be stateless: requests 154 * are not enqueued, and responses are not dequeued. This mode is appropriate 155 * for benchmarking. 156 */ 157 public void setSingleResponse(boolean singleResponse) { 158 this.singleResponse = singleResponse; 159 } 160 161 /** 162 * Equivalent to {@code play(0)}. 163 */ 164 public void play() throws IOException { 165 play(0); 166 } 167 168 /** 169 * Starts the server, serves all enqueued requests, and shuts the server 170 * down. 171 * 172 * @param port the port to listen to, or 0 for any available port. 173 * Automated tests should always use port 0 to avoid flakiness when a 174 * specific port is unavailable. 175 */ 176 public void play(int port) throws IOException { 177 executor = Executors.newCachedThreadPool(); 178 serverSocket = new ServerSocket(port); 179 serverSocket.setReuseAddress(true); 180 181 this.port = serverSocket.getLocalPort(); 182 executor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() { 183 public void run() { 184 try { 185 acceptConnections(); 186 } catch (Throwable e) { 187 logger.log(Level.WARNING, "MockWebServer connection failed", e); 188 } 189 190 /* 191 * This gnarly block of code will release all sockets and 192 * all thread, even if any close fails. 193 */ 194 try { 195 serverSocket.close(); 196 } catch (Throwable e) { 197 logger.log(Level.WARNING, "MockWebServer server socket close failed", e); 198 } 199 for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext();) { 200 try { 201 s.next().close(); 202 s.remove(); 203 } catch (Throwable e) { 204 logger.log(Level.WARNING, "MockWebServer socket close failed", e); 205 } 206 } 207 try { 208 executor.shutdown(); 209 } catch (Throwable e) { 210 logger.log(Level.WARNING, "MockWebServer executor shutdown failed", e); 211 } 212 } 213 214 private void acceptConnections() throws Exception { 215 do { 216 Socket socket; 217 try { 218 socket = serverSocket.accept(); 219 } catch (SocketException ignored) { 220 continue; 221 } 222 MockResponse peek = responseQueue.peek(); 223 if (peek != null && peek.getSocketPolicy() == DISCONNECT_AT_START) { 224 responseQueue.take(); 225 socket.close(); 226 } else { 227 openClientSockets.add(socket); 228 serveConnection(socket); 229 } 230 } while (!responseQueue.isEmpty()); 231 } 232 })); 233 } 234 235 public void shutdown() throws IOException { 236 if (serverSocket != null) { 237 serverSocket.close(); // should cause acceptConnections() to break out 238 } 239 } 240 241 private void serveConnection(final Socket raw) { 242 String name = "MockWebServer-" + raw.getRemoteSocketAddress(); 243 executor.execute(namedRunnable(name, new Runnable() { 244 int sequenceNumber = 0; 245 246 public void run() { 247 try { 248 processConnection(); 249 } catch (Exception e) { 250 logger.log(Level.WARNING, "MockWebServer connection failed", e); 251 } 252 } 253 254 public void processConnection() throws Exception { 255 Socket socket; 256 if (sslSocketFactory != null) { 257 if (tunnelProxy) { 258 createTunnel(); 259 } 260 socket = sslSocketFactory.createSocket( 261 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 262 ((SSLSocket) socket).setUseClientMode(false); 263 openClientSockets.add(socket); 264 openClientSockets.remove(raw); 265 } else { 266 socket = raw; 267 } 268 269 InputStream in = new BufferedInputStream(socket.getInputStream()); 270 OutputStream out = new BufferedOutputStream(socket.getOutputStream()); 271 272 while (!responseQueue.isEmpty() && processOneRequest(in, out, socket)) {} 273 274 if (sequenceNumber == 0) { 275 logger.warning("MockWebServer connection didn't make a request"); 276 } 277 278 in.close(); 279 out.close(); 280 socket.close(); 281 if (responseQueue.isEmpty()) { 282 shutdown(); 283 } 284 openClientSockets.remove(socket); 285 } 286 287 /** 288 * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response 289 * is dispatched. 290 */ 291 private void createTunnel() throws IOException, InterruptedException { 292 while (true) { 293 MockResponse connect = responseQueue.peek(); 294 if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) { 295 throw new IllegalStateException("Tunnel without any CONNECT!"); 296 } 297 if (connect.getSocketPolicy() == SocketPolicy.UPGRADE_TO_SSL_AT_END) { 298 return; 299 } 300 } 301 } 302 303 /** 304 * Reads a request and writes its response. Returns true if a request 305 * was processed. 306 */ 307 private boolean processOneRequest(InputStream in, OutputStream out, Socket socket) 308 throws IOException, InterruptedException { 309 RecordedRequest request = readRequest(in, sequenceNumber); 310 if (request == null) { 311 return false; 312 } 313 MockResponse response = dispatch(request); 314 writeResponse(out, response); 315 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) { 316 in.close(); 317 out.close(); 318 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) { 319 socket.shutdownInput(); 320 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) { 321 socket.shutdownOutput(); 322 } 323 sequenceNumber++; 324 return true; 325 } 326 })); 327 } 328 329 /** 330 * @param sequenceNumber the index of this request on this connection. 331 */ 332 private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException { 333 String request; 334 try { 335 request = readAsciiUntilCrlf(in); 336 } catch (IOException streamIsClosed) { 337 return null; // no request because we closed the stream 338 } 339 if (request.isEmpty()) { 340 return null; // no request because the stream is exhausted 341 } 342 343 List<String> headers = new ArrayList<String>(); 344 int contentLength = -1; 345 boolean chunked = false; 346 String header; 347 while (!(header = readAsciiUntilCrlf(in)).isEmpty()) { 348 headers.add(header); 349 String lowercaseHeader = header.toLowerCase(); 350 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 351 contentLength = Integer.parseInt(header.substring(15).trim()); 352 } 353 if (lowercaseHeader.startsWith("transfer-encoding:") && 354 lowercaseHeader.substring(18).trim().equals("chunked")) { 355 chunked = true; 356 } 357 } 358 359 boolean hasBody = false; 360 TruncatingOutputStream requestBody = new TruncatingOutputStream(); 361 List<Integer> chunkSizes = new ArrayList<Integer>(); 362 if (contentLength != -1) { 363 hasBody = true; 364 transfer(contentLength, in, requestBody); 365 } else if (chunked) { 366 hasBody = true; 367 while (true) { 368 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); 369 if (chunkSize == 0) { 370 readEmptyLine(in); 371 break; 372 } 373 chunkSizes.add(chunkSize); 374 transfer(chunkSize, in, requestBody); 375 readEmptyLine(in); 376 } 377 } 378 379 if (request.startsWith("OPTIONS ") || request.startsWith("GET ") 380 || request.startsWith("HEAD ") || request.startsWith("DELETE ") 381 || request .startsWith("TRACE ") || request.startsWith("CONNECT ")) { 382 if (hasBody) { 383 throw new IllegalArgumentException("Request must not have a body: " + request); 384 } 385 } else if (request.startsWith("POST ") || request.startsWith("PUT ")) { 386 if (!hasBody) { 387 throw new IllegalArgumentException("Request must have a body: " + request); 388 } 389 } else { 390 throw new UnsupportedOperationException("Unexpected method: " + request); 391 } 392 393 return new RecordedRequest(request, headers, chunkSizes, 394 requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber); 395 } 396 397 /** 398 * Returns a response to satisfy {@code request}. 399 */ 400 private MockResponse dispatch(RecordedRequest request) throws InterruptedException { 401 if (responseQueue.isEmpty()) { 402 throw new IllegalStateException("Unexpected request: " + request); 403 } 404 405 // to permit interactive/browser testing, ignore requests for favicons 406 if (request.getRequestLine().equals("GET /favicon.ico HTTP/1.1")) { 407 System.out.println("served " + request.getRequestLine()); 408 return new MockResponse() 409 .setResponseCode(HttpURLConnection.HTTP_NOT_FOUND); 410 } 411 412 if (singleResponse) { 413 return responseQueue.peek(); 414 } else { 415 requestCount.incrementAndGet(); 416 requestQueue.add(request); 417 return responseQueue.take(); 418 } 419 } 420 421 private void writeResponse(OutputStream out, MockResponse response) throws IOException { 422 out.write((response.getStatus() + "\r\n").getBytes(ASCII)); 423 for (String header : response.getHeaders()) { 424 out.write((header + "\r\n").getBytes(ASCII)); 425 } 426 out.write(("\r\n").getBytes(ASCII)); 427 out.write(response.getBody()); 428 out.flush(); 429 } 430 431 /** 432 * Transfer bytes from {@code in} to {@code out} until either {@code length} 433 * bytes have been transferred or {@code in} is exhausted. 434 */ 435 private void transfer(int length, InputStream in, OutputStream out) throws IOException { 436 byte[] buffer = new byte[1024]; 437 while (length > 0) { 438 int count = in.read(buffer, 0, Math.min(buffer.length, length)); 439 if (count == -1) { 440 return; 441 } 442 out.write(buffer, 0, count); 443 length -= count; 444 } 445 } 446 447 /** 448 * Returns the text from {@code in} until the next "\r\n", or null if 449 * {@code in} is exhausted. 450 */ 451 private String readAsciiUntilCrlf(InputStream in) throws IOException { 452 StringBuilder builder = new StringBuilder(); 453 while (true) { 454 int c = in.read(); 455 if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { 456 builder.deleteCharAt(builder.length() - 1); 457 return builder.toString(); 458 } else if (c == -1) { 459 return builder.toString(); 460 } else { 461 builder.append((char) c); 462 } 463 } 464 } 465 466 private void readEmptyLine(InputStream in) throws IOException { 467 String line = readAsciiUntilCrlf(in); 468 if (!line.isEmpty()) { 469 throw new IllegalStateException("Expected empty but was: " + line); 470 } 471 } 472 473 /** 474 * An output stream that drops data after bodyLimit bytes. 475 */ 476 private class TruncatingOutputStream extends ByteArrayOutputStream { 477 private int numBytesReceived = 0; 478 @Override public void write(byte[] buffer, int offset, int len) { 479 numBytesReceived += len; 480 super.write(buffer, offset, Math.min(len, bodyLimit - count)); 481 } 482 @Override public void write(int oneByte) { 483 numBytesReceived++; 484 if (count < bodyLimit) { 485 super.write(oneByte); 486 } 487 } 488 } 489 490 private static Runnable namedRunnable(final String name, final Runnable runnable) { 491 return new Runnable() { 492 public void run() { 493 String originalName = Thread.currentThread().getName(); 494 Thread.currentThread().setName(name); 495 try { 496 runnable.run(); 497 } finally { 498 Thread.currentThread().setName(originalName); 499 } 500 } 501 }; 502 } 503} 504