MockWebServer.java revision e942f46f10bb9384a1b186b3d7b74f9704c57090
1/* 2 * Copyright (C) 2010 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package tests.http; 18 19import java.io.BufferedInputStream; 20import java.io.BufferedOutputStream; 21import java.io.ByteArrayOutputStream; 22import java.io.IOException; 23import java.io.InputStream; 24import java.io.OutputStream; 25import java.net.InetAddress; 26import java.net.InetSocketAddress; 27import java.net.MalformedURLException; 28import java.net.Proxy; 29import java.net.ServerSocket; 30import java.net.Socket; 31import java.net.URL; 32import java.net.UnknownHostException; 33import java.util.ArrayList; 34import java.util.Collections; 35import java.util.HashSet; 36import java.util.Iterator; 37import java.util.List; 38import java.util.Set; 39import java.util.concurrent.BlockingQueue; 40import java.util.concurrent.Callable; 41import java.util.concurrent.ExecutorService; 42import java.util.concurrent.Executors; 43import java.util.concurrent.LinkedBlockingDeque; 44import java.util.concurrent.LinkedBlockingQueue; 45import java.util.concurrent.atomic.AtomicInteger; 46import javax.net.ssl.SSLSocket; 47import javax.net.ssl.SSLSocketFactory; 48import static tests.http.SocketPolicy.DISCONNECT_AT_START; 49 50/** 51 * A scriptable web server. Callers supply canned responses and the server 52 * replays them upon request in sequence. 53 */ 54public final class MockWebServer { 55 56 static final String ASCII = "US-ASCII"; 57 58 private final BlockingQueue<RecordedRequest> requestQueue 59 = new LinkedBlockingQueue<RecordedRequest>(); 60 private final BlockingQueue<MockResponse> responseQueue 61 = new LinkedBlockingDeque<MockResponse>(); 62 private final Set<Socket> openClientSockets 63 = Collections.synchronizedSet(new HashSet<Socket>()); 64 private boolean singleResponse; 65 private final AtomicInteger requestCount = new AtomicInteger(); 66 private int bodyLimit = Integer.MAX_VALUE; 67 private ServerSocket serverSocket; 68 private SSLSocketFactory sslSocketFactory; 69 private ExecutorService executor; 70 private boolean tunnelProxy; 71 72 private int port = -1; 73 74 public int getPort() { 75 if (port == -1) { 76 throw new IllegalStateException("Cannot retrieve port before calling play()"); 77 } 78 return port; 79 } 80 81 public Proxy toProxyAddress() { 82 return new Proxy(Proxy.Type.HTTP, new InetSocketAddress("localhost", getPort())); 83 } 84 85 /** 86 * Returns a URL for connecting to this server. 87 * 88 * @param path the request path, such as "/". 89 */ 90 public URL getUrl(String path) throws MalformedURLException, UnknownHostException { 91 String host = InetAddress.getLocalHost().getHostName(); 92 return sslSocketFactory != null 93 ? new URL("https://" + host + ":" + getPort() + path) 94 : new URL("http://" + host + ":" + getPort() + path); 95 } 96 97 /** 98 * Sets the number of bytes of the POST body to keep in memory to the given 99 * limit. 100 */ 101 public void setBodyLimit(int maxBodyLength) { 102 this.bodyLimit = maxBodyLength; 103 } 104 105 /** 106 * Serve requests with HTTPS rather than otherwise. 107 * 108 * @param tunnelProxy whether to expect the HTTP CONNECT method before 109 * negotiating TLS. 110 */ 111 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 112 this.sslSocketFactory = sslSocketFactory; 113 this.tunnelProxy = tunnelProxy; 114 } 115 116 /** 117 * Awaits the next HTTP request, removes it, and returns it. Callers should 118 * use this to verify the request sent was as intended. 119 */ 120 public RecordedRequest takeRequest() throws InterruptedException { 121 return requestQueue.take(); 122 } 123 124 /** 125 * Returns the number of HTTP requests received thus far by this server. 126 * This may exceed the number of HTTP connections when connection reuse is 127 * in practice. 128 */ 129 public int getRequestCount() { 130 return requestCount.get(); 131 } 132 133 public void enqueue(MockResponse response) { 134 responseQueue.add(response); 135 } 136 137 /** 138 * By default, this class processes requests coming in by adding them to a 139 * queue and serves responses by removing them from another queue. This mode 140 * is appropriate for correctness testing. 141 * 142 * <p>Serving a single response causes the server to be stateless: requests 143 * are not enqueued, and responses are not dequeued. This mode is appropriate 144 * for benchmarking. 145 */ 146 public void setSingleResponse(boolean singleResponse) { 147 this.singleResponse = singleResponse; 148 } 149 150 /** 151 * Starts the server, serves all enqueued requests, and shuts the server 152 * down. 153 */ 154 public void play() throws IOException { 155 executor = Executors.newCachedThreadPool(); 156 serverSocket = new ServerSocket(0); 157 serverSocket.setReuseAddress(true); 158 159 port = serverSocket.getLocalPort(); 160 executor.submit(namedCallable("MockWebServer-accept-" + port, new Callable<Void>() { 161 public Void call() throws Exception { 162 List<Throwable> failures = new ArrayList<Throwable>(); 163 try { 164 acceptConnections(); 165 } catch (Throwable e) { 166 failures.add(e); 167 } 168 169 /* 170 * This gnarly block of code will release all sockets and 171 * all thread, even if any close fails. 172 */ 173 try { 174 serverSocket.close(); 175 } catch (Throwable e) { 176 failures.add(e); 177 } 178 for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext(); ) { 179 try { 180 s.next().close(); 181 s.remove(); 182 } catch (Throwable e) { 183 failures.add(e); 184 } 185 } 186 try { 187 executor.shutdown(); 188 } catch (Throwable e) { 189 failures.add(e); 190 } 191 if (!failures.isEmpty()) { 192 Throwable thrown = failures.get(0); 193 if (thrown instanceof Exception) { 194 throw (Exception) thrown; 195 } else { 196 throw (Error) thrown; 197 } 198 } else { 199 return null; 200 } 201 } 202 203 public void acceptConnections() throws Exception { 204 int count = 0; 205 while (true) { 206 if (count > 0 && responseQueue.isEmpty()) { 207 return; 208 } 209 210 Socket socket = serverSocket.accept(); 211 if (responseQueue.peek().getSocketPolicy() == DISCONNECT_AT_START) { 212 responseQueue.take(); 213 socket.close(); 214 continue; 215 } 216 openClientSockets.add(socket); 217 serveConnection(socket); 218 count++; 219 } 220 } 221 })); 222 } 223 224 public void shutdown() throws IOException { 225 if (serverSocket != null) { 226 serverSocket.close(); // should cause acceptConnections() to break out 227 } 228 } 229 230 private void serveConnection(final Socket raw) { 231 String name = "MockWebServer-" + raw.getRemoteSocketAddress(); 232 executor.submit(namedCallable(name, new Callable<Void>() { 233 int sequenceNumber = 0; 234 235 public Void call() throws Exception { 236 Socket socket; 237 if (sslSocketFactory != null) { 238 if (tunnelProxy) { 239 if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) { 240 throw new IllegalStateException("Tunnel without any CONNECT!"); 241 } 242 } 243 socket = sslSocketFactory.createSocket( 244 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 245 ((SSLSocket) socket).setUseClientMode(false); 246 openClientSockets.add(socket); 247 openClientSockets.remove(raw); 248 } else { 249 socket = raw; 250 } 251 252 InputStream in = new BufferedInputStream(socket.getInputStream()); 253 OutputStream out = new BufferedOutputStream(socket.getOutputStream()); 254 255 if (!processOneRequest(in, out, socket)) { 256 throw new IllegalStateException("Connection without any request!"); 257 } 258 while (processOneRequest(in, out, socket)) {} 259 260 in.close(); 261 out.close(); 262 socket.close(); 263 openClientSockets.remove(socket); 264 return null; 265 } 266 267 /** 268 * Reads a request and writes its response. Returns true if a request 269 * was processed. 270 */ 271 private boolean processOneRequest(InputStream in, OutputStream out, Socket socket) 272 throws IOException, InterruptedException { 273 RecordedRequest request = readRequest(in, sequenceNumber); 274 if (request == null) { 275 return false; 276 } 277 MockResponse response = dispatch(request); 278 writeResponse(out, response); 279 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) { 280 in.close(); 281 out.close(); 282 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) { 283 socket.shutdownInput(); 284 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) { 285 socket.shutdownOutput(); 286 } 287 sequenceNumber++; 288 return true; 289 } 290 })); 291 } 292 293 /** 294 * @param sequenceNumber the index of this request on this connection. 295 */ 296 private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException { 297 String request = readAsciiUntilCrlf(in); 298 if (request.isEmpty()) { 299 return null; // end of data; no more requests 300 } 301 302 List<String> headers = new ArrayList<String>(); 303 int contentLength = -1; 304 boolean chunked = false; 305 String header; 306 while (!(header = readAsciiUntilCrlf(in)).isEmpty()) { 307 headers.add(header); 308 String lowercaseHeader = header.toLowerCase(); 309 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 310 contentLength = Integer.parseInt(header.substring(15).trim()); 311 } 312 if (lowercaseHeader.startsWith("transfer-encoding:") && 313 lowercaseHeader.substring(18).trim().equals("chunked")) { 314 chunked = true; 315 } 316 } 317 318 boolean hasBody = false; 319 TruncatingOutputStream requestBody = new TruncatingOutputStream(); 320 List<Integer> chunkSizes = new ArrayList<Integer>(); 321 if (contentLength != -1) { 322 hasBody = true; 323 transfer(contentLength, in, requestBody); 324 } else if (chunked) { 325 hasBody = true; 326 while (true) { 327 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); 328 if (chunkSize == 0) { 329 readEmptyLine(in); 330 break; 331 } 332 chunkSizes.add(chunkSize); 333 transfer(chunkSize, in, requestBody); 334 readEmptyLine(in); 335 } 336 } 337 338 if (request.startsWith("GET ") || request.startsWith("CONNECT ")) { 339 if (hasBody) { 340 throw new IllegalArgumentException("GET requests should not have a body!"); 341 } 342 } else if (request.startsWith("POST ")) { 343 if (!hasBody) { 344 throw new IllegalArgumentException("POST requests must have a body!"); 345 } 346 } else { 347 throw new UnsupportedOperationException("Unexpected method: " + request); 348 } 349 350 return new RecordedRequest(request, headers, chunkSizes, 351 requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber); 352 } 353 354 /** 355 * Returns a response to satisfy {@code request}. 356 */ 357 private MockResponse dispatch(RecordedRequest request) throws InterruptedException { 358 if (responseQueue.isEmpty()) { 359 throw new IllegalStateException("Unexpected request: " + request); 360 } 361 362 if (singleResponse) { 363 return responseQueue.peek(); 364 } else { 365 requestCount.incrementAndGet(); 366 requestQueue.add(request); 367 return responseQueue.take(); 368 } 369 } 370 371 private void writeResponse(OutputStream out, MockResponse response) throws IOException { 372 out.write((response.getStatus() + "\r\n").getBytes(ASCII)); 373 for (String header : response.getHeaders()) { 374 out.write((header + "\r\n").getBytes(ASCII)); 375 } 376 out.write(("\r\n").getBytes(ASCII)); 377 out.write(response.getBody()); 378 out.flush(); 379 } 380 381 /** 382 * Transfer bytes from {@code in} to {@code out} until either {@code length} 383 * bytes have been transferred or {@code in} is exhausted. 384 */ 385 private void transfer(int length, InputStream in, OutputStream out) throws IOException { 386 byte[] buffer = new byte[1024]; 387 while (length > 0) { 388 int count = in.read(buffer, 0, Math.min(buffer.length, length)); 389 if (count == -1) { 390 return; 391 } 392 out.write(buffer, 0, count); 393 length -= count; 394 } 395 } 396 397 /** 398 * Returns the text from {@code in} until the next "\r\n", or null if 399 * {@code in} is exhausted. 400 */ 401 private String readAsciiUntilCrlf(InputStream in) throws IOException { 402 StringBuilder builder = new StringBuilder(); 403 while (true) { 404 int c = in.read(); 405 if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { 406 builder.deleteCharAt(builder.length() - 1); 407 return builder.toString(); 408 } else if (c == -1) { 409 return builder.toString(); 410 } else { 411 builder.append((char) c); 412 } 413 } 414 } 415 416 private void readEmptyLine(InputStream in) throws IOException { 417 String line = readAsciiUntilCrlf(in); 418 if (!line.isEmpty()) { 419 throw new IllegalStateException("Expected empty but was: " + line); 420 } 421 } 422 423 /** 424 * An output stream that drops data after bodyLimit bytes. 425 */ 426 private class TruncatingOutputStream extends ByteArrayOutputStream { 427 private int numBytesReceived = 0; 428 @Override public void write(byte[] buffer, int offset, int len) { 429 numBytesReceived += len; 430 super.write(buffer, offset, Math.min(len, bodyLimit - count)); 431 } 432 @Override public void write(int oneByte) { 433 numBytesReceived++; 434 if (count < bodyLimit) { 435 super.write(oneByte); 436 } 437 } 438 } 439 440 private static <T> Callable<T> namedCallable(final String name, final Callable<T> callable) { 441 return new Callable<T>() { 442 public T call() throws Exception { 443 String originalName = Thread.currentThread().getName(); 444 Thread.currentThread().setName(name); 445 try { 446 return callable.call(); 447 } finally { 448 Thread.currentThread().setName(originalName); 449 } 450 } 451 }; 452 } 453} 454