1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package coretestutils.http;
18
19import java.io.BufferedInputStream;
20import java.io.BufferedOutputStream;
21import java.io.ByteArrayOutputStream;
22import java.io.File;
23import java.io.IOException;
24import java.io.InputStream;
25import java.io.OutputStream;
26import java.net.MalformedURLException;
27import java.net.ServerSocket;
28import java.net.Socket;
29import java.net.URL;
30import java.util.ArrayList;
31import java.util.LinkedList;
32import java.util.List;
33import java.util.Queue;
34import java.util.concurrent.BlockingQueue;
35import java.util.concurrent.Callable;
36import java.util.concurrent.ExecutionException;
37import java.util.concurrent.ExecutorService;
38import java.util.concurrent.Executors;
39import java.util.concurrent.Future;
40import java.util.concurrent.LinkedBlockingQueue;
41import java.util.concurrent.TimeUnit;
42import java.util.concurrent.TimeoutException;
43
44import android.util.Log;
45
46/**
47 * A scriptable web server. Callers supply canned responses and the server
48 * replays them upon request in sequence.
49 *
50 * TODO: merge with the version from libcore/support/src/tests/java once it's in.
51 */
52public final class MockWebServer {
53    static final String ASCII = "US-ASCII";
54    static final String LOG_TAG = "coretestutils.http.MockWebServer";
55
56    private final BlockingQueue<RecordedRequest> requestQueue
57            = new LinkedBlockingQueue<RecordedRequest>();
58    private final BlockingQueue<MockResponse> responseQueue
59            = new LinkedBlockingQueue<MockResponse>();
60    private int bodyLimit = Integer.MAX_VALUE;
61    private final ExecutorService executor = Executors.newCachedThreadPool();
62    // keep Futures around so we can rethrow any exceptions thrown by Callables
63    private final Queue<Future<?>> futures = new LinkedList<Future<?>>();
64    private final Object downloadPauseLock = new Object();
65    // global flag to signal when downloads should resume on the server
66    private volatile boolean downloadResume = false;
67
68    private int port = -1;
69
70    public int getPort() {
71        if (port == -1) {
72            throw new IllegalStateException("Cannot retrieve port before calling play()");
73        }
74        return port;
75    }
76
77    /**
78     * Returns a URL for connecting to this server.
79     *
80     * @param path the request path, such as "/".
81     */
82    public URL getUrl(String path) throws MalformedURLException {
83        return new URL("http://localhost:" + getPort() + path);
84    }
85
86    /**
87     * Sets the number of bytes of the POST body to keep in memory to the given
88     * limit.
89     */
90    public void setBodyLimit(int maxBodyLength) {
91        this.bodyLimit = maxBodyLength;
92    }
93
94    public void enqueue(MockResponse response) {
95        responseQueue.add(response);
96    }
97
98    /**
99     * Awaits the next HTTP request, removes it, and returns it. Callers should
100     * use this to verify the request sent was as intended.
101     */
102    public RecordedRequest takeRequest() throws InterruptedException {
103        return requestQueue.take();
104    }
105
106    public RecordedRequest takeRequestWithTimeout(long timeoutMillis) throws InterruptedException {
107        return requestQueue.poll(timeoutMillis, TimeUnit.MILLISECONDS);
108    }
109
110    public List<RecordedRequest> drainRequests() {
111        List<RecordedRequest> requests = new ArrayList<RecordedRequest>();
112        requestQueue.drainTo(requests);
113        return requests;
114    }
115
116    /**
117     * Starts the server, serves all enqueued requests, and shuts the server
118     * down using the default (server-assigned) port.
119     */
120    public void play() throws IOException {
121        play(0);
122    }
123
124    /**
125     * Starts the server, serves all enqueued requests, and shuts the server
126     * down.
127     *
128     * @param port The port number to use to listen to connections on; pass in 0 to have the
129     * server automatically assign a free port
130     */
131    public void play(int portNumber) throws IOException {
132        final ServerSocket ss = new ServerSocket(portNumber);
133        ss.setReuseAddress(true);
134        port = ss.getLocalPort();
135        submitCallable(new Callable<Void>() {
136            public Void call() throws Exception {
137                int count = 0;
138                while (true) {
139                    if (count > 0 && responseQueue.isEmpty()) {
140                        ss.close();
141                        executor.shutdown();
142                        return null;
143                    }
144
145                    serveConnection(ss.accept());
146                    count++;
147                }
148            }
149        });
150    }
151
152    private void serveConnection(final Socket s) {
153        submitCallable(new Callable<Void>() {
154            public Void call() throws Exception {
155                InputStream in = new BufferedInputStream(s.getInputStream());
156                OutputStream out = new BufferedOutputStream(s.getOutputStream());
157
158                int sequenceNumber = 0;
159                while (true) {
160                    RecordedRequest request = readRequest(in, sequenceNumber);
161                    if (request == null) {
162                        if (sequenceNumber == 0) {
163                            throw new IllegalStateException("Connection without any request!");
164                        } else {
165                            break;
166                        }
167                    }
168                    requestQueue.add(request);
169                    MockResponse response = computeResponse(request);
170                    writeResponse(out, response);
171                    if (response.shouldCloseConnectionAfter()) {
172                        break;
173                    }
174                    sequenceNumber++;
175                }
176
177                in.close();
178                out.close();
179                return null;
180            }
181        });
182    }
183
184    private void submitCallable(Callable<?> callable) {
185        Future<?> future = executor.submit(callable);
186        futures.add(future);
187    }
188
189    /**
190     * Check for and raise any exceptions that have been thrown by child threads.  Will not block on
191     * children still running.
192     * @throws ExecutionException for the first child thread that threw an exception
193     */
194    public void checkForExceptions() throws ExecutionException, InterruptedException {
195        final int originalSize = futures.size();
196        for (int i = 0; i < originalSize; i++) {
197            Future<?> future = futures.remove();
198            try {
199                future.get(0, TimeUnit.SECONDS);
200            } catch (TimeoutException e) {
201                futures.add(future); // still running
202            }
203        }
204    }
205
206    /**
207     * @param sequenceNumber the index of this request on this connection.
208     */
209    private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
210        String request = readAsciiUntilCrlf(in);
211        if (request.equals("")) {
212            return null; // end of data; no more requests
213        }
214
215        List<String> headers = new ArrayList<String>();
216        int contentLength = -1;
217        boolean chunked = false;
218        String header;
219        while (!(header = readAsciiUntilCrlf(in)).equals("")) {
220            headers.add(header);
221            String lowercaseHeader = header.toLowerCase();
222            if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
223                contentLength = Integer.parseInt(header.substring(15).trim());
224            }
225            if (lowercaseHeader.startsWith("transfer-encoding:") &&
226                    lowercaseHeader.substring(18).trim().equals("chunked")) {
227                chunked = true;
228            }
229        }
230
231        boolean hasBody = false;
232        TruncatingOutputStream requestBody = new TruncatingOutputStream();
233        List<Integer> chunkSizes = new ArrayList<Integer>();
234        if (contentLength != -1) {
235            hasBody = true;
236            transfer(contentLength, in, requestBody);
237        } else if (chunked) {
238            hasBody = true;
239            while (true) {
240                int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
241                if (chunkSize == 0) {
242                    readEmptyLine(in);
243                    break;
244                }
245                chunkSizes.add(chunkSize);
246                transfer(chunkSize, in, requestBody);
247                readEmptyLine(in);
248            }
249        }
250
251        if (request.startsWith("GET ")) {
252            if (hasBody) {
253                throw new IllegalArgumentException("GET requests should not have a body!");
254            }
255        } else if (request.startsWith("POST ")) {
256            if (!hasBody) {
257                throw new IllegalArgumentException("POST requests must have a body!");
258            }
259        } else {
260            throw new UnsupportedOperationException("Unexpected method: " + request);
261        }
262
263        return new RecordedRequest(request, headers, chunkSizes,
264                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
265    }
266
267    /**
268     * Returns a response to satisfy {@code request}.
269     */
270    private MockResponse computeResponse(RecordedRequest request) throws InterruptedException {
271        if (responseQueue.isEmpty()) {
272            throw new IllegalStateException("Unexpected request: " + request);
273        }
274        return responseQueue.take();
275    }
276
277    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
278        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
279        boolean doCloseConnectionAfterHeader = (response.getCloseConnectionAfterHeader() != null);
280
281        // Send headers
282        String closeConnectionAfterHeader = response.getCloseConnectionAfterHeader();
283        for (String header : response.getHeaders()) {
284            out.write((header + "\r\n").getBytes(ASCII));
285
286            if (doCloseConnectionAfterHeader && header.startsWith(closeConnectionAfterHeader)) {
287                Log.i(LOG_TAG, "Closing connection after header" + header);
288                break;
289            }
290        }
291
292        // Send actual body data
293        if (!doCloseConnectionAfterHeader) {
294            out.write(("\r\n").getBytes(ASCII));
295
296            InputStream body = response.getBody();
297            final int READ_BLOCK_SIZE = 10000;  // process blocks this size
298            byte[] currentBlock = new byte[READ_BLOCK_SIZE];
299            int currentBlockSize = 0;
300            int writtenSoFar = 0;
301
302            boolean shouldPause = response.getShouldPause();
303            boolean shouldClose = response.getShouldClose();
304            int pause = response.getPauseConnectionAfterXBytes();
305            int close = response.getCloseConnectionAfterXBytes();
306
307            // Don't bother pausing if it's set to pause -after- the connection should be dropped
308            if (shouldPause && shouldClose && (pause > close)) {
309                shouldPause = false;
310            }
311
312            // Process each block we read in...
313            while ((currentBlockSize = body.read(currentBlock)) != -1) {
314                int startIndex = 0;
315                int writeLength = currentBlockSize;
316
317                // handle the case of pausing
318                if (shouldPause && (writtenSoFar + currentBlockSize >= pause)) {
319                    writeLength = pause - writtenSoFar;
320                    out.write(currentBlock, 0, writeLength);
321                    out.flush();
322                    writtenSoFar += writeLength;
323
324                    // now pause...
325                    try {
326                        Log.i(LOG_TAG, "Pausing connection after " + pause + " bytes");
327                        // Wait until someone tells us to resume sending...
328                        synchronized(downloadPauseLock) {
329                            while (!downloadResume) {
330                                downloadPauseLock.wait();
331                            }
332                            // reset resume back to false
333                            downloadResume = false;
334                        }
335                    } catch (InterruptedException e) {
336                        Log.e(LOG_TAG, "Server was interrupted during pause in download.");
337                    }
338
339                    startIndex = writeLength;
340                    writeLength = currentBlockSize - writeLength;
341                }
342
343                // handle the case of closing the connection
344                if (shouldClose && (writtenSoFar + writeLength > close)) {
345                    writeLength = close - writtenSoFar;
346                    out.write(currentBlock, startIndex, writeLength);
347                    writtenSoFar += writeLength;
348                    Log.i(LOG_TAG, "Closing connection after " + close + " bytes");
349                    break;
350                }
351                out.write(currentBlock, startIndex, writeLength);
352                writtenSoFar += writeLength;
353            }
354        }
355        out.flush();
356    }
357
358    /**
359     * Transfer bytes from {@code in} to {@code out} until either {@code length}
360     * bytes have been transferred or {@code in} is exhausted.
361     */
362    private void transfer(int length, InputStream in, OutputStream out) throws IOException {
363        byte[] buffer = new byte[1024];
364        while (length > 0) {
365            int count = in.read(buffer, 0, Math.min(buffer.length, length));
366            if (count == -1) {
367                return;
368            }
369            out.write(buffer, 0, count);
370            length -= count;
371        }
372    }
373
374    /**
375     * Returns the text from {@code in} until the next "\r\n", or null if
376     * {@code in} is exhausted.
377     */
378    private String readAsciiUntilCrlf(InputStream in) throws IOException {
379        StringBuilder builder = new StringBuilder();
380        while (true) {
381            int c = in.read();
382            if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
383                builder.deleteCharAt(builder.length() - 1);
384                return builder.toString();
385            } else if (c == -1) {
386                return builder.toString();
387            } else {
388                builder.append((char) c);
389            }
390        }
391    }
392
393    private void readEmptyLine(InputStream in) throws IOException {
394        String line = readAsciiUntilCrlf(in);
395        if (!line.equals("")) {
396            throw new IllegalStateException("Expected empty but was: " + line);
397        }
398    }
399
400    /**
401     * An output stream that drops data after bodyLimit bytes.
402     */
403    private class TruncatingOutputStream extends ByteArrayOutputStream {
404        private int numBytesReceived = 0;
405        @Override public void write(byte[] buffer, int offset, int len) {
406            numBytesReceived += len;
407            super.write(buffer, offset, Math.min(len, bodyLimit - count));
408        }
409        @Override public void write(int oneByte) {
410            numBytesReceived++;
411            if (count < bodyLimit) {
412                super.write(oneByte);
413            }
414        }
415    }
416
417    /**
418     * Trigger the server to resume sending the download
419     */
420    public void doResumeDownload() {
421        synchronized (downloadPauseLock) {
422            downloadResume = true;
423            downloadPauseLock.notifyAll();
424        }
425    }
426}
427