MockWebServer.java revision f162edaa335461474b020027bb2e85eb3be2c179
1/* 2 * Copyright (C) 2010 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package tests.http; 18 19import java.io.BufferedInputStream; 20import java.io.BufferedOutputStream; 21import java.io.ByteArrayOutputStream; 22import java.io.IOException; 23import java.io.InputStream; 24import java.io.OutputStream; 25import java.net.InetAddress; 26import java.net.InetSocketAddress; 27import java.net.MalformedURLException; 28import java.net.Proxy; 29import java.net.ServerSocket; 30import java.net.Socket; 31import java.net.SocketException; 32import java.net.URL; 33import java.net.UnknownHostException; 34import java.util.ArrayList; 35import java.util.Collections; 36import java.util.HashSet; 37import java.util.Iterator; 38import java.util.List; 39import java.util.Set; 40import java.util.concurrent.BlockingQueue; 41import java.util.concurrent.ExecutorService; 42import java.util.concurrent.Executors; 43import java.util.concurrent.LinkedBlockingDeque; 44import java.util.concurrent.LinkedBlockingQueue; 45import java.util.concurrent.atomic.AtomicInteger; 46import java.util.logging.Level; 47import java.util.logging.Logger; 48import javax.net.ssl.SSLSocket; 49import javax.net.ssl.SSLSocketFactory; 50import static tests.http.SocketPolicy.DISCONNECT_AT_START; 51 52/** 53 * A scriptable web server. Callers supply canned responses and the server 54 * replays them upon request in sequence. 55 */ 56public final class MockWebServer { 57 58 static final String ASCII = "US-ASCII"; 59 60 private static final Logger logger = Logger.getLogger(MockWebServer.class.getName()); 61 private final BlockingQueue<RecordedRequest> requestQueue 62 = new LinkedBlockingQueue<RecordedRequest>(); 63 private final BlockingQueue<MockResponse> responseQueue 64 = new LinkedBlockingDeque<MockResponse>(); 65 private final Set<Socket> openClientSockets 66 = Collections.synchronizedSet(new HashSet<Socket>()); 67 private boolean singleResponse; 68 private final AtomicInteger requestCount = new AtomicInteger(); 69 private int bodyLimit = Integer.MAX_VALUE; 70 private ServerSocket serverSocket; 71 private SSLSocketFactory sslSocketFactory; 72 private ExecutorService executor; 73 private boolean tunnelProxy; 74 75 private int port = -1; 76 77 public int getPort() { 78 if (port == -1) { 79 throw new IllegalStateException("Cannot retrieve port before calling play()"); 80 } 81 return port; 82 } 83 84 public Proxy toProxyAddress() { 85 return new Proxy(Proxy.Type.HTTP, new InetSocketAddress("localhost", getPort())); 86 } 87 88 /** 89 * Returns a URL for connecting to this server. 90 * 91 * @param path the request path, such as "/". 92 */ 93 public URL getUrl(String path) throws MalformedURLException, UnknownHostException { 94 String host = InetAddress.getLocalHost().getHostName(); 95 return sslSocketFactory != null 96 ? new URL("https://" + host + ":" + getPort() + path) 97 : new URL("http://" + host + ":" + getPort() + path); 98 } 99 100 /** 101 * Sets the number of bytes of the POST body to keep in memory to the given 102 * limit. 103 */ 104 public void setBodyLimit(int maxBodyLength) { 105 this.bodyLimit = maxBodyLength; 106 } 107 108 /** 109 * Serve requests with HTTPS rather than otherwise. 110 * 111 * @param tunnelProxy whether to expect the HTTP CONNECT method before 112 * negotiating TLS. 113 */ 114 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 115 this.sslSocketFactory = sslSocketFactory; 116 this.tunnelProxy = tunnelProxy; 117 } 118 119 /** 120 * Awaits the next HTTP request, removes it, and returns it. Callers should 121 * use this to verify the request sent was as intended. 122 */ 123 public RecordedRequest takeRequest() throws InterruptedException { 124 return requestQueue.take(); 125 } 126 127 /** 128 * Returns the number of HTTP requests received thus far by this server. 129 * This may exceed the number of HTTP connections when connection reuse is 130 * in practice. 131 */ 132 public int getRequestCount() { 133 return requestCount.get(); 134 } 135 136 public void enqueue(MockResponse response) { 137 responseQueue.add(response); 138 } 139 140 /** 141 * By default, this class processes requests coming in by adding them to a 142 * queue and serves responses by removing them from another queue. This mode 143 * is appropriate for correctness testing. 144 * 145 * <p>Serving a single response causes the server to be stateless: requests 146 * are not enqueued, and responses are not dequeued. This mode is appropriate 147 * for benchmarking. 148 */ 149 public void setSingleResponse(boolean singleResponse) { 150 this.singleResponse = singleResponse; 151 } 152 153 /** 154 * Starts the server, serves all enqueued requests, and shuts the server 155 * down. 156 */ 157 public void play() throws IOException { 158 executor = Executors.newCachedThreadPool(); 159 serverSocket = new ServerSocket(0); 160 serverSocket.setReuseAddress(true); 161 162 port = serverSocket.getLocalPort(); 163 executor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() { 164 public void run() { 165 try { 166 acceptConnections(); 167 } catch (Throwable e) { 168 logger.log(Level.WARNING, "MockWebServer connection failed", e); 169 } 170 171 /* 172 * This gnarly block of code will release all sockets and 173 * all thread, even if any close fails. 174 */ 175 try { 176 serverSocket.close(); 177 } catch (Throwable e) { 178 logger.log(Level.WARNING, "MockWebServer server socket close failed", e); 179 } 180 for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext();) { 181 try { 182 s.next().close(); 183 s.remove(); 184 } catch (Throwable e) { 185 logger.log(Level.WARNING, "MockWebServer socket close failed", e); 186 } 187 } 188 try { 189 executor.shutdown(); 190 } catch (Throwable e) { 191 logger.log(Level.WARNING, "MockWebServer executor shutdown failed", e); 192 } 193 } 194 195 private void acceptConnections() throws Exception { 196 do { 197 Socket socket; 198 try { 199 socket = serverSocket.accept(); 200 } catch (SocketException ignored) { 201 continue; 202 } 203 MockResponse peek = responseQueue.peek(); 204 if (peek != null && peek.getSocketPolicy() == DISCONNECT_AT_START) { 205 responseQueue.take(); 206 socket.close(); 207 } else { 208 openClientSockets.add(socket); 209 serveConnection(socket); 210 } 211 } while (!responseQueue.isEmpty()); 212 } 213 })); 214 } 215 216 public void shutdown() throws IOException { 217 if (serverSocket != null) { 218 serverSocket.close(); // should cause acceptConnections() to break out 219 } 220 } 221 222 private void serveConnection(final Socket raw) { 223 String name = "MockWebServer-" + raw.getRemoteSocketAddress(); 224 executor.execute(namedRunnable(name, new Runnable() { 225 int sequenceNumber = 0; 226 227 public void run() { 228 try { 229 processConnection(); 230 } catch (Exception e) { 231 logger.log(Level.WARNING, "MockWebServer connection failed", e); 232 } 233 } 234 235 public void processConnection() throws Exception { 236 Socket socket; 237 if (sslSocketFactory != null) { 238 if (tunnelProxy) { 239 if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) { 240 throw new IllegalStateException("Tunnel without any CONNECT!"); 241 } 242 } 243 socket = sslSocketFactory.createSocket( 244 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 245 ((SSLSocket) socket).setUseClientMode(false); 246 openClientSockets.add(socket); 247 openClientSockets.remove(raw); 248 } else { 249 socket = raw; 250 } 251 252 InputStream in = new BufferedInputStream(socket.getInputStream()); 253 OutputStream out = new BufferedOutputStream(socket.getOutputStream()); 254 255 while (!responseQueue.isEmpty() && processOneRequest(in, out, socket)) {} 256 257 if (sequenceNumber == 0) { 258 logger.warning("MockWebServer connection didn't make a request"); 259 } 260 261 in.close(); 262 out.close(); 263 socket.close(); 264 if (responseQueue.isEmpty()) { 265 shutdown(); 266 } 267 openClientSockets.remove(socket); 268 } 269 270 /** 271 * Reads a request and writes its response. Returns true if a request 272 * was processed. 273 */ 274 private boolean processOneRequest(InputStream in, OutputStream out, Socket socket) 275 throws IOException, InterruptedException { 276 RecordedRequest request = readRequest(in, sequenceNumber); 277 if (request == null) { 278 return false; 279 } 280 MockResponse response = dispatch(request); 281 writeResponse(out, response); 282 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) { 283 in.close(); 284 out.close(); 285 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) { 286 socket.shutdownInput(); 287 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) { 288 socket.shutdownOutput(); 289 } 290 sequenceNumber++; 291 return true; 292 } 293 })); 294 } 295 296 /** 297 * @param sequenceNumber the index of this request on this connection. 298 */ 299 private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException { 300 String request; 301 try { 302 request = readAsciiUntilCrlf(in); 303 } catch (IOException streamIsClosed) { 304 return null; // no request because we closed the stream 305 } 306 if (request.isEmpty()) { 307 return null; // no request because the stream is exhausted 308 } 309 310 List<String> headers = new ArrayList<String>(); 311 int contentLength = -1; 312 boolean chunked = false; 313 String header; 314 while (!(header = readAsciiUntilCrlf(in)).isEmpty()) { 315 headers.add(header); 316 String lowercaseHeader = header.toLowerCase(); 317 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 318 contentLength = Integer.parseInt(header.substring(15).trim()); 319 } 320 if (lowercaseHeader.startsWith("transfer-encoding:") && 321 lowercaseHeader.substring(18).trim().equals("chunked")) { 322 chunked = true; 323 } 324 } 325 326 boolean hasBody = false; 327 TruncatingOutputStream requestBody = new TruncatingOutputStream(); 328 List<Integer> chunkSizes = new ArrayList<Integer>(); 329 if (contentLength != -1) { 330 hasBody = true; 331 transfer(contentLength, in, requestBody); 332 } else if (chunked) { 333 hasBody = true; 334 while (true) { 335 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); 336 if (chunkSize == 0) { 337 readEmptyLine(in); 338 break; 339 } 340 chunkSizes.add(chunkSize); 341 transfer(chunkSize, in, requestBody); 342 readEmptyLine(in); 343 } 344 } 345 346 if (request.startsWith("GET ") || request.startsWith("CONNECT ")) { 347 if (hasBody) { 348 throw new IllegalArgumentException("GET requests should not have a body!"); 349 } 350 } else if (request.startsWith("POST ")) { 351 if (!hasBody) { 352 throw new IllegalArgumentException("POST requests must have a body!"); 353 } 354 } else { 355 throw new UnsupportedOperationException("Unexpected method: " + request); 356 } 357 358 return new RecordedRequest(request, headers, chunkSizes, 359 requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber); 360 } 361 362 /** 363 * Returns a response to satisfy {@code request}. 364 */ 365 private MockResponse dispatch(RecordedRequest request) throws InterruptedException { 366 if (responseQueue.isEmpty()) { 367 throw new IllegalStateException("Unexpected request: " + request); 368 } 369 370 if (singleResponse) { 371 return responseQueue.peek(); 372 } else { 373 requestCount.incrementAndGet(); 374 requestQueue.add(request); 375 return responseQueue.take(); 376 } 377 } 378 379 private void writeResponse(OutputStream out, MockResponse response) throws IOException { 380 out.write((response.getStatus() + "\r\n").getBytes(ASCII)); 381 for (String header : response.getHeaders()) { 382 out.write((header + "\r\n").getBytes(ASCII)); 383 } 384 out.write(("\r\n").getBytes(ASCII)); 385 out.write(response.getBody()); 386 out.flush(); 387 } 388 389 /** 390 * Transfer bytes from {@code in} to {@code out} until either {@code length} 391 * bytes have been transferred or {@code in} is exhausted. 392 */ 393 private void transfer(int length, InputStream in, OutputStream out) throws IOException { 394 byte[] buffer = new byte[1024]; 395 while (length > 0) { 396 int count = in.read(buffer, 0, Math.min(buffer.length, length)); 397 if (count == -1) { 398 return; 399 } 400 out.write(buffer, 0, count); 401 length -= count; 402 } 403 } 404 405 /** 406 * Returns the text from {@code in} until the next "\r\n", or null if 407 * {@code in} is exhausted. 408 */ 409 private String readAsciiUntilCrlf(InputStream in) throws IOException { 410 StringBuilder builder = new StringBuilder(); 411 while (true) { 412 int c = in.read(); 413 if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { 414 builder.deleteCharAt(builder.length() - 1); 415 return builder.toString(); 416 } else if (c == -1) { 417 return builder.toString(); 418 } else { 419 builder.append((char) c); 420 } 421 } 422 } 423 424 private void readEmptyLine(InputStream in) throws IOException { 425 String line = readAsciiUntilCrlf(in); 426 if (!line.isEmpty()) { 427 throw new IllegalStateException("Expected empty but was: " + line); 428 } 429 } 430 431 /** 432 * An output stream that drops data after bodyLimit bytes. 433 */ 434 private class TruncatingOutputStream extends ByteArrayOutputStream { 435 private int numBytesReceived = 0; 436 @Override public void write(byte[] buffer, int offset, int len) { 437 numBytesReceived += len; 438 super.write(buffer, offset, Math.min(len, bodyLimit - count)); 439 } 440 @Override public void write(int oneByte) { 441 numBytesReceived++; 442 if (count < bodyLimit) { 443 super.write(oneByte); 444 } 445 } 446 } 447 448 private static Runnable namedRunnable(final String name, final Runnable runnable) { 449 return new Runnable() { 450 public void run() { 451 String originalName = Thread.currentThread().getName(); 452 Thread.currentThread().setName(name); 453 try { 454 runnable.run(); 455 } finally { 456 Thread.currentThread().setName(originalName); 457 } 458 } 459 }; 460 } 461} 462