MockWebServer.java revision e942f46f10bb9384a1b186b3d7b74f9704c57090
1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package tests.http;
18
19import java.io.BufferedInputStream;
20import java.io.BufferedOutputStream;
21import java.io.ByteArrayOutputStream;
22import java.io.IOException;
23import java.io.InputStream;
24import java.io.OutputStream;
25import java.net.InetAddress;
26import java.net.InetSocketAddress;
27import java.net.MalformedURLException;
28import java.net.Proxy;
29import java.net.ServerSocket;
30import java.net.Socket;
31import java.net.URL;
32import java.net.UnknownHostException;
33import java.util.ArrayList;
34import java.util.Collections;
35import java.util.HashSet;
36import java.util.Iterator;
37import java.util.List;
38import java.util.Set;
39import java.util.concurrent.BlockingQueue;
40import java.util.concurrent.Callable;
41import java.util.concurrent.ExecutorService;
42import java.util.concurrent.Executors;
43import java.util.concurrent.LinkedBlockingDeque;
44import java.util.concurrent.LinkedBlockingQueue;
45import java.util.concurrent.atomic.AtomicInteger;
46import javax.net.ssl.SSLSocket;
47import javax.net.ssl.SSLSocketFactory;
48import static tests.http.SocketPolicy.DISCONNECT_AT_START;
49
50/**
51 * A scriptable web server. Callers supply canned responses and the server
52 * replays them upon request in sequence.
53 */
54public final class MockWebServer {
55
56    static final String ASCII = "US-ASCII";
57
58    private final BlockingQueue<RecordedRequest> requestQueue
59            = new LinkedBlockingQueue<RecordedRequest>();
60    private final BlockingQueue<MockResponse> responseQueue
61            = new LinkedBlockingDeque<MockResponse>();
62    private final Set<Socket> openClientSockets
63            = Collections.synchronizedSet(new HashSet<Socket>());
64    private boolean singleResponse;
65    private final AtomicInteger requestCount = new AtomicInteger();
66    private int bodyLimit = Integer.MAX_VALUE;
67    private ServerSocket serverSocket;
68    private SSLSocketFactory sslSocketFactory;
69    private ExecutorService executor;
70    private boolean tunnelProxy;
71
72    private int port = -1;
73
74    public int getPort() {
75        if (port == -1) {
76            throw new IllegalStateException("Cannot retrieve port before calling play()");
77        }
78        return port;
79    }
80
81    public Proxy toProxyAddress() {
82        return new Proxy(Proxy.Type.HTTP, new InetSocketAddress("localhost", getPort()));
83    }
84
85    /**
86     * Returns a URL for connecting to this server.
87     *
88     * @param path the request path, such as "/".
89     */
90    public URL getUrl(String path) throws MalformedURLException, UnknownHostException {
91        String host = InetAddress.getLocalHost().getHostName();
92        return sslSocketFactory != null
93                ? new URL("https://" + host + ":" + getPort() + path)
94                : new URL("http://" + host + ":" + getPort() + path);
95    }
96
97    /**
98     * Sets the number of bytes of the POST body to keep in memory to the given
99     * limit.
100     */
101    public void setBodyLimit(int maxBodyLength) {
102        this.bodyLimit = maxBodyLength;
103    }
104
105    /**
106     * Serve requests with HTTPS rather than otherwise.
107     *
108     * @param tunnelProxy whether to expect the HTTP CONNECT method before
109     *     negotiating TLS.
110     */
111    public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
112        this.sslSocketFactory = sslSocketFactory;
113        this.tunnelProxy = tunnelProxy;
114    }
115
116    /**
117     * Awaits the next HTTP request, removes it, and returns it. Callers should
118     * use this to verify the request sent was as intended.
119     */
120    public RecordedRequest takeRequest() throws InterruptedException {
121        return requestQueue.take();
122    }
123
124    /**
125     * Returns the number of HTTP requests received thus far by this server.
126     * This may exceed the number of HTTP connections when connection reuse is
127     * in practice.
128     */
129    public int getRequestCount() {
130        return requestCount.get();
131    }
132
133    public void enqueue(MockResponse response) {
134        responseQueue.add(response);
135    }
136
137    /**
138     * By default, this class processes requests coming in by adding them to a
139     * queue and serves responses by removing them from another queue. This mode
140     * is appropriate for correctness testing.
141     *
142     * <p>Serving a single response causes the server to be stateless: requests
143     * are not enqueued, and responses are not dequeued. This mode is appropriate
144     * for benchmarking.
145     */
146    public void setSingleResponse(boolean singleResponse) {
147        this.singleResponse = singleResponse;
148    }
149
150    /**
151     * Starts the server, serves all enqueued requests, and shuts the server
152     * down.
153     */
154    public void play() throws IOException {
155        executor = Executors.newCachedThreadPool();
156        serverSocket = new ServerSocket(0);
157        serverSocket.setReuseAddress(true);
158
159        port = serverSocket.getLocalPort();
160        executor.submit(namedCallable("MockWebServer-accept-" + port, new Callable<Void>() {
161            public Void call() throws Exception {
162                List<Throwable> failures = new ArrayList<Throwable>();
163                try {
164                    acceptConnections();
165                } catch (Throwable e) {
166                    failures.add(e);
167                }
168
169                /*
170                 * This gnarly block of code will release all sockets and
171                 * all thread, even if any close fails.
172                 */
173                try {
174                    serverSocket.close();
175                } catch (Throwable e) {
176                    failures.add(e);
177                }
178                for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext(); ) {
179                    try {
180                        s.next().close();
181                        s.remove();
182                    } catch (Throwable e) {
183                        failures.add(e);
184                    }
185                }
186                try {
187                    executor.shutdown();
188                } catch (Throwable e) {
189                    failures.add(e);
190                }
191                if (!failures.isEmpty()) {
192                    Throwable thrown = failures.get(0);
193                    if (thrown instanceof Exception) {
194                        throw (Exception) thrown;
195                    } else {
196                        throw (Error) thrown;
197                    }
198                } else {
199                    return null;
200                }
201            }
202
203            public void acceptConnections() throws Exception {
204                int count = 0;
205                while (true) {
206                    if (count > 0 && responseQueue.isEmpty()) {
207                        return;
208                    }
209
210                    Socket socket = serverSocket.accept();
211                    if (responseQueue.peek().getSocketPolicy() == DISCONNECT_AT_START) {
212                        responseQueue.take();
213                        socket.close();
214                        continue;
215                    }
216                    openClientSockets.add(socket);
217                    serveConnection(socket);
218                    count++;
219                }
220            }
221        }));
222    }
223
224    public void shutdown() throws IOException {
225        if (serverSocket != null) {
226            serverSocket.close(); // should cause acceptConnections() to break out
227        }
228    }
229
230    private void serveConnection(final Socket raw) {
231        String name = "MockWebServer-" + raw.getRemoteSocketAddress();
232        executor.submit(namedCallable(name, new Callable<Void>() {
233            int sequenceNumber = 0;
234
235            public Void call() throws Exception {
236                Socket socket;
237                if (sslSocketFactory != null) {
238                    if (tunnelProxy) {
239                        if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) {
240                            throw new IllegalStateException("Tunnel without any CONNECT!");
241                        }
242                    }
243                    socket = sslSocketFactory.createSocket(
244                            raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
245                    ((SSLSocket) socket).setUseClientMode(false);
246                    openClientSockets.add(socket);
247                    openClientSockets.remove(raw);
248                } else {
249                    socket = raw;
250                }
251
252                InputStream in = new BufferedInputStream(socket.getInputStream());
253                OutputStream out = new BufferedOutputStream(socket.getOutputStream());
254
255                if (!processOneRequest(in, out, socket)) {
256                    throw new IllegalStateException("Connection without any request!");
257                }
258                while (processOneRequest(in, out, socket)) {}
259
260                in.close();
261                out.close();
262                socket.close();
263                openClientSockets.remove(socket);
264                return null;
265            }
266
267            /**
268             * Reads a request and writes its response. Returns true if a request
269             * was processed.
270             */
271            private boolean processOneRequest(InputStream in, OutputStream out, Socket socket)
272                    throws IOException, InterruptedException {
273                RecordedRequest request = readRequest(in, sequenceNumber);
274                if (request == null) {
275                    return false;
276                }
277                MockResponse response = dispatch(request);
278                writeResponse(out, response);
279                if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
280                    in.close();
281                    out.close();
282                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) {
283                    socket.shutdownInput();
284                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) {
285                    socket.shutdownOutput();
286                }
287                sequenceNumber++;
288                return true;
289            }
290        }));
291    }
292
293    /**
294     * @param sequenceNumber the index of this request on this connection.
295     */
296    private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
297        String request = readAsciiUntilCrlf(in);
298        if (request.isEmpty()) {
299            return null; // end of data; no more requests
300        }
301
302        List<String> headers = new ArrayList<String>();
303        int contentLength = -1;
304        boolean chunked = false;
305        String header;
306        while (!(header = readAsciiUntilCrlf(in)).isEmpty()) {
307            headers.add(header);
308            String lowercaseHeader = header.toLowerCase();
309            if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
310                contentLength = Integer.parseInt(header.substring(15).trim());
311            }
312            if (lowercaseHeader.startsWith("transfer-encoding:") &&
313                    lowercaseHeader.substring(18).trim().equals("chunked")) {
314                chunked = true;
315            }
316        }
317
318        boolean hasBody = false;
319        TruncatingOutputStream requestBody = new TruncatingOutputStream();
320        List<Integer> chunkSizes = new ArrayList<Integer>();
321        if (contentLength != -1) {
322            hasBody = true;
323            transfer(contentLength, in, requestBody);
324        } else if (chunked) {
325            hasBody = true;
326            while (true) {
327                int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
328                if (chunkSize == 0) {
329                    readEmptyLine(in);
330                    break;
331                }
332                chunkSizes.add(chunkSize);
333                transfer(chunkSize, in, requestBody);
334                readEmptyLine(in);
335            }
336        }
337
338        if (request.startsWith("GET ") || request.startsWith("CONNECT ")) {
339            if (hasBody) {
340                throw new IllegalArgumentException("GET requests should not have a body!");
341            }
342        } else if (request.startsWith("POST ")) {
343            if (!hasBody) {
344                throw new IllegalArgumentException("POST requests must have a body!");
345            }
346        } else {
347            throw new UnsupportedOperationException("Unexpected method: " + request);
348        }
349
350        return new RecordedRequest(request, headers, chunkSizes,
351                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
352    }
353
354    /**
355     * Returns a response to satisfy {@code request}.
356     */
357    private MockResponse dispatch(RecordedRequest request) throws InterruptedException {
358        if (responseQueue.isEmpty()) {
359            throw new IllegalStateException("Unexpected request: " + request);
360        }
361
362        if (singleResponse) {
363            return responseQueue.peek();
364        } else {
365            requestCount.incrementAndGet();
366            requestQueue.add(request);
367            return responseQueue.take();
368        }
369    }
370
371    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
372        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
373        for (String header : response.getHeaders()) {
374            out.write((header + "\r\n").getBytes(ASCII));
375        }
376        out.write(("\r\n").getBytes(ASCII));
377        out.write(response.getBody());
378        out.flush();
379    }
380
381    /**
382     * Transfer bytes from {@code in} to {@code out} until either {@code length}
383     * bytes have been transferred or {@code in} is exhausted.
384     */
385    private void transfer(int length, InputStream in, OutputStream out) throws IOException {
386        byte[] buffer = new byte[1024];
387        while (length > 0) {
388            int count = in.read(buffer, 0, Math.min(buffer.length, length));
389            if (count == -1) {
390                return;
391            }
392            out.write(buffer, 0, count);
393            length -= count;
394        }
395    }
396
397    /**
398     * Returns the text from {@code in} until the next "\r\n", or null if
399     * {@code in} is exhausted.
400     */
401    private String readAsciiUntilCrlf(InputStream in) throws IOException {
402        StringBuilder builder = new StringBuilder();
403        while (true) {
404            int c = in.read();
405            if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
406                builder.deleteCharAt(builder.length() - 1);
407                return builder.toString();
408            } else if (c == -1) {
409                return builder.toString();
410            } else {
411                builder.append((char) c);
412            }
413        }
414    }
415
416    private void readEmptyLine(InputStream in) throws IOException {
417        String line = readAsciiUntilCrlf(in);
418        if (!line.isEmpty()) {
419            throw new IllegalStateException("Expected empty but was: " + line);
420        }
421    }
422
423    /**
424     * An output stream that drops data after bodyLimit bytes.
425     */
426    private class TruncatingOutputStream extends ByteArrayOutputStream {
427        private int numBytesReceived = 0;
428        @Override public void write(byte[] buffer, int offset, int len) {
429            numBytesReceived += len;
430            super.write(buffer, offset, Math.min(len, bodyLimit - count));
431        }
432        @Override public void write(int oneByte) {
433            numBytesReceived++;
434            if (count < bodyLimit) {
435                super.write(oneByte);
436            }
437        }
438    }
439
440    private static <T> Callable<T> namedCallable(final String name, final Callable<T> callable) {
441        return new Callable<T>() {
442            public T call() throws Exception {
443                String originalName = Thread.currentThread().getName();
444                Thread.currentThread().setName(name);
445                try {
446                    return callable.call();
447                } finally {
448                    Thread.currentThread().setName(originalName);
449                }
450            }
451        };
452    }
453}
454