1/*
2 * Copyright (C) 2011 Google Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.google.mockwebserver;
18
19import static com.google.mockwebserver.SocketPolicy.DISCONNECT_AT_START;
20import static com.google.mockwebserver.SocketPolicy.FAIL_HANDSHAKE;
21import java.io.BufferedInputStream;
22import java.io.BufferedOutputStream;
23import java.io.ByteArrayOutputStream;
24import java.io.IOException;
25import java.io.InputStream;
26import java.io.OutputStream;
27import java.net.InetAddress;
28import java.net.InetSocketAddress;
29import java.net.MalformedURLException;
30import java.net.Proxy;
31import java.net.ServerSocket;
32import java.net.Socket;
33import java.net.SocketException;
34import java.net.URL;
35import java.net.UnknownHostException;
36import java.security.cert.CertificateException;
37import java.security.cert.X509Certificate;
38import java.util.ArrayList;
39import java.util.Iterator;
40import java.util.List;
41import java.util.Map;
42import java.util.concurrent.BlockingQueue;
43import java.util.concurrent.ConcurrentHashMap;
44import java.util.concurrent.ExecutorService;
45import java.util.concurrent.Executors;
46import java.util.concurrent.LinkedBlockingQueue;
47import java.util.concurrent.atomic.AtomicInteger;
48import java.util.logging.Level;
49import java.util.logging.Logger;
50import javax.net.ssl.SSLContext;
51import javax.net.ssl.SSLSocket;
52import javax.net.ssl.SSLSocketFactory;
53import javax.net.ssl.TrustManager;
54import javax.net.ssl.X509TrustManager;
55
56/**
57 * A scriptable web server. Callers supply canned responses and the server
58 * replays them upon request in sequence.
59 */
60public final class MockWebServer {
61
62    static final String ASCII = "US-ASCII";
63
64    private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
65    private final BlockingQueue<RecordedRequest> requestQueue
66            = new LinkedBlockingQueue<RecordedRequest>();
67    /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */
68    private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>();
69    private final AtomicInteger requestCount = new AtomicInteger();
70    private int bodyLimit = Integer.MAX_VALUE;
71    private ServerSocket serverSocket;
72    private SSLSocketFactory sslSocketFactory;
73    private ExecutorService acceptExecutor;
74    private ExecutorService requestExecutor;
75    private boolean tunnelProxy;
76    private Dispatcher dispatcher = new QueueDispatcher();
77
78    private int port = -1;
79    private int workerThreads = Integer.MAX_VALUE;
80
81
82    public int getPort() {
83        if (port == -1) {
84            throw new IllegalStateException("Cannot retrieve port before calling play()");
85        }
86        return port;
87    }
88
89    public String getHostName() {
90        try {
91            return InetAddress.getLocalHost().getHostName();
92        } catch (UnknownHostException e) {
93            throw new AssertionError();
94        }
95    }
96
97    public Proxy toProxyAddress() {
98        return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort()));
99    }
100
101    /**
102     * Returns a URL for connecting to this server.
103     *
104     * @param path the request path, such as "/".
105     */
106    public URL getUrl(String path) {
107        try {
108            return sslSocketFactory != null
109                    ? new URL("https://" + getHostName() + ":" + getPort() + path)
110                    : new URL("http://" + getHostName() + ":" + getPort() + path);
111        } catch (MalformedURLException e) {
112            throw new AssertionError(e);
113        }
114    }
115
116    /**
117     * Returns a cookie domain for this server. This returns the server's
118     * non-loopback host name if it is known. Otherwise this returns ".local"
119     * for this server's loopback name.
120     */
121    public String getCookieDomain() {
122        String hostName = getHostName();
123        return hostName.contains(".") ? hostName : ".local";
124    }
125
126    public void setWorkerThreads(int threads) {
127        this.workerThreads = threads;
128    }
129
130    /**
131     * Sets the number of bytes of the POST body to keep in memory to the given
132     * limit.
133     */
134    public void setBodyLimit(int maxBodyLength) {
135        this.bodyLimit = maxBodyLength;
136    }
137
138    /**
139     * Serve requests with HTTPS rather than otherwise.
140     *
141     * @param tunnelProxy whether to expect the HTTP CONNECT method before
142     *     negotiating TLS.
143     */
144    public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
145        this.sslSocketFactory = sslSocketFactory;
146        this.tunnelProxy = tunnelProxy;
147    }
148
149    /**
150     * Awaits the next HTTP request, removes it, and returns it. Callers should
151     * use this to verify the request sent was as intended.
152     */
153    public RecordedRequest takeRequest() throws InterruptedException {
154        return requestQueue.take();
155    }
156
157    /**
158     * Returns the number of HTTP requests received thus far by this server.
159     * This may exceed the number of HTTP connections when connection reuse is
160     * in practice.
161     */
162    public int getRequestCount() {
163        return requestCount.get();
164    }
165
166    /**
167     * Scripts {@code response} to be returned to a request made in sequence.
168     * The first request is served by the first enqueued response; the second
169     * request by the second enqueued response; and so on.
170     *
171     * @throws ClassCastException if the default dispatcher has been replaced
172     *     with {@link #setDispatcher(Dispatcher)}.
173     */
174    public void enqueue(MockResponse response) {
175        ((QueueDispatcher) dispatcher).enqueueResponse(response.clone());
176    }
177
178    /**
179     * Equivalent to {@code play(0)}.
180     */
181    public void play() throws IOException {
182        play(0);
183    }
184
185    /**
186     * Starts the server, serves all enqueued requests, and shuts the server
187     * down.
188     *
189     * @param port the port to listen to, or 0 for any available port.
190     *     Automated tests should always use port 0 to avoid flakiness when a
191     *     specific port is unavailable.
192     */
193    public void play(int port) throws IOException {
194        if (acceptExecutor != null) {
195            throw new IllegalStateException("play() already called");
196        }
197        // The acceptExecutor handles the Socket.accept() and hands each request off to the
198        // requestExecutor. It also handles shutdown.
199        acceptExecutor = Executors.newSingleThreadExecutor();
200        // The requestExecutor has a fixed number of worker threads. In order to get strict
201        // guarantees that requests are handled in the order in which they are accepted
202        // workerThreads should be set to 1.
203        requestExecutor = Executors.newFixedThreadPool(workerThreads);
204        serverSocket = new ServerSocket(port);
205        serverSocket.setReuseAddress(true);
206
207        this.port = serverSocket.getLocalPort();
208        acceptExecutor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() {
209            public void run() {
210                try {
211                    acceptConnections();
212                } catch (Throwable e) {
213                    logger.log(Level.WARNING, "MockWebServer connection failed", e);
214                }
215
216                /*
217                 * This gnarly block of code will release all sockets and
218                 * all thread, even if any close fails.
219                 */
220                try {
221                    serverSocket.close();
222                } catch (Throwable e) {
223                    logger.log(Level.WARNING, "MockWebServer server socket close failed", e);
224                }
225                for (Iterator<Socket> s = openClientSockets.keySet().iterator(); s.hasNext(); ) {
226                    try {
227                        s.next().close();
228                        s.remove();
229                    } catch (Throwable e) {
230                        logger.log(Level.WARNING, "MockWebServer socket close failed", e);
231                    }
232                }
233                try {
234                    acceptExecutor.shutdown();
235                } catch (Throwable e) {
236                    logger.log(Level.WARNING, "MockWebServer acceptExecutor shutdown failed", e);
237                }
238                try {
239                    requestExecutor.shutdown();
240                } catch (Throwable e) {
241                    logger.log(Level.WARNING, "MockWebServer requestExecutor shutdown failed", e);
242                }
243            }
244
245            private void acceptConnections() throws Exception {
246                while (true) {
247                    Socket socket;
248                    try {
249                        socket = serverSocket.accept();
250                    } catch (SocketException e) {
251                        return;
252                    }
253                    final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
254                    if (socketPolicy == DISCONNECT_AT_START) {
255                        dispatchBookkeepingRequest(0, socket);
256                        socket.close();
257                    } else {
258                        openClientSockets.put(socket, true);
259                        serveConnection(socket);
260                    }
261                }
262            }
263        }));
264    }
265
266    public void shutdown() throws IOException {
267        if (serverSocket != null) {
268            serverSocket.close(); // should cause acceptConnections() to break out
269        }
270    }
271
272    private void serveConnection(final Socket raw) {
273        String name = "MockWebServer-" + raw.getRemoteSocketAddress();
274        requestExecutor.execute(namedRunnable(name, new Runnable() {
275            int sequenceNumber = 0;
276
277            public void run() {
278                try {
279                    processConnection();
280                } catch (Exception e) {
281                    logger.log(Level.WARNING, "MockWebServer connection failed", e);
282                }
283            }
284
285            public void processConnection() throws Exception {
286                Socket socket;
287                if (sslSocketFactory != null) {
288                    if (tunnelProxy) {
289                        createTunnel();
290                    }
291                    final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
292                    if (socketPolicy == FAIL_HANDSHAKE) {
293                        dispatchBookkeepingRequest(sequenceNumber, raw);
294                        processHandshakeFailure(raw, sequenceNumber++);
295                        return;
296                    }
297                    socket = sslSocketFactory.createSocket(
298                            raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
299                    ((SSLSocket) socket).setUseClientMode(false);
300                    openClientSockets.put(socket, true);
301                    openClientSockets.remove(raw);
302                } else {
303                    socket = raw;
304                }
305
306                InputStream in = new BufferedInputStream(socket.getInputStream());
307                OutputStream out = new BufferedOutputStream(socket.getOutputStream());
308
309                while (processOneRequest(socket, in, out)) {
310                }
311
312                if (sequenceNumber == 0) {
313                    logger.warning("MockWebServer connection didn't make a request");
314                }
315
316                in.close();
317                out.close();
318                socket.close();
319                openClientSockets.remove(socket);
320            }
321
322            /**
323             * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response
324             * is dispatched.
325             */
326            private void createTunnel() throws IOException, InterruptedException {
327                while (true) {
328                    final SocketPolicy socketPolicy = dispatcher.peekSocketPolicy();
329                    if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
330                        throw new IllegalStateException("Tunnel without any CONNECT!");
331                    }
332                    if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) {
333                        return;
334                    }
335                }
336            }
337
338            /**
339             * Reads a request and writes its response. Returns true if a request
340             * was processed.
341             */
342            private boolean processOneRequest(Socket socket, InputStream in, OutputStream out)
343                    throws IOException, InterruptedException {
344                RecordedRequest request = readRequest(socket, in, sequenceNumber);
345                if (request == null) {
346                    return false;
347                }
348                requestCount.incrementAndGet();
349                requestQueue.add(request);
350                MockResponse response = dispatcher.dispatch(request);
351                if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AFTER_READING_REQUEST) {
352                  logger.info("Received request: " + request + " and disconnected without responding");
353                  return false;
354                }
355                writeResponse(out, response);
356
357                // For socket policies that poison the socket after the response is written:
358                // The client has received the response and will no longer be blocked after
359                // writeResponse() has returned. A client can then re-use the connection before
360                // the socket is poisoned (i.e. keep-alive / connection pooling). The second
361                // request/response may fail at the beginning, middle, end, or even succeed
362                // depending on scheduling. Delays can be required in tests to improve the chances
363                // of sockets being in a known state when subsequent requests are made.
364                //
365                // For SHUTDOWN_OUTPUT_AT_END the client may detect a problem with its input socket
366                // after the request has been made but before the server has chosen a response.
367                // For clients that perform retries, this can cause the client to issue a retry
368                // request. The retry handler may call dispatcher.dispatch(request) before the
369                // initial, failed request handler does and cause non-obvious response ordering.
370                // Setting workerThreads = 1 ensures that the dispatcher is called for requests in
371                // the order they are received.
372
373                if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
374                    in.close();
375                    out.close();
376                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) {
377                    socket.shutdownInput();
378                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) {
379                    socket.shutdownOutput();
380                }
381                logger.info("Received request: " + request + " and responded: " + response);
382                sequenceNumber++;
383                return true;
384            }
385        }));
386    }
387
388    private void processHandshakeFailure(Socket raw, int sequenceNumber) throws Exception {
389        X509TrustManager untrusted = new X509TrustManager() {
390            @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
391                    throws CertificateException {
392                throw new CertificateException();
393            }
394            @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
395                throw new AssertionError();
396            }
397            @Override public X509Certificate[] getAcceptedIssuers() {
398                throw new AssertionError();
399            }
400        };
401        SSLContext context = SSLContext.getInstance("TLS");
402        context.init(null, new TrustManager[] { untrusted }, new java.security.SecureRandom());
403        SSLSocketFactory sslSocketFactory = context.getSocketFactory();
404        SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
405                raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
406        try {
407            socket.startHandshake(); // we're testing a handshake failure
408            throw new AssertionError();
409        } catch (IOException expected) {
410        }
411        socket.close();
412    }
413
414    private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket) throws InterruptedException {
415        requestCount.incrementAndGet();
416        RecordedRequest request = new RecordedRequest(null, null, null, -1, null, sequenceNumber,
417                socket);
418        dispatcher.dispatch(request);
419        requestQueue.add(request);
420    }
421
422    /**
423     * @param sequenceNumber the index of this request on this connection.
424     */
425    private RecordedRequest readRequest(Socket socket, InputStream in, int sequenceNumber)
426            throws IOException {
427        String request;
428        try {
429            request = readAsciiUntilCrlf(in);
430        } catch (IOException streamIsClosed) {
431            return null; // no request because we closed the stream
432        }
433        if (request.length() == 0) {
434            return null; // no request because the stream is exhausted
435        }
436
437        List<String> headers = new ArrayList<String>();
438        int contentLength = -1;
439        boolean chunked = false;
440        String header;
441        while ((header = readAsciiUntilCrlf(in)).length() != 0) {
442            headers.add(header);
443            String lowercaseHeader = header.toLowerCase();
444            if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
445                contentLength = Integer.parseInt(header.substring(15).trim());
446            }
447            if (lowercaseHeader.startsWith("transfer-encoding:") &&
448                    lowercaseHeader.substring(18).trim().equals("chunked")) {
449                chunked = true;
450            }
451        }
452
453        boolean hasBody = false;
454        TruncatingOutputStream requestBody = new TruncatingOutputStream();
455        List<Integer> chunkSizes = new ArrayList<Integer>();
456        if (contentLength != -1) {
457            hasBody = true;
458            transfer(contentLength, in, requestBody);
459        } else if (chunked) {
460            hasBody = true;
461            while (true) {
462                int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
463                if (chunkSize == 0) {
464                    readEmptyLine(in);
465                    break;
466                }
467                chunkSizes.add(chunkSize);
468                transfer(chunkSize, in, requestBody);
469                readEmptyLine(in);
470            }
471        }
472
473        if (request.startsWith("OPTIONS ") || request.startsWith("GET ")
474                || request.startsWith("HEAD ") || request.startsWith("DELETE ")
475                || request.startsWith("TRACE ") || request.startsWith("CONNECT ")) {
476            if (hasBody) {
477                throw new IllegalArgumentException("Request must not have a body: " + request);
478            }
479        } else if (!request.startsWith("POST ") && !request.startsWith("PUT ")) {
480            throw new UnsupportedOperationException("Unexpected method: " + request);
481        }
482
483        return new RecordedRequest(request, headers, chunkSizes,
484                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber, socket);
485    }
486
487    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
488        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
489        for (String header : response.getHeaders()) {
490            out.write((header + "\r\n").getBytes(ASCII));
491        }
492        out.write(("\r\n").getBytes(ASCII));
493        out.flush();
494
495        final InputStream in = response.getBodyStream();
496        if (in == null) {
497            return;
498        }
499        final int bytesPerSecond = response.getBytesPerSecond();
500
501        // Stream data in MTU-sized increments
502        final byte[] buffer = new byte[1452];
503        final long delayMs;
504        if (bytesPerSecond == Integer.MAX_VALUE) {
505            delayMs = 0;
506        } else {
507            delayMs = (1000 * buffer.length) / bytesPerSecond;
508        }
509
510        int read;
511        long sinceDelay = 0;
512        while ((read = in.read(buffer)) != -1) {
513            out.write(buffer, 0, read);
514            out.flush();
515
516            sinceDelay += read;
517            if (sinceDelay >= buffer.length && delayMs > 0) {
518                sinceDelay %= buffer.length;
519                try {
520                    Thread.sleep(delayMs);
521                } catch (InterruptedException e) {
522                    throw new AssertionError();
523                }
524            }
525        }
526    }
527
528    /**
529     * Transfer bytes from {@code in} to {@code out} until either {@code length}
530     * bytes have been transferred or {@code in} is exhausted.
531     */
532    private void transfer(int length, InputStream in, OutputStream out) throws IOException {
533        byte[] buffer = new byte[1024];
534        while (length > 0) {
535            int count = in.read(buffer, 0, Math.min(buffer.length, length));
536            if (count == -1) {
537                return;
538            }
539            out.write(buffer, 0, count);
540            length -= count;
541        }
542    }
543
544    /**
545     * Returns the text from {@code in} until the next "\r\n", or null if
546     * {@code in} is exhausted.
547     */
548    private String readAsciiUntilCrlf(InputStream in) throws IOException {
549        StringBuilder builder = new StringBuilder();
550        while (true) {
551            int c = in.read();
552            if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
553                builder.deleteCharAt(builder.length() - 1);
554                return builder.toString();
555            } else if (c == -1) {
556                return builder.toString();
557            } else {
558                builder.append((char) c);
559            }
560        }
561    }
562
563    private void readEmptyLine(InputStream in) throws IOException {
564        String line = readAsciiUntilCrlf(in);
565        if (line.length() != 0) {
566            throw new IllegalStateException("Expected empty but was: " + line);
567        }
568    }
569
570    /**
571     * Sets the dispatcher used to match incoming requests to mock responses.
572     * The default dispatcher simply serves a fixed sequence of responses from
573     * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the
574     * response based on timing or the content of the request.
575     */
576    public void setDispatcher(Dispatcher dispatcher) {
577        if (dispatcher == null) {
578            throw new NullPointerException();
579        }
580        this.dispatcher = dispatcher;
581    }
582
583    /**
584     * An output stream that drops data after bodyLimit bytes.
585     */
586    private class TruncatingOutputStream extends ByteArrayOutputStream {
587        private int numBytesReceived = 0;
588        @Override public void write(byte[] buffer, int offset, int len) {
589            numBytesReceived += len;
590            super.write(buffer, offset, Math.min(len, bodyLimit - count));
591        }
592        @Override public void write(int oneByte) {
593            numBytesReceived++;
594            if (count < bodyLimit) {
595                super.write(oneByte);
596            }
597        }
598    }
599
600    private static Runnable namedRunnable(final String name, final Runnable runnable) {
601        return new Runnable() {
602            public void run() {
603                String originalName = Thread.currentThread().getName();
604                Thread.currentThread().setName(name);
605                try {
606                    runnable.run();
607                } finally {
608                    Thread.currentThread().setName(originalName);
609                }
610            }
611        };
612    }
613}
614