MockWebServer.java revision 8ac847a52e72f0cefbb20a6850ae04468d433a9e
1/* 2 * Copyright (C) 2010 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package tests.http; 18 19import java.io.BufferedInputStream; 20import java.io.BufferedOutputStream; 21import java.io.ByteArrayOutputStream; 22import java.io.IOException; 23import java.io.InputStream; 24import java.io.OutputStream; 25import java.net.InetAddress; 26import java.net.InetSocketAddress; 27import java.net.MalformedURLException; 28import java.net.Proxy; 29import java.net.ServerSocket; 30import java.net.Socket; 31import java.net.URL; 32import java.net.UnknownHostException; 33import java.util.ArrayList; 34import java.util.Collections; 35import java.util.HashSet; 36import java.util.Iterator; 37import java.util.List; 38import java.util.Set; 39import java.util.concurrent.BlockingQueue; 40import java.util.concurrent.Callable; 41import java.util.concurrent.ExecutorService; 42import java.util.concurrent.Executors; 43import java.util.concurrent.LinkedBlockingDeque; 44import java.util.concurrent.LinkedBlockingQueue; 45import java.util.concurrent.atomic.AtomicInteger; 46import javax.net.ssl.SSLSocket; 47import javax.net.ssl.SSLSocketFactory; 48 49/** 50 * A scriptable web server. Callers supply canned responses and the server 51 * replays them upon request in sequence. 52 */ 53public final class MockWebServer { 54 55 static final String ASCII = "US-ASCII"; 56 57 private final BlockingQueue<RecordedRequest> requestQueue 58 = new LinkedBlockingQueue<RecordedRequest>(); 59 private final BlockingQueue<MockResponse> responseQueue 60 = new LinkedBlockingDeque<MockResponse>(); 61 private final Set<Socket> openClientSockets 62 = Collections.synchronizedSet(new HashSet<Socket>()); 63 private boolean singleResponse; 64 private final AtomicInteger requestCount = new AtomicInteger(); 65 private int bodyLimit = Integer.MAX_VALUE; 66 private ServerSocket serverSocket; 67 private SSLSocketFactory sslSocketFactory; 68 private ExecutorService executor; 69 private boolean tunnelProxy; 70 71 private int port = -1; 72 73 public int getPort() { 74 if (port == -1) { 75 throw new IllegalStateException("Cannot retrieve port before calling play()"); 76 } 77 return port; 78 } 79 80 public Proxy toProxyAddress() { 81 return new Proxy(Proxy.Type.HTTP, new InetSocketAddress("localhost", getPort())); 82 } 83 84 /** 85 * Returns a URL for connecting to this server. 86 * 87 * @param path the request path, such as "/". 88 */ 89 public URL getUrl(String path) throws MalformedURLException, UnknownHostException { 90 String host = InetAddress.getLocalHost().getHostName(); 91 return sslSocketFactory != null 92 ? new URL("https://" + host + ":" + getPort() + path) 93 : new URL("http://" + host + ":" + getPort() + path); 94 } 95 96 /** 97 * Sets the number of bytes of the POST body to keep in memory to the given 98 * limit. 99 */ 100 public void setBodyLimit(int maxBodyLength) { 101 this.bodyLimit = maxBodyLength; 102 } 103 104 /** 105 * Serve requests with HTTPS rather than otherwise. 106 * 107 * @param tunnelProxy whether to expect the HTTP CONNECT method before 108 * negotiating TLS. 109 */ 110 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 111 this.sslSocketFactory = sslSocketFactory; 112 this.tunnelProxy = tunnelProxy; 113 } 114 115 /** 116 * Awaits the next HTTP request, removes it, and returns it. Callers should 117 * use this to verify the request sent was as intended. 118 */ 119 public RecordedRequest takeRequest() throws InterruptedException { 120 return requestQueue.take(); 121 } 122 123 /** 124 * Returns the number of HTTP requests received thus far by this server. 125 * This may exceed the number of HTTP connections when connection reuse is 126 * in practice. 127 */ 128 public int getRequestCount() { 129 return requestCount.get(); 130 } 131 132 public void enqueue(MockResponse response) { 133 responseQueue.add(response); 134 } 135 136 /** 137 * By default, this class processes requests coming in by adding them to a 138 * queue and serves responses by removing them from another queue. This mode 139 * is appropriate for correctness testing. 140 * 141 * <p>Serving a single response causes the server to be stateless: requests 142 * are not enqueued, and responses are not dequeued. This mode is appropriate 143 * for benchmarking. 144 */ 145 public void setSingleResponse(boolean singleResponse) { 146 this.singleResponse = singleResponse; 147 } 148 149 /** 150 * Starts the server, serves all enqueued requests, and shuts the server 151 * down. 152 */ 153 public void play() throws IOException { 154 executor = Executors.newCachedThreadPool(); 155 serverSocket = new ServerSocket(0); 156 serverSocket.setReuseAddress(true); 157 158 port = serverSocket.getLocalPort(); 159 executor.submit(namedCallable("MockWebServer-accept-" + port, new Callable<Void>() { 160 public Void call() throws Exception { 161 List<Throwable> failures = new ArrayList<Throwable>(); 162 try { 163 acceptConnections(); 164 } catch (Throwable e) { 165 failures.add(e); 166 } 167 168 /* 169 * This gnarly block of code will release all sockets and 170 * all thread, even if any close fails. 171 */ 172 try { 173 serverSocket.close(); 174 } catch (Throwable e) { 175 failures.add(e); 176 } 177 for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext(); ) { 178 try { 179 s.next().close(); 180 s.remove(); 181 } catch (Throwable e) { 182 failures.add(e); 183 } 184 } 185 try { 186 executor.shutdown(); 187 } catch (Throwable e) { 188 failures.add(e); 189 } 190 if (!failures.isEmpty()) { 191 Throwable thrown = failures.get(0); 192 if (thrown instanceof Exception) { 193 throw (Exception) thrown; 194 } else { 195 throw (Error) thrown; 196 } 197 } else { 198 return null; 199 } 200 } 201 202 public void acceptConnections() throws Exception { 203 int count = 0; 204 while (true) { 205 if (count > 0 && responseQueue.isEmpty()) { 206 return; 207 } 208 209 Socket socket = serverSocket.accept(); 210 if (responseQueue.peek().getDisconnectAtStart()) { 211 responseQueue.take(); 212 socket.close(); 213 continue; 214 } 215 openClientSockets.add(socket); 216 serveConnection(socket); 217 count++; 218 } 219 } 220 })); 221 } 222 223 public void shutdown() throws IOException { 224 if (serverSocket != null) { 225 serverSocket.close(); // should cause acceptConnections() to break out 226 } 227 } 228 229 private void serveConnection(final Socket raw) { 230 String name = "MockWebServer-" + raw.getRemoteSocketAddress(); 231 executor.submit(namedCallable(name, new Callable<Void>() { 232 int sequenceNumber = 0; 233 234 public Void call() throws Exception { 235 Socket socket; 236 if (sslSocketFactory != null) { 237 if (tunnelProxy) { 238 if (!processOneRequest(raw.getInputStream(), raw.getOutputStream())) { 239 throw new IllegalStateException("Tunnel without any CONNECT!"); 240 } 241 } 242 socket = sslSocketFactory.createSocket( 243 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 244 ((SSLSocket) socket).setUseClientMode(false); 245 openClientSockets.add(socket); 246 openClientSockets.remove(raw); 247 } else { 248 socket = raw; 249 } 250 251 InputStream in = new BufferedInputStream(socket.getInputStream()); 252 OutputStream out = new BufferedOutputStream(socket.getOutputStream()); 253 254 if (!processOneRequest(in, out)) { 255 throw new IllegalStateException("Connection without any request!"); 256 } 257 while (processOneRequest(in, out)) {} 258 259 in.close(); 260 out.close(); 261 socket.close(); 262 openClientSockets.remove(socket); 263 return null; 264 } 265 266 /** 267 * Reads a request and writes its response. Returns true if a request 268 * was processed. 269 */ 270 private boolean processOneRequest(InputStream in, OutputStream out) 271 throws IOException, InterruptedException { 272 RecordedRequest request = readRequest(in, sequenceNumber); 273 if (request == null) { 274 return false; 275 } 276 MockResponse response = dispatch(request); 277 writeResponse(out, response); 278 if (response.getDisconnectAtEnd()) { 279 in.close(); 280 out.close(); 281 } 282 sequenceNumber++; 283 return true; 284 } 285 })); 286 } 287 288 /** 289 * @param sequenceNumber the index of this request on this connection. 290 */ 291 private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException { 292 String request = readAsciiUntilCrlf(in); 293 if (request.isEmpty()) { 294 return null; // end of data; no more requests 295 } 296 297 List<String> headers = new ArrayList<String>(); 298 int contentLength = -1; 299 boolean chunked = false; 300 String header; 301 while (!(header = readAsciiUntilCrlf(in)).isEmpty()) { 302 headers.add(header); 303 String lowercaseHeader = header.toLowerCase(); 304 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 305 contentLength = Integer.parseInt(header.substring(15).trim()); 306 } 307 if (lowercaseHeader.startsWith("transfer-encoding:") && 308 lowercaseHeader.substring(18).trim().equals("chunked")) { 309 chunked = true; 310 } 311 } 312 313 boolean hasBody = false; 314 TruncatingOutputStream requestBody = new TruncatingOutputStream(); 315 List<Integer> chunkSizes = new ArrayList<Integer>(); 316 if (contentLength != -1) { 317 hasBody = true; 318 transfer(contentLength, in, requestBody); 319 } else if (chunked) { 320 hasBody = true; 321 while (true) { 322 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); 323 if (chunkSize == 0) { 324 readEmptyLine(in); 325 break; 326 } 327 chunkSizes.add(chunkSize); 328 transfer(chunkSize, in, requestBody); 329 readEmptyLine(in); 330 } 331 } 332 333 if (request.startsWith("GET ") || request.startsWith("CONNECT ")) { 334 if (hasBody) { 335 throw new IllegalArgumentException("GET requests should not have a body!"); 336 } 337 } else if (request.startsWith("POST ")) { 338 if (!hasBody) { 339 throw new IllegalArgumentException("POST requests must have a body!"); 340 } 341 } else { 342 throw new UnsupportedOperationException("Unexpected method: " + request); 343 } 344 345 return new RecordedRequest(request, headers, chunkSizes, 346 requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber); 347 } 348 349 /** 350 * Returns a response to satisfy {@code request}. 351 */ 352 private MockResponse dispatch(RecordedRequest request) throws InterruptedException { 353 if (responseQueue.isEmpty()) { 354 throw new IllegalStateException("Unexpected request: " + request); 355 } 356 357 if (singleResponse) { 358 return responseQueue.peek(); 359 } else { 360 requestCount.incrementAndGet(); 361 requestQueue.add(request); 362 return responseQueue.take(); 363 } 364 } 365 366 private void writeResponse(OutputStream out, MockResponse response) throws IOException { 367 out.write((response.getStatus() + "\r\n").getBytes(ASCII)); 368 for (String header : response.getHeaders()) { 369 out.write((header + "\r\n").getBytes(ASCII)); 370 } 371 out.write(("\r\n").getBytes(ASCII)); 372 out.write(response.getBody()); 373 out.flush(); 374 } 375 376 /** 377 * Transfer bytes from {@code in} to {@code out} until either {@code length} 378 * bytes have been transferred or {@code in} is exhausted. 379 */ 380 private void transfer(int length, InputStream in, OutputStream out) throws IOException { 381 byte[] buffer = new byte[1024]; 382 while (length > 0) { 383 int count = in.read(buffer, 0, Math.min(buffer.length, length)); 384 if (count == -1) { 385 return; 386 } 387 out.write(buffer, 0, count); 388 length -= count; 389 } 390 } 391 392 /** 393 * Returns the text from {@code in} until the next "\r\n", or null if 394 * {@code in} is exhausted. 395 */ 396 private String readAsciiUntilCrlf(InputStream in) throws IOException { 397 StringBuilder builder = new StringBuilder(); 398 while (true) { 399 int c = in.read(); 400 if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { 401 builder.deleteCharAt(builder.length() - 1); 402 return builder.toString(); 403 } else if (c == -1) { 404 return builder.toString(); 405 } else { 406 builder.append((char) c); 407 } 408 } 409 } 410 411 private void readEmptyLine(InputStream in) throws IOException { 412 String line = readAsciiUntilCrlf(in); 413 if (!line.isEmpty()) { 414 throw new IllegalStateException("Expected empty but was: " + line); 415 } 416 } 417 418 /** 419 * An output stream that drops data after bodyLimit bytes. 420 */ 421 private class TruncatingOutputStream extends ByteArrayOutputStream { 422 private int numBytesReceived = 0; 423 @Override public void write(byte[] buffer, int offset, int len) { 424 numBytesReceived += len; 425 super.write(buffer, offset, Math.min(len, bodyLimit - count)); 426 } 427 @Override public void write(int oneByte) { 428 numBytesReceived++; 429 if (count < bodyLimit) { 430 super.write(oneByte); 431 } 432 } 433 } 434 435 private static <T> Callable<T> namedCallable(final String name, final Callable<T> callable) { 436 return new Callable<T>() { 437 public T call() throws Exception { 438 String originalName = Thread.currentThread().getName(); 439 Thread.currentThread().setName(name); 440 try { 441 return callable.call(); 442 } finally { 443 Thread.currentThread().setName(originalName); 444 } 445 } 446 }; 447 } 448} 449