MockWebServer.java revision 706d53593cd8841d378dbe298a8d1940db1e71df
1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package tests.http;
18
19import java.io.BufferedInputStream;
20import java.io.BufferedOutputStream;
21import java.io.ByteArrayOutputStream;
22import java.io.IOException;
23import java.io.InputStream;
24import java.io.OutputStream;
25import java.net.MalformedURLException;
26import java.net.ServerSocket;
27import java.net.Socket;
28import java.net.URL;
29import java.util.ArrayList;
30import java.util.List;
31import java.util.concurrent.BlockingQueue;
32import java.util.concurrent.Callable;
33import java.util.concurrent.ExecutorService;
34import java.util.concurrent.Executors;
35import java.util.concurrent.LinkedBlockingDeque;
36import java.util.concurrent.LinkedBlockingQueue;
37import javax.net.ssl.SSLSocket;
38import javax.net.ssl.SSLSocketFactory;
39
40/**
41 * A scriptable web server. Callers supply canned responses and the server
42 * replays them upon request in sequence.
43 */
44public final class MockWebServer {
45
46    static final String ASCII = "US-ASCII";
47
48    private final BlockingQueue<RecordedRequest> requestQueue
49            = new LinkedBlockingQueue<RecordedRequest>();
50    private final BlockingQueue<MockResponse> responseQueue
51            = new LinkedBlockingDeque<MockResponse>();
52    private int bodyLimit = Integer.MAX_VALUE;
53    private SSLSocketFactory sslSocketFactory;
54    private boolean tunnelProxy;
55    private final ExecutorService executor = Executors.newCachedThreadPool();
56
57    private int port = -1;
58
59    public int getPort() {
60        if (port == -1) {
61            throw new IllegalStateException("Cannot retrieve port before calling play()");
62        }
63        return port;
64    }
65
66    /**
67     * Returns a URL for connecting to this server.
68     *
69     * @param path the request path, such as "/".
70     */
71    public URL getUrl(String path) throws MalformedURLException {
72        return new URL("http://localhost:" + getPort() + path);
73    }
74
75    /**
76     * Sets the number of bytes of the POST body to keep in memory to the given
77     * limit.
78     */
79    public void setBodyLimit(int maxBodyLength) {
80        this.bodyLimit = maxBodyLength;
81    }
82
83    /**
84     * Serve requests with HTTPS rather than otherwise.
85     *
86     * @param tunnelProxy whether to expect the HTTP CONNECT method before
87     *     negotiating TLS.
88     */
89    public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
90        this.sslSocketFactory = sslSocketFactory;
91        this.tunnelProxy = tunnelProxy;
92    }
93
94    /**
95     * Awaits the next HTTP request, removes it, and returns it. Callers should
96     * use this to verify the request sent was as intended.
97     */
98    public RecordedRequest takeRequest() throws InterruptedException {
99        return requestQueue.take();
100    }
101
102    public void enqueue(MockResponse response) {
103        responseQueue.add(response);
104    }
105
106    /**
107     * Starts the server, serves all enqueued requests, and shuts the server
108     * down.
109     */
110    public void play() throws IOException {
111        final ServerSocket ss;
112        ss = new ServerSocket(0);
113        ss.setReuseAddress(true);
114
115        port = ss.getLocalPort();
116        executor.submit(new Callable<Void>() {
117            public Void call() throws Exception {
118                int count = 0;
119                while (true) {
120                    if (count > 0 && responseQueue.isEmpty()) {
121                        ss.close();
122                        executor.shutdown();
123                        return null;
124                    }
125
126                    serveConnection(ss.accept());
127                    count++;
128                }
129            }
130        });
131    }
132
133    private void serveConnection(final Socket raw) {
134        executor.submit(new Callable<Void>() {
135            int sequenceNumber = 0;
136
137            public Void call() throws Exception {
138                Socket socket;
139                if (sslSocketFactory != null) {
140                    if (tunnelProxy) {
141                        if (!processOneRequest(raw.getInputStream(), raw.getOutputStream())) {
142                            throw new IllegalStateException("Tunnel without any CONNECT!");
143                        }
144                    }
145                    socket = sslSocketFactory.createSocket(
146                            raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
147                    ((SSLSocket) socket).setUseClientMode(false);
148                } else {
149                    socket = raw;
150                }
151
152                InputStream in = new BufferedInputStream(socket.getInputStream());
153                OutputStream out = new BufferedOutputStream(socket.getOutputStream());
154
155                if (!processOneRequest(in, out)) {
156                    throw new IllegalStateException("Connection without any request!");
157                }
158                while (processOneRequest(in, out)) {}
159
160                in.close();
161                out.close();
162                return null;
163            }
164
165            /**
166             * Reads a request and writes its response. Returns true if a request
167             * was processed.
168             */
169            private boolean processOneRequest(InputStream in, OutputStream out)
170                    throws IOException, InterruptedException {
171                RecordedRequest request = readRequest(in, sequenceNumber);
172                if (request == null) {
173                    return false;
174                }
175                requestQueue.add(request);
176                writeResponse(out, computeResponse(request));
177                sequenceNumber++;
178                return true;
179            }
180        });
181    }
182
183    /**
184     * @param sequenceNumber the index of this request on this connection.
185     */
186    private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
187        String request = readAsciiUntilCrlf(in);
188        if (request.isEmpty()) {
189            return null; // end of data; no more requests
190        }
191
192        List<String> headers = new ArrayList<String>();
193        int contentLength = -1;
194        boolean chunked = false;
195        String header;
196        while (!(header = readAsciiUntilCrlf(in)).isEmpty()) {
197            headers.add(header);
198            String lowercaseHeader = header.toLowerCase();
199            if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
200                contentLength = Integer.parseInt(header.substring(15).trim());
201            }
202            if (lowercaseHeader.startsWith("transfer-encoding:") &&
203                    lowercaseHeader.substring(18).trim().equals("chunked")) {
204                chunked = true;
205            }
206        }
207
208        boolean hasBody = false;
209        TruncatingOutputStream requestBody = new TruncatingOutputStream();
210        List<Integer> chunkSizes = new ArrayList<Integer>();
211        if (contentLength != -1) {
212            hasBody = true;
213            transfer(contentLength, in, requestBody);
214        } else if (chunked) {
215            hasBody = true;
216            while (true) {
217                int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
218                if (chunkSize == 0) {
219                    readEmptyLine(in);
220                    break;
221                }
222                chunkSizes.add(chunkSize);
223                transfer(chunkSize, in, requestBody);
224                readEmptyLine(in);
225            }
226        }
227
228        if (request.startsWith("GET ") || request.startsWith("CONNECT ")) {
229            if (hasBody) {
230                throw new IllegalArgumentException("GET requests should not have a body!");
231            }
232        } else if (request.startsWith("POST ")) {
233            if (!hasBody) {
234                throw new IllegalArgumentException("POST requests must have a body!");
235            }
236        } else {
237            throw new UnsupportedOperationException("Unexpected method: " + request);
238        }
239
240        return new RecordedRequest(request, headers, chunkSizes,
241                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
242    }
243
244    /**
245     * Returns a response to satisfy {@code request}.
246     */
247    private MockResponse computeResponse(RecordedRequest request) throws InterruptedException {
248        if (responseQueue.isEmpty()) {
249            throw new IllegalStateException("Unexpected request: " + request);
250        }
251        return responseQueue.take();
252    }
253
254    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
255        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
256        for (String header : response.getHeaders()) {
257            out.write((header + "\r\n").getBytes(ASCII));
258        }
259        out.write(("\r\n").getBytes(ASCII));
260        out.write(response.getBody());
261        out.write(("\r\n").getBytes(ASCII));
262        out.flush();
263    }
264
265    /**
266     * Transfer bytes from {@code in} to {@code out} until either {@code length}
267     * bytes have been transferred or {@code in} is exhausted.
268     */
269    private void transfer(int length, InputStream in, OutputStream out) throws IOException {
270        byte[] buffer = new byte[1024];
271        while (length > 0) {
272            int count = in.read(buffer, 0, Math.min(buffer.length, length));
273            if (count == -1) {
274                return;
275            }
276            out.write(buffer, 0, count);
277            length -= count;
278        }
279    }
280
281    /**
282     * Returns the text from {@code in} until the next "\r\n", or null if
283     * {@code in} is exhausted.
284     */
285    private String readAsciiUntilCrlf(InputStream in) throws IOException {
286        StringBuilder builder = new StringBuilder();
287        while (true) {
288            int c = in.read();
289            if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
290                builder.deleteCharAt(builder.length() - 1);
291                return builder.toString();
292            } else if (c == -1) {
293                return builder.toString();
294            } else {
295                builder.append((char) c);
296            }
297        }
298    }
299
300    private void readEmptyLine(InputStream in) throws IOException {
301        String line = readAsciiUntilCrlf(in);
302        if (!line.isEmpty()) {
303            throw new IllegalStateException("Expected empty but was: " + line);
304        }
305    }
306
307    /**
308     * An output stream that drops data after bodyLimit bytes.
309     */
310    private class TruncatingOutputStream extends ByteArrayOutputStream {
311        private int numBytesReceived = 0;
312        @Override public void write(byte[] buffer, int offset, int len) {
313            numBytesReceived += len;
314            super.write(buffer, offset, Math.min(len, bodyLimit - count));
315        }
316        @Override public void write(int oneByte) {
317            numBytesReceived++;
318            if (count < bodyLimit) {
319                super.write(oneByte);
320            }
321        }
322    }
323}
324