MockWebServer.java revision 8ac847a52e72f0cefbb20a6850ae04468d433a9e
1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package tests.http;
18
19import java.io.BufferedInputStream;
20import java.io.BufferedOutputStream;
21import java.io.ByteArrayOutputStream;
22import java.io.IOException;
23import java.io.InputStream;
24import java.io.OutputStream;
25import java.net.InetAddress;
26import java.net.InetSocketAddress;
27import java.net.MalformedURLException;
28import java.net.Proxy;
29import java.net.ServerSocket;
30import java.net.Socket;
31import java.net.URL;
32import java.net.UnknownHostException;
33import java.util.ArrayList;
34import java.util.Collections;
35import java.util.HashSet;
36import java.util.Iterator;
37import java.util.List;
38import java.util.Set;
39import java.util.concurrent.BlockingQueue;
40import java.util.concurrent.Callable;
41import java.util.concurrent.ExecutorService;
42import java.util.concurrent.Executors;
43import java.util.concurrent.LinkedBlockingDeque;
44import java.util.concurrent.LinkedBlockingQueue;
45import java.util.concurrent.atomic.AtomicInteger;
46import javax.net.ssl.SSLSocket;
47import javax.net.ssl.SSLSocketFactory;
48
49/**
50 * A scriptable web server. Callers supply canned responses and the server
51 * replays them upon request in sequence.
52 */
53public final class MockWebServer {
54
55    static final String ASCII = "US-ASCII";
56
57    private final BlockingQueue<RecordedRequest> requestQueue
58            = new LinkedBlockingQueue<RecordedRequest>();
59    private final BlockingQueue<MockResponse> responseQueue
60            = new LinkedBlockingDeque<MockResponse>();
61    private final Set<Socket> openClientSockets
62            = Collections.synchronizedSet(new HashSet<Socket>());
63    private boolean singleResponse;
64    private final AtomicInteger requestCount = new AtomicInteger();
65    private int bodyLimit = Integer.MAX_VALUE;
66    private ServerSocket serverSocket;
67    private SSLSocketFactory sslSocketFactory;
68    private ExecutorService executor;
69    private boolean tunnelProxy;
70
71    private int port = -1;
72
73    public int getPort() {
74        if (port == -1) {
75            throw new IllegalStateException("Cannot retrieve port before calling play()");
76        }
77        return port;
78    }
79
80    public Proxy toProxyAddress() {
81        return new Proxy(Proxy.Type.HTTP, new InetSocketAddress("localhost", getPort()));
82    }
83
84    /**
85     * Returns a URL for connecting to this server.
86     *
87     * @param path the request path, such as "/".
88     */
89    public URL getUrl(String path) throws MalformedURLException, UnknownHostException {
90        String host = InetAddress.getLocalHost().getHostName();
91        return sslSocketFactory != null
92                ? new URL("https://" + host + ":" + getPort() + path)
93                : new URL("http://" + host + ":" + getPort() + path);
94    }
95
96    /**
97     * Sets the number of bytes of the POST body to keep in memory to the given
98     * limit.
99     */
100    public void setBodyLimit(int maxBodyLength) {
101        this.bodyLimit = maxBodyLength;
102    }
103
104    /**
105     * Serve requests with HTTPS rather than otherwise.
106     *
107     * @param tunnelProxy whether to expect the HTTP CONNECT method before
108     *     negotiating TLS.
109     */
110    public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
111        this.sslSocketFactory = sslSocketFactory;
112        this.tunnelProxy = tunnelProxy;
113    }
114
115    /**
116     * Awaits the next HTTP request, removes it, and returns it. Callers should
117     * use this to verify the request sent was as intended.
118     */
119    public RecordedRequest takeRequest() throws InterruptedException {
120        return requestQueue.take();
121    }
122
123    /**
124     * Returns the number of HTTP requests received thus far by this server.
125     * This may exceed the number of HTTP connections when connection reuse is
126     * in practice.
127     */
128    public int getRequestCount() {
129        return requestCount.get();
130    }
131
132    public void enqueue(MockResponse response) {
133        responseQueue.add(response);
134    }
135
136    /**
137     * By default, this class processes requests coming in by adding them to a
138     * queue and serves responses by removing them from another queue. This mode
139     * is appropriate for correctness testing.
140     *
141     * <p>Serving a single response causes the server to be stateless: requests
142     * are not enqueued, and responses are not dequeued. This mode is appropriate
143     * for benchmarking.
144     */
145    public void setSingleResponse(boolean singleResponse) {
146        this.singleResponse = singleResponse;
147    }
148
149    /**
150     * Starts the server, serves all enqueued requests, and shuts the server
151     * down.
152     */
153    public void play() throws IOException {
154        executor = Executors.newCachedThreadPool();
155        serverSocket = new ServerSocket(0);
156        serverSocket.setReuseAddress(true);
157
158        port = serverSocket.getLocalPort();
159        executor.submit(namedCallable("MockWebServer-accept-" + port, new Callable<Void>() {
160            public Void call() throws Exception {
161                List<Throwable> failures = new ArrayList<Throwable>();
162                try {
163                    acceptConnections();
164                } catch (Throwable e) {
165                    failures.add(e);
166                }
167
168                /*
169                 * This gnarly block of code will release all sockets and
170                 * all thread, even if any close fails.
171                 */
172                try {
173                    serverSocket.close();
174                } catch (Throwable e) {
175                    failures.add(e);
176                }
177                for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext(); ) {
178                    try {
179                        s.next().close();
180                        s.remove();
181                    } catch (Throwable e) {
182                        failures.add(e);
183                    }
184                }
185                try {
186                    executor.shutdown();
187                } catch (Throwable e) {
188                    failures.add(e);
189                }
190                if (!failures.isEmpty()) {
191                    Throwable thrown = failures.get(0);
192                    if (thrown instanceof Exception) {
193                        throw (Exception) thrown;
194                    } else {
195                        throw (Error) thrown;
196                    }
197                } else {
198                    return null;
199                }
200            }
201
202            public void acceptConnections() throws Exception {
203                int count = 0;
204                while (true) {
205                    if (count > 0 && responseQueue.isEmpty()) {
206                        return;
207                    }
208
209                    Socket socket = serverSocket.accept();
210                    if (responseQueue.peek().getDisconnectAtStart()) {
211                        responseQueue.take();
212                        socket.close();
213                        continue;
214                    }
215                    openClientSockets.add(socket);
216                    serveConnection(socket);
217                    count++;
218                }
219            }
220        }));
221    }
222
223    public void shutdown() throws IOException {
224        if (serverSocket != null) {
225            serverSocket.close(); // should cause acceptConnections() to break out
226        }
227    }
228
229    private void serveConnection(final Socket raw) {
230        String name = "MockWebServer-" + raw.getRemoteSocketAddress();
231        executor.submit(namedCallable(name, new Callable<Void>() {
232            int sequenceNumber = 0;
233
234            public Void call() throws Exception {
235                Socket socket;
236                if (sslSocketFactory != null) {
237                    if (tunnelProxy) {
238                        if (!processOneRequest(raw.getInputStream(), raw.getOutputStream())) {
239                            throw new IllegalStateException("Tunnel without any CONNECT!");
240                        }
241                    }
242                    socket = sslSocketFactory.createSocket(
243                            raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
244                    ((SSLSocket) socket).setUseClientMode(false);
245                    openClientSockets.add(socket);
246                    openClientSockets.remove(raw);
247                } else {
248                    socket = raw;
249                }
250
251                InputStream in = new BufferedInputStream(socket.getInputStream());
252                OutputStream out = new BufferedOutputStream(socket.getOutputStream());
253
254                if (!processOneRequest(in, out)) {
255                    throw new IllegalStateException("Connection without any request!");
256                }
257                while (processOneRequest(in, out)) {}
258
259                in.close();
260                out.close();
261                socket.close();
262                openClientSockets.remove(socket);
263                return null;
264            }
265
266            /**
267             * Reads a request and writes its response. Returns true if a request
268             * was processed.
269             */
270            private boolean processOneRequest(InputStream in, OutputStream out)
271                    throws IOException, InterruptedException {
272                RecordedRequest request = readRequest(in, sequenceNumber);
273                if (request == null) {
274                    return false;
275                }
276                MockResponse response = dispatch(request);
277                writeResponse(out, response);
278                if (response.getDisconnectAtEnd()) {
279                    in.close();
280                    out.close();
281                }
282                sequenceNumber++;
283                return true;
284            }
285        }));
286    }
287
288    /**
289     * @param sequenceNumber the index of this request on this connection.
290     */
291    private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
292        String request = readAsciiUntilCrlf(in);
293        if (request.isEmpty()) {
294            return null; // end of data; no more requests
295        }
296
297        List<String> headers = new ArrayList<String>();
298        int contentLength = -1;
299        boolean chunked = false;
300        String header;
301        while (!(header = readAsciiUntilCrlf(in)).isEmpty()) {
302            headers.add(header);
303            String lowercaseHeader = header.toLowerCase();
304            if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
305                contentLength = Integer.parseInt(header.substring(15).trim());
306            }
307            if (lowercaseHeader.startsWith("transfer-encoding:") &&
308                    lowercaseHeader.substring(18).trim().equals("chunked")) {
309                chunked = true;
310            }
311        }
312
313        boolean hasBody = false;
314        TruncatingOutputStream requestBody = new TruncatingOutputStream();
315        List<Integer> chunkSizes = new ArrayList<Integer>();
316        if (contentLength != -1) {
317            hasBody = true;
318            transfer(contentLength, in, requestBody);
319        } else if (chunked) {
320            hasBody = true;
321            while (true) {
322                int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
323                if (chunkSize == 0) {
324                    readEmptyLine(in);
325                    break;
326                }
327                chunkSizes.add(chunkSize);
328                transfer(chunkSize, in, requestBody);
329                readEmptyLine(in);
330            }
331        }
332
333        if (request.startsWith("GET ") || request.startsWith("CONNECT ")) {
334            if (hasBody) {
335                throw new IllegalArgumentException("GET requests should not have a body!");
336            }
337        } else if (request.startsWith("POST ")) {
338            if (!hasBody) {
339                throw new IllegalArgumentException("POST requests must have a body!");
340            }
341        } else {
342            throw new UnsupportedOperationException("Unexpected method: " + request);
343        }
344
345        return new RecordedRequest(request, headers, chunkSizes,
346                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
347    }
348
349    /**
350     * Returns a response to satisfy {@code request}.
351     */
352    private MockResponse dispatch(RecordedRequest request) throws InterruptedException {
353        if (responseQueue.isEmpty()) {
354            throw new IllegalStateException("Unexpected request: " + request);
355        }
356
357        if (singleResponse) {
358            return responseQueue.peek();
359        } else {
360            requestCount.incrementAndGet();
361            requestQueue.add(request);
362            return responseQueue.take();
363        }
364    }
365
366    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
367        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
368        for (String header : response.getHeaders()) {
369            out.write((header + "\r\n").getBytes(ASCII));
370        }
371        out.write(("\r\n").getBytes(ASCII));
372        out.write(response.getBody());
373        out.flush();
374    }
375
376    /**
377     * Transfer bytes from {@code in} to {@code out} until either {@code length}
378     * bytes have been transferred or {@code in} is exhausted.
379     */
380    private void transfer(int length, InputStream in, OutputStream out) throws IOException {
381        byte[] buffer = new byte[1024];
382        while (length > 0) {
383            int count = in.read(buffer, 0, Math.min(buffer.length, length));
384            if (count == -1) {
385                return;
386            }
387            out.write(buffer, 0, count);
388            length -= count;
389        }
390    }
391
392    /**
393     * Returns the text from {@code in} until the next "\r\n", or null if
394     * {@code in} is exhausted.
395     */
396    private String readAsciiUntilCrlf(InputStream in) throws IOException {
397        StringBuilder builder = new StringBuilder();
398        while (true) {
399            int c = in.read();
400            if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
401                builder.deleteCharAt(builder.length() - 1);
402                return builder.toString();
403            } else if (c == -1) {
404                return builder.toString();
405            } else {
406                builder.append((char) c);
407            }
408        }
409    }
410
411    private void readEmptyLine(InputStream in) throws IOException {
412        String line = readAsciiUntilCrlf(in);
413        if (!line.isEmpty()) {
414            throw new IllegalStateException("Expected empty but was: " + line);
415        }
416    }
417
418    /**
419     * An output stream that drops data after bodyLimit bytes.
420     */
421    private class TruncatingOutputStream extends ByteArrayOutputStream {
422        private int numBytesReceived = 0;
423        @Override public void write(byte[] buffer, int offset, int len) {
424            numBytesReceived += len;
425            super.write(buffer, offset, Math.min(len, bodyLimit - count));
426        }
427        @Override public void write(int oneByte) {
428            numBytesReceived++;
429            if (count < bodyLimit) {
430                super.write(oneByte);
431            }
432        }
433    }
434
435    private static <T> Callable<T> namedCallable(final String name, final Callable<T> callable) {
436        return new Callable<T>() {
437            public T call() throws Exception {
438                String originalName = Thread.currentThread().getName();
439                Thread.currentThread().setName(name);
440                try {
441                    return callable.call();
442                } finally {
443                    Thread.currentThread().setName(originalName);
444                }
445            }
446        };
447    }
448}
449