1/* 2 * Copyright (C) 2011 Google Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package com.google.mockwebserver; 18 19import static com.google.mockwebserver.SocketPolicy.DISCONNECT_AT_START; 20import static com.google.mockwebserver.SocketPolicy.FAIL_HANDSHAKE; 21import java.io.BufferedInputStream; 22import java.io.BufferedOutputStream; 23import java.io.ByteArrayOutputStream; 24import java.io.IOException; 25import java.io.InputStream; 26import java.io.OutputStream; 27import java.net.InetAddress; 28import java.net.InetSocketAddress; 29import java.net.MalformedURLException; 30import java.net.Proxy; 31import java.net.ServerSocket; 32import java.net.Socket; 33import java.net.SocketException; 34import java.net.URL; 35import java.net.UnknownHostException; 36import java.nio.charset.StandardCharsets; 37import java.security.SecureRandom; 38import java.security.cert.CertificateException; 39import java.security.cert.X509Certificate; 40import java.util.ArrayList; 41import java.util.Iterator; 42import java.util.List; 43import java.util.Locale; 44import java.util.Map; 45import java.util.concurrent.BlockingQueue; 46import java.util.concurrent.ConcurrentHashMap; 47import java.util.concurrent.ExecutorService; 48import java.util.concurrent.Executors; 49import java.util.concurrent.LinkedBlockingQueue; 50import java.util.concurrent.atomic.AtomicInteger; 51import java.util.logging.Level; 52import java.util.logging.Logger; 53import javax.net.ssl.SSLContext; 54import javax.net.ssl.SSLSocket; 55import javax.net.ssl.SSLSocketFactory; 56import javax.net.ssl.TrustManager; 57import javax.net.ssl.X509TrustManager; 58 59/** 60 * A scriptable web server. Callers supply canned responses and the server 61 * replays them upon request in sequence. 62 */ 63public final class MockWebServer { 64 private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() { 65 @Override public void checkClientTrusted(X509Certificate[] chain, String authType) 66 throws CertificateException { 67 throw new CertificateException(); 68 } 69 70 @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { 71 throw new AssertionError(); 72 } 73 74 @Override public X509Certificate[] getAcceptedIssuers() { 75 throw new AssertionError(); 76 } 77 }; 78 79 private static final Logger logger = Logger.getLogger(MockWebServer.class.getName()); 80 81 private final BlockingQueue<RecordedRequest> requestQueue 82 = new LinkedBlockingQueue<RecordedRequest>(); 83 84 /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */ 85 private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>(); 86 private final AtomicInteger requestCount = new AtomicInteger(); 87 private int bodyLimit = Integer.MAX_VALUE; 88 private ServerSocket serverSocket; 89 private SSLSocketFactory sslSocketFactory; 90 private ExecutorService acceptExecutor; 91 private ExecutorService requestExecutor; 92 private boolean tunnelProxy; 93 private Dispatcher dispatcher = new QueueDispatcher(); 94 95 private int port = -1; 96 private int workerThreads = Integer.MAX_VALUE; 97 98 public int getPort() { 99 if (port == -1) { 100 throw new IllegalStateException("Cannot retrieve port before calling play()"); 101 } 102 return port; 103 } 104 105 public String getHostName() { 106 try { 107 return InetAddress.getLocalHost().getHostName(); 108 } catch (UnknownHostException e) { 109 throw new AssertionError(e); 110 } 111 } 112 113 public Proxy toProxyAddress() { 114 return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort())); 115 } 116 117 /** 118 * Returns a URL for connecting to this server. 119 * 120 * @param path the request path, such as "/". 121 */ 122 public URL getUrl(String path) { 123 try { 124 return sslSocketFactory != null 125 ? new URL("https://" + getHostName() + ":" + getPort() + path) 126 : new URL("http://" + getHostName() + ":" + getPort() + path); 127 } catch (MalformedURLException e) { 128 throw new AssertionError(e); 129 } 130 } 131 132 /** 133 * Returns a cookie domain for this server. This returns the server's 134 * non-loopback host name if it is known. Otherwise this returns ".local" 135 * for this server's loopback name. 136 */ 137 public String getCookieDomain() { 138 String hostName = getHostName(); 139 return hostName.contains(".") ? hostName : ".local"; 140 } 141 142 public void setWorkerThreads(int threads) { 143 this.workerThreads = threads; 144 } 145 146 /** 147 * Sets the number of bytes of the POST body to keep in memory to the given 148 * limit. 149 */ 150 public void setBodyLimit(int maxBodyLength) { 151 this.bodyLimit = maxBodyLength; 152 } 153 154 /** 155 * Serve requests with HTTPS rather than otherwise. 156 * 157 * @param tunnelProxy whether to expect the HTTP CONNECT method before 158 * negotiating TLS. 159 */ 160 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 161 this.sslSocketFactory = sslSocketFactory; 162 this.tunnelProxy = tunnelProxy; 163 } 164 165 /** 166 * Awaits the next HTTP request, removes it, and returns it. Callers should 167 * use this to verify the request sent was as intended. 168 */ 169 public RecordedRequest takeRequest() throws InterruptedException { 170 return requestQueue.take(); 171 } 172 173 /** 174 * Returns the number of HTTP requests received thus far by this server. 175 * This may exceed the number of HTTP connections when connection reuse is 176 * in practice. 177 */ 178 public int getRequestCount() { 179 return requestCount.get(); 180 } 181 182 /** 183 * Scripts {@code response} to be returned to a request made in sequence. 184 * The first request is served by the first enqueued response; the second 185 * request by the second enqueued response; and so on. 186 * 187 * @throws ClassCastException if the default dispatcher has been replaced 188 * with {@link #setDispatcher(Dispatcher)}. 189 */ 190 public void enqueue(MockResponse response) { 191 ((QueueDispatcher) dispatcher).enqueueResponse(response.clone()); 192 } 193 194 /** 195 * Equivalent to {@code play(0)}. 196 */ 197 public void play() throws IOException { 198 play(0); 199 } 200 201 /** 202 * Starts the server, serves all enqueued requests, and shuts the server 203 * down. 204 * 205 * @param port the port to listen to, or 0 for any available port. 206 * Automated tests should always use port 0 to avoid flakiness when a 207 * specific port is unavailable. 208 */ 209 public void play(int port) throws IOException { 210 if (acceptExecutor != null) { 211 throw new IllegalStateException("play() already called"); 212 } 213 // The acceptExecutor handles the Socket.accept() and hands each request off to the 214 // requestExecutor. It also handles shutdown. 215 acceptExecutor = Executors.newSingleThreadExecutor(); 216 // The requestExecutor has a fixed number of worker threads. In order to get strict 217 // guarantees that requests are handled in the order in which they are accepted 218 // workerThreads should be set to 1. 219 requestExecutor = Executors.newFixedThreadPool(workerThreads); 220 serverSocket = new ServerSocket(port); 221 serverSocket.setReuseAddress(true); 222 223 this.port = serverSocket.getLocalPort(); 224 acceptExecutor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() { 225 public void run() { 226 try { 227 acceptConnections(); 228 } catch (Throwable e) { 229 logger.log(Level.WARNING, "MockWebServer connection failed", e); 230 } 231 232 /* 233 * This gnarly block of code will release all sockets and 234 * all thread, even if any close fails. 235 */ 236 try { 237 serverSocket.close(); 238 } catch (Throwable e) { 239 logger.log(Level.WARNING, "MockWebServer server socket close failed", e); 240 } 241 for (Iterator<Socket> s = openClientSockets.keySet().iterator(); s.hasNext(); ) { 242 try { 243 s.next().close(); 244 s.remove(); 245 } catch (Throwable e) { 246 logger.log(Level.WARNING, "MockWebServer socket close failed", e); 247 } 248 } 249 try { 250 acceptExecutor.shutdown(); 251 } catch (Throwable e) { 252 logger.log(Level.WARNING, "MockWebServer acceptExecutor shutdown failed", e); 253 } 254 try { 255 requestExecutor.shutdown(); 256 } catch (Throwable e) { 257 logger.log(Level.WARNING, "MockWebServer requestExecutor shutdown failed", e); 258 } 259 } 260 261 private void acceptConnections() throws Exception { 262 while (true) { 263 Socket socket; 264 try { 265 socket = serverSocket.accept(); 266 } catch (SocketException e) { 267 return; 268 } 269 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 270 if (socketPolicy == DISCONNECT_AT_START) { 271 dispatchBookkeepingRequest(0, socket); 272 socket.close(); 273 } else { 274 openClientSockets.put(socket, true); 275 serveConnection(socket); 276 } 277 } 278 } 279 })); 280 } 281 282 public void shutdown() throws IOException { 283 if (serverSocket != null) { 284 serverSocket.close(); // should cause acceptConnections() to break out 285 } 286 } 287 288 private void serveConnection(final Socket raw) { 289 String name = "MockWebServer-" + raw.getRemoteSocketAddress(); 290 requestExecutor.execute(namedRunnable(name, new Runnable() { 291 int sequenceNumber = 0; 292 293 public void run() { 294 try { 295 processConnection(); 296 } catch (Exception e) { 297 logger.log(Level.WARNING, "MockWebServer connection failed", e); 298 } 299 } 300 301 public void processConnection() throws Exception { 302 Socket socket; 303 if (sslSocketFactory != null) { 304 if (tunnelProxy) { 305 createTunnel(); 306 } 307 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 308 if (socketPolicy == FAIL_HANDSHAKE) { 309 dispatchBookkeepingRequest(sequenceNumber, raw); 310 processHandshakeFailure(raw); 311 return; 312 } 313 socket = sslSocketFactory.createSocket( 314 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 315 SSLSocket sslSocket = (SSLSocket) socket; 316 sslSocket.setUseClientMode(false); 317 openClientSockets.put(socket, true); 318 319 sslSocket.startHandshake(); 320 321 openClientSockets.remove(raw); 322 } else { 323 socket = raw; 324 } 325 326 InputStream in = new BufferedInputStream(socket.getInputStream()); 327 OutputStream out = new BufferedOutputStream(socket.getOutputStream()); 328 329 while (processOneRequest(socket, in, out)) { 330 } 331 332 if (sequenceNumber == 0) { 333 logger.warning("MockWebServer connection didn't make a request"); 334 } 335 336 in.close(); 337 out.close(); 338 socket.close(); 339 openClientSockets.remove(socket); 340 } 341 342 /** 343 * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response 344 * is dispatched. 345 */ 346 private void createTunnel() throws IOException, InterruptedException { 347 while (true) { 348 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 349 if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) { 350 throw new IllegalStateException("Tunnel without any CONNECT!"); 351 } 352 if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return; 353 } 354 } 355 356 /** 357 * Reads a request and writes its response. Returns true if a request 358 * was processed. 359 */ 360 private boolean processOneRequest(Socket socket, InputStream in, OutputStream out) 361 throws IOException, InterruptedException { 362 RecordedRequest request = readRequest(socket, in, out, sequenceNumber); 363 if (request == null) { 364 return false; 365 } 366 requestCount.incrementAndGet(); 367 requestQueue.add(request); 368 MockResponse response = dispatcher.dispatch(request); 369 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AFTER_READING_REQUEST) { 370 logger.info("Received request: " + request + " and disconnected without responding"); 371 return false; 372 } 373 writeResponse(out, response); 374 375 // For socket policies that poison the socket after the response is written: 376 // The client has received the response and will no longer be blocked after 377 // writeResponse() has returned. A client can then re-use the connection before 378 // the socket is poisoned (i.e. keep-alive / connection pooling). The second 379 // request/response may fail at the beginning, middle, end, or even succeed 380 // depending on scheduling. Delays can be required in tests to improve the chances 381 // of sockets being in a known state when subsequent requests are made. 382 // 383 // For SHUTDOWN_OUTPUT_AT_END the client may detect a problem with its input socket 384 // after the request has been made but before the server has chosen a response. 385 // For clients that perform retries, this can cause the client to issue a retry 386 // request. The retry handler may call dispatcher.dispatch(request) before the 387 // initial, failed request handler does and cause non-obvious response ordering. 388 // Setting workerThreads = 1 ensures that the dispatcher is called for requests in 389 // the order they are received. 390 391 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) { 392 in.close(); 393 out.close(); 394 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) { 395 socket.shutdownInput(); 396 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) { 397 socket.shutdownOutput(); 398 } 399 logger.info("Received request: " + request + " and responded: " + response); 400 sequenceNumber++; 401 return true; 402 } 403 })); 404 } 405 406 private void processHandshakeFailure(Socket raw) throws Exception { 407 SSLContext context = SSLContext.getInstance("TLS"); 408 context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom()); 409 SSLSocketFactory sslSocketFactory = context.getSocketFactory(); 410 SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket( 411 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 412 try { 413 socket.startHandshake(); // we're testing a handshake failure 414 throw new AssertionError(); 415 } catch (IOException expected) { 416 } 417 socket.close(); 418 } 419 420 private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket) throws InterruptedException { 421 requestCount.incrementAndGet(); 422 RecordedRequest request = new RecordedRequest(null, null, null, -1, null, sequenceNumber, 423 socket); 424 dispatcher.dispatch(request); 425 } 426 427 /** @param sequenceNumber the index of this request on this connection. */ 428 private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out, 429 int sequenceNumber) throws IOException { 430 String request; 431 try { 432 request = readAsciiUntilCrlf(in); 433 } catch (IOException streamIsClosed) { 434 return null; // no request because we closed the stream 435 } 436 if (request.length() == 0) { 437 return null; // no request because the stream is exhausted 438 } 439 440 List<String> headers = new ArrayList<String>(); 441 long contentLength = -1; 442 boolean chunked = false; 443 boolean expectContinue = false; 444 String header; 445 while ((header = readAsciiUntilCrlf(in)).length() != 0) { 446 headers.add(header); 447 String lowercaseHeader = header.toLowerCase(Locale.US); 448 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 449 contentLength = Long.parseLong(header.substring(15).trim()); 450 } 451 if (lowercaseHeader.startsWith("transfer-encoding:") 452 && lowercaseHeader.substring(18).trim().equals("chunked")) { 453 chunked = true; 454 } 455 if (lowercaseHeader.startsWith("expect:") 456 && lowercaseHeader.substring(7).trim().equals("100-continue")) { 457 expectContinue = true; 458 } 459 } 460 461 if (expectContinue) { 462 out.write(("HTTP/1.1 100 Continue\r\n").getBytes(StandardCharsets.US_ASCII)); 463 out.write(("Content-Length: 0\r\n").getBytes(StandardCharsets.US_ASCII)); 464 out.write(("\r\n").getBytes(StandardCharsets.US_ASCII)); 465 out.flush(); 466 } 467 468 boolean hasBody = false; 469 TruncatingOutputStream requestBody = new TruncatingOutputStream(); 470 List<Integer> chunkSizes = new ArrayList<Integer>(); 471 MockResponse throttlePolicy = dispatcher.peek(); 472 if (contentLength != -1) { 473 hasBody = true; 474 throttledTransfer(throttlePolicy, in, requestBody, contentLength); 475 } else if (chunked) { 476 hasBody = true; 477 while (true) { 478 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); 479 if (chunkSize == 0) { 480 readEmptyLine(in); 481 break; 482 } 483 chunkSizes.add(chunkSize); 484 throttledTransfer(throttlePolicy, in, requestBody, chunkSize); 485 readEmptyLine(in); 486 } 487 } 488 489 if (request.startsWith("OPTIONS ") 490 || request.startsWith("GET ") 491 || request.startsWith("HEAD ") 492 || request.startsWith("TRACE ") 493 || request.startsWith("CONNECT ")) { 494 if (hasBody) { 495 throw new IllegalArgumentException("Request must not have a body: " + request); 496 } 497 } else if (!request.startsWith("POST ") 498 && !request.startsWith("PUT ") 499 && !request.startsWith("PATCH ") 500 && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous. 501 throw new UnsupportedOperationException("Unexpected method: " + request); 502 } 503 504 return new RecordedRequest(request, headers, chunkSizes, requestBody.numBytesReceived, 505 requestBody.toByteArray(), sequenceNumber, socket); 506 } 507 508 private void writeResponse(OutputStream out, MockResponse response) throws IOException { 509 out.write((response.getStatus() + "\r\n").getBytes(StandardCharsets.US_ASCII)); 510 List<String> headers = response.getHeaders(); 511 for (int i = 0, size = headers.size(); i < size; i++) { 512 String header = headers.get(i); 513 out.write((header + "\r\n").getBytes(StandardCharsets.US_ASCII)); 514 } 515 out.write(("\r\n").getBytes(StandardCharsets.US_ASCII)); 516 out.flush(); 517 518 InputStream in = response.getBodyStream(); 519 if (in == null) return; 520 throttledTransfer(response, in, out, Long.MAX_VALUE); 521 } 522 523 /** 524 * Transfer bytes from {@code in} to {@code out} until either {@code length} 525 * bytes have been transferred or {@code in} is exhausted. The transfer is 526 * throttled according to {@code throttlePolicy}. 527 */ 528 private void throttledTransfer(MockResponse throttlePolicy, InputStream in, OutputStream out, 529 long limit) throws IOException { 530 byte[] buffer = new byte[1024]; 531 int bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod(); 532 long delayMs = throttlePolicy.getThrottleUnit().toMillis(throttlePolicy.getThrottlePeriod()); 533 534 while (true) { 535 for (int b = 0; b < bytesPerPeriod; ) { 536 int toRead = (int) Math.min(Math.min(buffer.length, limit), bytesPerPeriod - b); 537 int read = in.read(buffer, 0, toRead); 538 if (read == -1) return; 539 540 out.write(buffer, 0, read); 541 out.flush(); 542 b += read; 543 limit -= read; 544 545 if (limit == 0) return; 546 } 547 548 try { 549 if (delayMs != 0) Thread.sleep(delayMs); 550 } catch (InterruptedException e) { 551 throw new AssertionError(); 552 } 553 } 554 } 555 556 /** 557 * Returns the text from {@code in} until the next "\r\n", or null if 558 * {@code in} is exhausted. 559 */ 560 private String readAsciiUntilCrlf(InputStream in) throws IOException { 561 StringBuilder builder = new StringBuilder(); 562 while (true) { 563 int c = in.read(); 564 if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { 565 builder.deleteCharAt(builder.length() - 1); 566 return builder.toString(); 567 } else if (c == -1) { 568 return builder.toString(); 569 } else { 570 builder.append((char) c); 571 } 572 } 573 } 574 575 private void readEmptyLine(InputStream in) throws IOException { 576 String line = readAsciiUntilCrlf(in); 577 if (line.length() != 0) { 578 throw new IllegalStateException("Expected empty but was: " + line); 579 } 580 } 581 582 /** 583 * Sets the dispatcher used to match incoming requests to mock responses. 584 * The default dispatcher simply serves a fixed sequence of responses from 585 * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the 586 * response based on timing or the content of the request. 587 */ 588 public void setDispatcher(Dispatcher dispatcher) { 589 if (dispatcher == null) { 590 throw new NullPointerException(); 591 } 592 this.dispatcher = dispatcher; 593 } 594 595 /** 596 * An output stream that drops data after bodyLimit bytes. 597 */ 598 private class TruncatingOutputStream extends ByteArrayOutputStream { 599 private int numBytesReceived = 0; 600 @Override public void write(byte[] buffer, int offset, int len) { 601 numBytesReceived += len; 602 super.write(buffer, offset, Math.min(len, bodyLimit - count)); 603 } 604 @Override public void write(int oneByte) { 605 numBytesReceived++; 606 if (count < bodyLimit) { 607 super.write(oneByte); 608 } 609 } 610 } 611 612 private static Runnable namedRunnable(final String name, final Runnable runnable) { 613 return new Runnable() { 614 public void run() { 615 String originalName = Thread.currentThread().getName(); 616 Thread.currentThread().setName(name); 617 try { 618 runnable.run(); 619 } finally { 620 Thread.currentThread().setName(originalName); 621 } 622 } 623 }; 624 } 625} 626