MockWebServer.java revision 706d53593cd8841d378dbe298a8d1940db1e71df
1/* 2 * Copyright (C) 2010 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package tests.http; 18 19import java.io.BufferedInputStream; 20import java.io.BufferedOutputStream; 21import java.io.ByteArrayOutputStream; 22import java.io.IOException; 23import java.io.InputStream; 24import java.io.OutputStream; 25import java.net.MalformedURLException; 26import java.net.ServerSocket; 27import java.net.Socket; 28import java.net.URL; 29import java.util.ArrayList; 30import java.util.List; 31import java.util.concurrent.BlockingQueue; 32import java.util.concurrent.Callable; 33import java.util.concurrent.ExecutorService; 34import java.util.concurrent.Executors; 35import java.util.concurrent.LinkedBlockingDeque; 36import java.util.concurrent.LinkedBlockingQueue; 37import javax.net.ssl.SSLSocket; 38import javax.net.ssl.SSLSocketFactory; 39 40/** 41 * A scriptable web server. Callers supply canned responses and the server 42 * replays them upon request in sequence. 43 */ 44public final class MockWebServer { 45 46 static final String ASCII = "US-ASCII"; 47 48 private final BlockingQueue<RecordedRequest> requestQueue 49 = new LinkedBlockingQueue<RecordedRequest>(); 50 private final BlockingQueue<MockResponse> responseQueue 51 = new LinkedBlockingDeque<MockResponse>(); 52 private int bodyLimit = Integer.MAX_VALUE; 53 private SSLSocketFactory sslSocketFactory; 54 private boolean tunnelProxy; 55 private final ExecutorService executor = Executors.newCachedThreadPool(); 56 57 private int port = -1; 58 59 public int getPort() { 60 if (port == -1) { 61 throw new IllegalStateException("Cannot retrieve port before calling play()"); 62 } 63 return port; 64 } 65 66 /** 67 * Returns a URL for connecting to this server. 68 * 69 * @param path the request path, such as "/". 70 */ 71 public URL getUrl(String path) throws MalformedURLException { 72 return new URL("http://localhost:" + getPort() + path); 73 } 74 75 /** 76 * Sets the number of bytes of the POST body to keep in memory to the given 77 * limit. 78 */ 79 public void setBodyLimit(int maxBodyLength) { 80 this.bodyLimit = maxBodyLength; 81 } 82 83 /** 84 * Serve requests with HTTPS rather than otherwise. 85 * 86 * @param tunnelProxy whether to expect the HTTP CONNECT method before 87 * negotiating TLS. 88 */ 89 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 90 this.sslSocketFactory = sslSocketFactory; 91 this.tunnelProxy = tunnelProxy; 92 } 93 94 /** 95 * Awaits the next HTTP request, removes it, and returns it. Callers should 96 * use this to verify the request sent was as intended. 97 */ 98 public RecordedRequest takeRequest() throws InterruptedException { 99 return requestQueue.take(); 100 } 101 102 public void enqueue(MockResponse response) { 103 responseQueue.add(response); 104 } 105 106 /** 107 * Starts the server, serves all enqueued requests, and shuts the server 108 * down. 109 */ 110 public void play() throws IOException { 111 final ServerSocket ss; 112 ss = new ServerSocket(0); 113 ss.setReuseAddress(true); 114 115 port = ss.getLocalPort(); 116 executor.submit(new Callable<Void>() { 117 public Void call() throws Exception { 118 int count = 0; 119 while (true) { 120 if (count > 0 && responseQueue.isEmpty()) { 121 ss.close(); 122 executor.shutdown(); 123 return null; 124 } 125 126 serveConnection(ss.accept()); 127 count++; 128 } 129 } 130 }); 131 } 132 133 private void serveConnection(final Socket raw) { 134 executor.submit(new Callable<Void>() { 135 int sequenceNumber = 0; 136 137 public Void call() throws Exception { 138 Socket socket; 139 if (sslSocketFactory != null) { 140 if (tunnelProxy) { 141 if (!processOneRequest(raw.getInputStream(), raw.getOutputStream())) { 142 throw new IllegalStateException("Tunnel without any CONNECT!"); 143 } 144 } 145 socket = sslSocketFactory.createSocket( 146 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 147 ((SSLSocket) socket).setUseClientMode(false); 148 } else { 149 socket = raw; 150 } 151 152 InputStream in = new BufferedInputStream(socket.getInputStream()); 153 OutputStream out = new BufferedOutputStream(socket.getOutputStream()); 154 155 if (!processOneRequest(in, out)) { 156 throw new IllegalStateException("Connection without any request!"); 157 } 158 while (processOneRequest(in, out)) {} 159 160 in.close(); 161 out.close(); 162 return null; 163 } 164 165 /** 166 * Reads a request and writes its response. Returns true if a request 167 * was processed. 168 */ 169 private boolean processOneRequest(InputStream in, OutputStream out) 170 throws IOException, InterruptedException { 171 RecordedRequest request = readRequest(in, sequenceNumber); 172 if (request == null) { 173 return false; 174 } 175 requestQueue.add(request); 176 writeResponse(out, computeResponse(request)); 177 sequenceNumber++; 178 return true; 179 } 180 }); 181 } 182 183 /** 184 * @param sequenceNumber the index of this request on this connection. 185 */ 186 private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException { 187 String request = readAsciiUntilCrlf(in); 188 if (request.isEmpty()) { 189 return null; // end of data; no more requests 190 } 191 192 List<String> headers = new ArrayList<String>(); 193 int contentLength = -1; 194 boolean chunked = false; 195 String header; 196 while (!(header = readAsciiUntilCrlf(in)).isEmpty()) { 197 headers.add(header); 198 String lowercaseHeader = header.toLowerCase(); 199 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 200 contentLength = Integer.parseInt(header.substring(15).trim()); 201 } 202 if (lowercaseHeader.startsWith("transfer-encoding:") && 203 lowercaseHeader.substring(18).trim().equals("chunked")) { 204 chunked = true; 205 } 206 } 207 208 boolean hasBody = false; 209 TruncatingOutputStream requestBody = new TruncatingOutputStream(); 210 List<Integer> chunkSizes = new ArrayList<Integer>(); 211 if (contentLength != -1) { 212 hasBody = true; 213 transfer(contentLength, in, requestBody); 214 } else if (chunked) { 215 hasBody = true; 216 while (true) { 217 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); 218 if (chunkSize == 0) { 219 readEmptyLine(in); 220 break; 221 } 222 chunkSizes.add(chunkSize); 223 transfer(chunkSize, in, requestBody); 224 readEmptyLine(in); 225 } 226 } 227 228 if (request.startsWith("GET ") || request.startsWith("CONNECT ")) { 229 if (hasBody) { 230 throw new IllegalArgumentException("GET requests should not have a body!"); 231 } 232 } else if (request.startsWith("POST ")) { 233 if (!hasBody) { 234 throw new IllegalArgumentException("POST requests must have a body!"); 235 } 236 } else { 237 throw new UnsupportedOperationException("Unexpected method: " + request); 238 } 239 240 return new RecordedRequest(request, headers, chunkSizes, 241 requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber); 242 } 243 244 /** 245 * Returns a response to satisfy {@code request}. 246 */ 247 private MockResponse computeResponse(RecordedRequest request) throws InterruptedException { 248 if (responseQueue.isEmpty()) { 249 throw new IllegalStateException("Unexpected request: " + request); 250 } 251 return responseQueue.take(); 252 } 253 254 private void writeResponse(OutputStream out, MockResponse response) throws IOException { 255 out.write((response.getStatus() + "\r\n").getBytes(ASCII)); 256 for (String header : response.getHeaders()) { 257 out.write((header + "\r\n").getBytes(ASCII)); 258 } 259 out.write(("\r\n").getBytes(ASCII)); 260 out.write(response.getBody()); 261 out.write(("\r\n").getBytes(ASCII)); 262 out.flush(); 263 } 264 265 /** 266 * Transfer bytes from {@code in} to {@code out} until either {@code length} 267 * bytes have been transferred or {@code in} is exhausted. 268 */ 269 private void transfer(int length, InputStream in, OutputStream out) throws IOException { 270 byte[] buffer = new byte[1024]; 271 while (length > 0) { 272 int count = in.read(buffer, 0, Math.min(buffer.length, length)); 273 if (count == -1) { 274 return; 275 } 276 out.write(buffer, 0, count); 277 length -= count; 278 } 279 } 280 281 /** 282 * Returns the text from {@code in} until the next "\r\n", or null if 283 * {@code in} is exhausted. 284 */ 285 private String readAsciiUntilCrlf(InputStream in) throws IOException { 286 StringBuilder builder = new StringBuilder(); 287 while (true) { 288 int c = in.read(); 289 if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { 290 builder.deleteCharAt(builder.length() - 1); 291 return builder.toString(); 292 } else if (c == -1) { 293 return builder.toString(); 294 } else { 295 builder.append((char) c); 296 } 297 } 298 } 299 300 private void readEmptyLine(InputStream in) throws IOException { 301 String line = readAsciiUntilCrlf(in); 302 if (!line.isEmpty()) { 303 throw new IllegalStateException("Expected empty but was: " + line); 304 } 305 } 306 307 /** 308 * An output stream that drops data after bodyLimit bytes. 309 */ 310 private class TruncatingOutputStream extends ByteArrayOutputStream { 311 private int numBytesReceived = 0; 312 @Override public void write(byte[] buffer, int offset, int len) { 313 numBytesReceived += len; 314 super.write(buffer, offset, Math.min(len, bodyLimit - count)); 315 } 316 @Override public void write(int oneByte) { 317 numBytesReceived++; 318 if (count < bodyLimit) { 319 super.write(oneByte); 320 } 321 } 322 } 323} 324