1/* 2 * Copyright (C) 2011 Google Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package com.google.mockwebserver; 18 19import static com.google.mockwebserver.SocketPolicy.DISCONNECT_AT_START; 20import static com.google.mockwebserver.SocketPolicy.FAIL_HANDSHAKE; 21import java.io.BufferedInputStream; 22import java.io.BufferedOutputStream; 23import java.io.ByteArrayOutputStream; 24import java.io.IOException; 25import java.io.InputStream; 26import java.io.OutputStream; 27import java.net.InetAddress; 28import java.net.InetSocketAddress; 29import java.net.MalformedURLException; 30import java.net.Proxy; 31import java.net.ServerSocket; 32import java.net.Socket; 33import java.net.SocketException; 34import java.net.URL; 35import java.net.UnknownHostException; 36import java.security.cert.CertificateException; 37import java.security.cert.X509Certificate; 38import java.util.ArrayList; 39import java.util.Iterator; 40import java.util.List; 41import java.util.Map; 42import java.util.concurrent.BlockingQueue; 43import java.util.concurrent.ConcurrentHashMap; 44import java.util.concurrent.ExecutorService; 45import java.util.concurrent.Executors; 46import java.util.concurrent.LinkedBlockingQueue; 47import java.util.concurrent.atomic.AtomicInteger; 48import java.util.logging.Level; 49import java.util.logging.Logger; 50import javax.net.ssl.SSLContext; 51import javax.net.ssl.SSLSocket; 52import javax.net.ssl.SSLSocketFactory; 53import javax.net.ssl.TrustManager; 54import javax.net.ssl.X509TrustManager; 55 56/** 57 * A scriptable web server. Callers supply canned responses and the server 58 * replays them upon request in sequence. 59 */ 60public final class MockWebServer { 61 62 static final String ASCII = "US-ASCII"; 63 64 private static final Logger logger = Logger.getLogger(MockWebServer.class.getName()); 65 private final BlockingQueue<RecordedRequest> requestQueue 66 = new LinkedBlockingQueue<RecordedRequest>(); 67 /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */ 68 private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>(); 69 private final AtomicInteger requestCount = new AtomicInteger(); 70 private int bodyLimit = Integer.MAX_VALUE; 71 private ServerSocket serverSocket; 72 private SSLSocketFactory sslSocketFactory; 73 private ExecutorService acceptExecutor; 74 private ExecutorService requestExecutor; 75 private boolean tunnelProxy; 76 private Dispatcher dispatcher = new QueueDispatcher(); 77 78 private int port = -1; 79 private int workerThreads = Integer.MAX_VALUE; 80 81 82 public int getPort() { 83 if (port == -1) { 84 throw new IllegalStateException("Cannot retrieve port before calling play()"); 85 } 86 return port; 87 } 88 89 public String getHostName() { 90 try { 91 return InetAddress.getLocalHost().getHostName(); 92 } catch (UnknownHostException e) { 93 throw new AssertionError(); 94 } 95 } 96 97 public Proxy toProxyAddress() { 98 return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort())); 99 } 100 101 /** 102 * Returns a URL for connecting to this server. 103 * 104 * @param path the request path, such as "/". 105 */ 106 public URL getUrl(String path) { 107 try { 108 return sslSocketFactory != null 109 ? new URL("https://" + getHostName() + ":" + getPort() + path) 110 : new URL("http://" + getHostName() + ":" + getPort() + path); 111 } catch (MalformedURLException e) { 112 throw new AssertionError(e); 113 } 114 } 115 116 /** 117 * Returns a cookie domain for this server. This returns the server's 118 * non-loopback host name if it is known. Otherwise this returns ".local" 119 * for this server's loopback name. 120 */ 121 public String getCookieDomain() { 122 String hostName = getHostName(); 123 return hostName.contains(".") ? hostName : ".local"; 124 } 125 126 public void setWorkerThreads(int threads) { 127 this.workerThreads = threads; 128 } 129 130 /** 131 * Sets the number of bytes of the POST body to keep in memory to the given 132 * limit. 133 */ 134 public void setBodyLimit(int maxBodyLength) { 135 this.bodyLimit = maxBodyLength; 136 } 137 138 /** 139 * Serve requests with HTTPS rather than otherwise. 140 * 141 * @param tunnelProxy whether to expect the HTTP CONNECT method before 142 * negotiating TLS. 143 */ 144 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 145 this.sslSocketFactory = sslSocketFactory; 146 this.tunnelProxy = tunnelProxy; 147 } 148 149 /** 150 * Awaits the next HTTP request, removes it, and returns it. Callers should 151 * use this to verify the request sent was as intended. 152 */ 153 public RecordedRequest takeRequest() throws InterruptedException { 154 return requestQueue.take(); 155 } 156 157 /** 158 * Returns the number of HTTP requests received thus far by this server. 159 * This may exceed the number of HTTP connections when connection reuse is 160 * in practice. 161 */ 162 public int getRequestCount() { 163 return requestCount.get(); 164 } 165 166 /** 167 * Scripts {@code response} to be returned to a request made in sequence. 168 * The first request is served by the first enqueued response; the second 169 * request by the second enqueued response; and so on. 170 * 171 * @throws ClassCastException if the default dispatcher has been replaced 172 * with {@link #setDispatcher(Dispatcher)}. 173 */ 174 public void enqueue(MockResponse response) { 175 ((QueueDispatcher) dispatcher).enqueueResponse(response.clone()); 176 } 177 178 /** 179 * Equivalent to {@code play(0)}. 180 */ 181 public void play() throws IOException { 182 play(0); 183 } 184 185 /** 186 * Starts the server, serves all enqueued requests, and shuts the server 187 * down. 188 * 189 * @param port the port to listen to, or 0 for any available port. 190 * Automated tests should always use port 0 to avoid flakiness when a 191 * specific port is unavailable. 192 */ 193 public void play(int port) throws IOException { 194 if (acceptExecutor != null) { 195 throw new IllegalStateException("play() already called"); 196 } 197 // The acceptExecutor handles the Socket.accept() and hands each request off to the 198 // requestExecutor. It also handles shutdown. 199 acceptExecutor = Executors.newSingleThreadExecutor(); 200 // The requestExecutor has a fixed number of worker threads. In order to get strict 201 // guarantees that requests are handled in the order in which they are accepted 202 // workerThreads should be set to 1. 203 requestExecutor = Executors.newFixedThreadPool(workerThreads); 204 serverSocket = new ServerSocket(port); 205 serverSocket.setReuseAddress(true); 206 207 this.port = serverSocket.getLocalPort(); 208 acceptExecutor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() { 209 public void run() { 210 try { 211 acceptConnections(); 212 } catch (Throwable e) { 213 logger.log(Level.WARNING, "MockWebServer connection failed", e); 214 } 215 216 /* 217 * This gnarly block of code will release all sockets and 218 * all thread, even if any close fails. 219 */ 220 try { 221 serverSocket.close(); 222 } catch (Throwable e) { 223 logger.log(Level.WARNING, "MockWebServer server socket close failed", e); 224 } 225 for (Iterator<Socket> s = openClientSockets.keySet().iterator(); s.hasNext(); ) { 226 try { 227 s.next().close(); 228 s.remove(); 229 } catch (Throwable e) { 230 logger.log(Level.WARNING, "MockWebServer socket close failed", e); 231 } 232 } 233 try { 234 acceptExecutor.shutdown(); 235 } catch (Throwable e) { 236 logger.log(Level.WARNING, "MockWebServer acceptExecutor shutdown failed", e); 237 } 238 try { 239 requestExecutor.shutdown(); 240 } catch (Throwable e) { 241 logger.log(Level.WARNING, "MockWebServer requestExecutor shutdown failed", e); 242 } 243 } 244 245 private void acceptConnections() throws Exception { 246 while (true) { 247 Socket socket; 248 try { 249 socket = serverSocket.accept(); 250 } catch (SocketException e) { 251 return; 252 } 253 final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); 254 if (socketPolicy == DISCONNECT_AT_START) { 255 dispatchBookkeepingRequest(0, socket); 256 socket.close(); 257 } else { 258 openClientSockets.put(socket, true); 259 serveConnection(socket); 260 } 261 } 262 } 263 })); 264 } 265 266 public void shutdown() throws IOException { 267 if (serverSocket != null) { 268 serverSocket.close(); // should cause acceptConnections() to break out 269 } 270 } 271 272 private void serveConnection(final Socket raw) { 273 String name = "MockWebServer-" + raw.getRemoteSocketAddress(); 274 requestExecutor.execute(namedRunnable(name, new Runnable() { 275 int sequenceNumber = 0; 276 277 public void run() { 278 try { 279 processConnection(); 280 } catch (Exception e) { 281 logger.log(Level.WARNING, "MockWebServer connection failed", e); 282 } 283 } 284 285 public void processConnection() throws Exception { 286 Socket socket; 287 if (sslSocketFactory != null) { 288 if (tunnelProxy) { 289 createTunnel(); 290 } 291 final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); 292 if (socketPolicy == FAIL_HANDSHAKE) { 293 dispatchBookkeepingRequest(sequenceNumber, raw); 294 processHandshakeFailure(raw, sequenceNumber++); 295 return; 296 } 297 socket = sslSocketFactory.createSocket( 298 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 299 ((SSLSocket) socket).setUseClientMode(false); 300 openClientSockets.put(socket, true); 301 openClientSockets.remove(raw); 302 } else { 303 socket = raw; 304 } 305 306 InputStream in = new BufferedInputStream(socket.getInputStream()); 307 OutputStream out = new BufferedOutputStream(socket.getOutputStream()); 308 309 while (processOneRequest(socket, in, out)) { 310 } 311 312 if (sequenceNumber == 0) { 313 logger.warning("MockWebServer connection didn't make a request"); 314 } 315 316 in.close(); 317 out.close(); 318 socket.close(); 319 openClientSockets.remove(socket); 320 } 321 322 /** 323 * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response 324 * is dispatched. 325 */ 326 private void createTunnel() throws IOException, InterruptedException { 327 while (true) { 328 final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy(); 329 if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) { 330 throw new IllegalStateException("Tunnel without any CONNECT!"); 331 } 332 if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) { 333 return; 334 } 335 } 336 } 337 338 /** 339 * Reads a request and writes its response. Returns true if a request 340 * was processed. 341 */ 342 private boolean processOneRequest(Socket socket, InputStream in, OutputStream out) 343 throws IOException, InterruptedException { 344 RecordedRequest request = readRequest(socket, in, sequenceNumber); 345 if (request == null) { 346 return false; 347 } 348 requestCount.incrementAndGet(); 349 requestQueue.add(request); 350 MockResponse response = dispatcher.dispatch(request); 351 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AFTER_READING_REQUEST) { 352 logger.info("Received request: " + request + " and disconnected without responding"); 353 return false; 354 } 355 writeResponse(out, response); 356 357 // For socket policies that poison the socket after the response is written: 358 // The client has received the response and will no longer be blocked after 359 // writeResponse() has returned. A client can then re-use the connection before 360 // the socket is poisoned (i.e. keep-alive / connection pooling). The second 361 // request/response may fail at the beginning, middle, end, or even succeed 362 // depending on scheduling. Delays can be required in tests to improve the chances 363 // of sockets being in a known state when subsequent requests are made. 364 // 365 // For SHUTDOWN_OUTPUT_AT_END the client may detect a problem with its input socket 366 // after the request has been made but before the server has chosen a response. 367 // For clients that perform retries, this can cause the client to issue a retry 368 // request. The retry handler may call dispatcher.dispatch(request) before the 369 // initial, failed request handler does and cause non-obvious response ordering. 370 // Setting workerThreads = 1 ensures that the dispatcher is called for requests in 371 // the order they are received. 372 373 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) { 374 in.close(); 375 out.close(); 376 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) { 377 socket.shutdownInput(); 378 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) { 379 socket.shutdownOutput(); 380 } 381 logger.info("Received request: " + request + " and responded: " + response); 382 sequenceNumber++; 383 return true; 384 } 385 })); 386 } 387 388 private void processHandshakeFailure(Socket raw, int sequenceNumber) throws Exception { 389 X509TrustManager untrusted = new X509TrustManager() { 390 @Override public void checkClientTrusted(X509Certificate[] chain, String authType) 391 throws CertificateException { 392 throw new CertificateException(); 393 } 394 @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { 395 throw new AssertionError(); 396 } 397 @Override public X509Certificate[] getAcceptedIssuers() { 398 throw new AssertionError(); 399 } 400 }; 401 SSLContext context = SSLContext.getInstance("TLS"); 402 context.init(null, new TrustManager[] { untrusted }, new java.security.SecureRandom()); 403 SSLSocketFactory sslSocketFactory = context.getSocketFactory(); 404 SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket( 405 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 406 try { 407 socket.startHandshake(); // we're testing a handshake failure 408 throw new AssertionError(); 409 } catch (IOException expected) { 410 } 411 socket.close(); 412 } 413 414 private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket) throws InterruptedException { 415 requestCount.incrementAndGet(); 416 RecordedRequest request = new RecordedRequest(null, null, null, -1, null, sequenceNumber, 417 socket); 418 dispatcher.dispatch(request); 419 requestQueue.add(request); 420 } 421 422 /** 423 * @param sequenceNumber the index of this request on this connection. 424 */ 425 private RecordedRequest readRequest(Socket socket, InputStream in, int sequenceNumber) 426 throws IOException { 427 String request; 428 try { 429 request = readAsciiUntilCrlf(in); 430 } catch (IOException streamIsClosed) { 431 return null; // no request because we closed the stream 432 } 433 if (request.length() == 0) { 434 return null; // no request because the stream is exhausted 435 } 436 437 List<String> headers = new ArrayList<String>(); 438 int contentLength = -1; 439 boolean chunked = false; 440 String header; 441 while ((header = readAsciiUntilCrlf(in)).length() != 0) { 442 headers.add(header); 443 String lowercaseHeader = header.toLowerCase(); 444 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 445 contentLength = Integer.parseInt(header.substring(15).trim()); 446 } 447 if (lowercaseHeader.startsWith("transfer-encoding:") && 448 lowercaseHeader.substring(18).trim().equals("chunked")) { 449 chunked = true; 450 } 451 } 452 453 boolean hasBody = false; 454 TruncatingOutputStream requestBody = new TruncatingOutputStream(); 455 List<Integer> chunkSizes = new ArrayList<Integer>(); 456 if (contentLength != -1) { 457 hasBody = true; 458 transfer(contentLength, in, requestBody); 459 } else if (chunked) { 460 hasBody = true; 461 while (true) { 462 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); 463 if (chunkSize == 0) { 464 readEmptyLine(in); 465 break; 466 } 467 chunkSizes.add(chunkSize); 468 transfer(chunkSize, in, requestBody); 469 readEmptyLine(in); 470 } 471 } 472 473 if (request.startsWith("OPTIONS ") || request.startsWith("GET ") 474 || request.startsWith("HEAD ") || request.startsWith("DELETE ") 475 || request.startsWith("TRACE ") || request.startsWith("CONNECT ")) { 476 if (hasBody) { 477 throw new IllegalArgumentException("Request must not have a body: " + request); 478 } 479 } else if (!request.startsWith("POST ") && !request.startsWith("PUT ")) { 480 throw new UnsupportedOperationException("Unexpected method: " + request); 481 } 482 483 return new RecordedRequest(request, headers, chunkSizes, 484 requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket); 485 } 486 487 private void writeResponse(OutputStream out, MockResponse response) throws IOException { 488 out.write((response.getStatus() + "\r\n").getBytes(ASCII)); 489 for (String header : response.getHeaders()) { 490 out.write((header + "\r\n").getBytes(ASCII)); 491 } 492 out.write(("\r\n").getBytes(ASCII)); 493 out.flush(); 494 495 final InputStream in = response.getBodyStream(); 496 if (in == null) { 497 return; 498 } 499 final int bytesPerSecond = response.getBytesPerSecond(); 500 501 // Stream data in MTU-sized increments 502 final byte[] buffer = new byte[1452]; 503 final long delayMs; 504 if (bytesPerSecond == Integer.MAX_VALUE) { 505 delayMs = 0; 506 } else { 507 delayMs = (1000 * buffer.length) / bytesPerSecond; 508 } 509 510 int read; 511 long sinceDelay = 0; 512 while ((read = in.read(buffer)) != -1) { 513 out.write(buffer, 0, read); 514 out.flush(); 515 516 sinceDelay += read; 517 if (sinceDelay >= buffer.length && delayMs > 0) { 518 sinceDelay %= buffer.length; 519 try { 520 Thread.sleep(delayMs); 521 } catch (InterruptedException e) { 522 throw new AssertionError(); 523 } 524 } 525 } 526 } 527 528 /** 529 * Transfer bytes from {@code in} to {@code out} until either {@code length} 530 * bytes have been transferred or {@code in} is exhausted. 531 */ 532 private void transfer(int length, InputStream in, OutputStream out) throws IOException { 533 byte[] buffer = new byte[1024]; 534 while (length > 0) { 535 int count = in.read(buffer, 0, Math.min(buffer.length, length)); 536 if (count == -1) { 537 return; 538 } 539 out.write(buffer, 0, count); 540 length -= count; 541 } 542 } 543 544 /** 545 * Returns the text from {@code in} until the next "\r\n", or null if 546 * {@code in} is exhausted. 547 */ 548 private String readAsciiUntilCrlf(InputStream in) throws IOException { 549 StringBuilder builder = new StringBuilder(); 550 while (true) { 551 int c = in.read(); 552 if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { 553 builder.deleteCharAt(builder.length() - 1); 554 return builder.toString(); 555 } else if (c == -1) { 556 return builder.toString(); 557 } else { 558 builder.append((char) c); 559 } 560 } 561 } 562 563 private void readEmptyLine(InputStream in) throws IOException { 564 String line = readAsciiUntilCrlf(in); 565 if (line.length() != 0) { 566 throw new IllegalStateException("Expected empty but was: " + line); 567 } 568 } 569 570 /** 571 * Sets the dispatcher used to match incoming requests to mock responses. 572 * The default dispatcher simply serves a fixed sequence of responses from 573 * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the 574 * response based on timing or the content of the request. 575 */ 576 public void setDispatcher(Dispatcher dispatcher) { 577 if (dispatcher == null) { 578 throw new NullPointerException(); 579 } 580 this.dispatcher = dispatcher; 581 } 582 583 /** 584 * An output stream that drops data after bodyLimit bytes. 585 */ 586 private class TruncatingOutputStream extends ByteArrayOutputStream { 587 private int numBytesReceived = 0; 588 @Override public void write(byte[] buffer, int offset, int len) { 589 numBytesReceived += len; 590 super.write(buffer, offset, Math.min(len, bodyLimit - count)); 591 } 592 @Override public void write(int oneByte) { 593 numBytesReceived++; 594 if (count < bodyLimit) { 595 super.write(oneByte); 596 } 597 } 598 } 599 600 private static Runnable namedRunnable(final String name, final Runnable runnable) { 601 return new Runnable() { 602 public void run() { 603 String originalName = Thread.currentThread().getName(); 604 Thread.currentThread().setName(name); 605 try { 606 runnable.run(); 607 } finally { 608 Thread.currentThread().setName(originalName); 609 } 610 } 611 }; 612 } 613} 614