MockWebServer.java revision f162edaa335461474b020027bb2e85eb3be2c179
1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package tests.http;
18
19import java.io.BufferedInputStream;
20import java.io.BufferedOutputStream;
21import java.io.ByteArrayOutputStream;
22import java.io.IOException;
23import java.io.InputStream;
24import java.io.OutputStream;
25import java.net.InetAddress;
26import java.net.InetSocketAddress;
27import java.net.MalformedURLException;
28import java.net.Proxy;
29import java.net.ServerSocket;
30import java.net.Socket;
31import java.net.SocketException;
32import java.net.URL;
33import java.net.UnknownHostException;
34import java.util.ArrayList;
35import java.util.Collections;
36import java.util.HashSet;
37import java.util.Iterator;
38import java.util.List;
39import java.util.Set;
40import java.util.concurrent.BlockingQueue;
41import java.util.concurrent.ExecutorService;
42import java.util.concurrent.Executors;
43import java.util.concurrent.LinkedBlockingDeque;
44import java.util.concurrent.LinkedBlockingQueue;
45import java.util.concurrent.atomic.AtomicInteger;
46import java.util.logging.Level;
47import java.util.logging.Logger;
48import javax.net.ssl.SSLSocket;
49import javax.net.ssl.SSLSocketFactory;
50import static tests.http.SocketPolicy.DISCONNECT_AT_START;
51
52/**
53 * A scriptable web server. Callers supply canned responses and the server
54 * replays them upon request in sequence.
55 */
56public final class MockWebServer {
57
58    static final String ASCII = "US-ASCII";
59
60    private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
61    private final BlockingQueue<RecordedRequest> requestQueue
62            = new LinkedBlockingQueue<RecordedRequest>();
63    private final BlockingQueue<MockResponse> responseQueue
64            = new LinkedBlockingDeque<MockResponse>();
65    private final Set<Socket> openClientSockets
66            = Collections.synchronizedSet(new HashSet<Socket>());
67    private boolean singleResponse;
68    private final AtomicInteger requestCount = new AtomicInteger();
69    private int bodyLimit = Integer.MAX_VALUE;
70    private ServerSocket serverSocket;
71    private SSLSocketFactory sslSocketFactory;
72    private ExecutorService executor;
73    private boolean tunnelProxy;
74
75    private int port = -1;
76
77    public int getPort() {
78        if (port == -1) {
79            throw new IllegalStateException("Cannot retrieve port before calling play()");
80        }
81        return port;
82    }
83
84    public Proxy toProxyAddress() {
85        return new Proxy(Proxy.Type.HTTP, new InetSocketAddress("localhost", getPort()));
86    }
87
88    /**
89     * Returns a URL for connecting to this server.
90     *
91     * @param path the request path, such as "/".
92     */
93    public URL getUrl(String path) throws MalformedURLException, UnknownHostException {
94        String host = InetAddress.getLocalHost().getHostName();
95        return sslSocketFactory != null
96                ? new URL("https://" + host + ":" + getPort() + path)
97                : new URL("http://" + host + ":" + getPort() + path);
98    }
99
100    /**
101     * Sets the number of bytes of the POST body to keep in memory to the given
102     * limit.
103     */
104    public void setBodyLimit(int maxBodyLength) {
105        this.bodyLimit = maxBodyLength;
106    }
107
108    /**
109     * Serve requests with HTTPS rather than otherwise.
110     *
111     * @param tunnelProxy whether to expect the HTTP CONNECT method before
112     *     negotiating TLS.
113     */
114    public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
115        this.sslSocketFactory = sslSocketFactory;
116        this.tunnelProxy = tunnelProxy;
117    }
118
119    /**
120     * Awaits the next HTTP request, removes it, and returns it. Callers should
121     * use this to verify the request sent was as intended.
122     */
123    public RecordedRequest takeRequest() throws InterruptedException {
124        return requestQueue.take();
125    }
126
127    /**
128     * Returns the number of HTTP requests received thus far by this server.
129     * This may exceed the number of HTTP connections when connection reuse is
130     * in practice.
131     */
132    public int getRequestCount() {
133        return requestCount.get();
134    }
135
136    public void enqueue(MockResponse response) {
137        responseQueue.add(response);
138    }
139
140    /**
141     * By default, this class processes requests coming in by adding them to a
142     * queue and serves responses by removing them from another queue. This mode
143     * is appropriate for correctness testing.
144     *
145     * <p>Serving a single response causes the server to be stateless: requests
146     * are not enqueued, and responses are not dequeued. This mode is appropriate
147     * for benchmarking.
148     */
149    public void setSingleResponse(boolean singleResponse) {
150        this.singleResponse = singleResponse;
151    }
152
153    /**
154     * Starts the server, serves all enqueued requests, and shuts the server
155     * down.
156     */
157    public void play() throws IOException {
158        executor = Executors.newCachedThreadPool();
159        serverSocket = new ServerSocket(0);
160        serverSocket.setReuseAddress(true);
161
162        port = serverSocket.getLocalPort();
163        executor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() {
164            public void run() {
165                try {
166                    acceptConnections();
167                } catch (Throwable e) {
168                    logger.log(Level.WARNING, "MockWebServer connection failed", e);
169                }
170
171                /*
172                 * This gnarly block of code will release all sockets and
173                 * all thread, even if any close fails.
174                 */
175                try {
176                    serverSocket.close();
177                } catch (Throwable e) {
178                    logger.log(Level.WARNING, "MockWebServer server socket close failed", e);
179                }
180                for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext();) {
181                    try {
182                        s.next().close();
183                        s.remove();
184                    } catch (Throwable e) {
185                        logger.log(Level.WARNING, "MockWebServer socket close failed", e);
186                    }
187                }
188                try {
189                    executor.shutdown();
190                } catch (Throwable e) {
191                    logger.log(Level.WARNING, "MockWebServer executor shutdown failed", e);
192                }
193            }
194
195            private void acceptConnections() throws Exception {
196                do {
197                    Socket socket;
198                    try {
199                        socket = serverSocket.accept();
200                    } catch (SocketException ignored) {
201                        continue;
202                    }
203                    MockResponse peek = responseQueue.peek();
204                    if (peek != null && peek.getSocketPolicy() == DISCONNECT_AT_START) {
205                        responseQueue.take();
206                        socket.close();
207                    } else {
208                        openClientSockets.add(socket);
209                        serveConnection(socket);
210                    }
211                } while (!responseQueue.isEmpty());
212            }
213        }));
214    }
215
216    public void shutdown() throws IOException {
217        if (serverSocket != null) {
218            serverSocket.close(); // should cause acceptConnections() to break out
219        }
220    }
221
222    private void serveConnection(final Socket raw) {
223        String name = "MockWebServer-" + raw.getRemoteSocketAddress();
224        executor.execute(namedRunnable(name, new Runnable() {
225            int sequenceNumber = 0;
226
227            public void run() {
228                try {
229                    processConnection();
230                } catch (Exception e) {
231                    logger.log(Level.WARNING, "MockWebServer connection failed", e);
232                }
233            }
234
235            public void processConnection() throws Exception {
236                Socket socket;
237                if (sslSocketFactory != null) {
238                    if (tunnelProxy) {
239                        if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) {
240                            throw new IllegalStateException("Tunnel without any CONNECT!");
241                        }
242                    }
243                    socket = sslSocketFactory.createSocket(
244                            raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
245                    ((SSLSocket) socket).setUseClientMode(false);
246                    openClientSockets.add(socket);
247                    openClientSockets.remove(raw);
248                } else {
249                    socket = raw;
250                }
251
252                InputStream in = new BufferedInputStream(socket.getInputStream());
253                OutputStream out = new BufferedOutputStream(socket.getOutputStream());
254
255                while (!responseQueue.isEmpty() && processOneRequest(in, out, socket)) {}
256
257                if (sequenceNumber == 0) {
258                    logger.warning("MockWebServer connection didn't make a request");
259                }
260
261                in.close();
262                out.close();
263                socket.close();
264                if (responseQueue.isEmpty()) {
265                    shutdown();
266                }
267                openClientSockets.remove(socket);
268            }
269
270            /**
271             * Reads a request and writes its response. Returns true if a request
272             * was processed.
273             */
274            private boolean processOneRequest(InputStream in, OutputStream out, Socket socket)
275                    throws IOException, InterruptedException {
276                RecordedRequest request = readRequest(in, sequenceNumber);
277                if (request == null) {
278                    return false;
279                }
280                MockResponse response = dispatch(request);
281                writeResponse(out, response);
282                if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
283                    in.close();
284                    out.close();
285                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) {
286                    socket.shutdownInput();
287                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) {
288                    socket.shutdownOutput();
289                }
290                sequenceNumber++;
291                return true;
292            }
293        }));
294    }
295
296    /**
297     * @param sequenceNumber the index of this request on this connection.
298     */
299    private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
300        String request;
301        try {
302            request = readAsciiUntilCrlf(in);
303        } catch (IOException streamIsClosed) {
304            return null; // no request because we closed the stream
305        }
306        if (request.isEmpty()) {
307            return null; // no request because the stream is exhausted
308        }
309
310        List<String> headers = new ArrayList<String>();
311        int contentLength = -1;
312        boolean chunked = false;
313        String header;
314        while (!(header = readAsciiUntilCrlf(in)).isEmpty()) {
315            headers.add(header);
316            String lowercaseHeader = header.toLowerCase();
317            if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
318                contentLength = Integer.parseInt(header.substring(15).trim());
319            }
320            if (lowercaseHeader.startsWith("transfer-encoding:") &&
321                    lowercaseHeader.substring(18).trim().equals("chunked")) {
322                chunked = true;
323            }
324        }
325
326        boolean hasBody = false;
327        TruncatingOutputStream requestBody = new TruncatingOutputStream();
328        List<Integer> chunkSizes = new ArrayList<Integer>();
329        if (contentLength != -1) {
330            hasBody = true;
331            transfer(contentLength, in, requestBody);
332        } else if (chunked) {
333            hasBody = true;
334            while (true) {
335                int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
336                if (chunkSize == 0) {
337                    readEmptyLine(in);
338                    break;
339                }
340                chunkSizes.add(chunkSize);
341                transfer(chunkSize, in, requestBody);
342                readEmptyLine(in);
343            }
344        }
345
346        if (request.startsWith("GET ") || request.startsWith("CONNECT ")) {
347            if (hasBody) {
348                throw new IllegalArgumentException("GET requests should not have a body!");
349            }
350        } else if (request.startsWith("POST ")) {
351            if (!hasBody) {
352                throw new IllegalArgumentException("POST requests must have a body!");
353            }
354        } else {
355            throw new UnsupportedOperationException("Unexpected method: " + request);
356        }
357
358        return new RecordedRequest(request, headers, chunkSizes,
359                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
360    }
361
362    /**
363     * Returns a response to satisfy {@code request}.
364     */
365    private MockResponse dispatch(RecordedRequest request) throws InterruptedException {
366        if (responseQueue.isEmpty()) {
367            throw new IllegalStateException("Unexpected request: " + request);
368        }
369
370        if (singleResponse) {
371            return responseQueue.peek();
372        } else {
373            requestCount.incrementAndGet();
374            requestQueue.add(request);
375            return responseQueue.take();
376        }
377    }
378
379    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
380        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
381        for (String header : response.getHeaders()) {
382            out.write((header + "\r\n").getBytes(ASCII));
383        }
384        out.write(("\r\n").getBytes(ASCII));
385        out.write(response.getBody());
386        out.flush();
387    }
388
389    /**
390     * Transfer bytes from {@code in} to {@code out} until either {@code length}
391     * bytes have been transferred or {@code in} is exhausted.
392     */
393    private void transfer(int length, InputStream in, OutputStream out) throws IOException {
394        byte[] buffer = new byte[1024];
395        while (length > 0) {
396            int count = in.read(buffer, 0, Math.min(buffer.length, length));
397            if (count == -1) {
398                return;
399            }
400            out.write(buffer, 0, count);
401            length -= count;
402        }
403    }
404
405    /**
406     * Returns the text from {@code in} until the next "\r\n", or null if
407     * {@code in} is exhausted.
408     */
409    private String readAsciiUntilCrlf(InputStream in) throws IOException {
410        StringBuilder builder = new StringBuilder();
411        while (true) {
412            int c = in.read();
413            if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
414                builder.deleteCharAt(builder.length() - 1);
415                return builder.toString();
416            } else if (c == -1) {
417                return builder.toString();
418            } else {
419                builder.append((char) c);
420            }
421        }
422    }
423
424    private void readEmptyLine(InputStream in) throws IOException {
425        String line = readAsciiUntilCrlf(in);
426        if (!line.isEmpty()) {
427            throw new IllegalStateException("Expected empty but was: " + line);
428        }
429    }
430
431    /**
432     * An output stream that drops data after bodyLimit bytes.
433     */
434    private class TruncatingOutputStream extends ByteArrayOutputStream {
435        private int numBytesReceived = 0;
436        @Override public void write(byte[] buffer, int offset, int len) {
437            numBytesReceived += len;
438            super.write(buffer, offset, Math.min(len, bodyLimit - count));
439        }
440        @Override public void write(int oneByte) {
441            numBytesReceived++;
442            if (count < bodyLimit) {
443                super.write(oneByte);
444            }
445        }
446    }
447
448    private static Runnable namedRunnable(final String name, final Runnable runnable) {
449        return new Runnable() {
450            public void run() {
451                String originalName = Thread.currentThread().getName();
452                Thread.currentThread().setName(name);
453                try {
454                    runnable.run();
455                } finally {
456                    Thread.currentThread().setName(originalName);
457                }
458            }
459        };
460    }
461}
462