1/*
2 * Copyright (C) 2011 Google Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.google.mockwebserver;
18
19import static com.google.mockwebserver.SocketPolicy.DISCONNECT_AT_START;
20import static com.google.mockwebserver.SocketPolicy.FAIL_HANDSHAKE;
21import java.io.BufferedInputStream;
22import java.io.BufferedOutputStream;
23import java.io.ByteArrayOutputStream;
24import java.io.IOException;
25import java.io.InputStream;
26import java.io.OutputStream;
27import java.net.InetAddress;
28import java.net.InetSocketAddress;
29import java.net.MalformedURLException;
30import java.net.Proxy;
31import java.net.ServerSocket;
32import java.net.Socket;
33import java.net.SocketException;
34import java.net.URL;
35import java.net.UnknownHostException;
36import java.nio.charset.StandardCharsets;
37import java.security.SecureRandom;
38import java.security.cert.CertificateException;
39import java.security.cert.X509Certificate;
40import java.util.ArrayList;
41import java.util.Iterator;
42import java.util.List;
43import java.util.Locale;
44import java.util.Map;
45import java.util.concurrent.BlockingQueue;
46import java.util.concurrent.ConcurrentHashMap;
47import java.util.concurrent.ExecutorService;
48import java.util.concurrent.Executors;
49import java.util.concurrent.LinkedBlockingQueue;
50import java.util.concurrent.atomic.AtomicInteger;
51import java.util.logging.Level;
52import java.util.logging.Logger;
53import javax.net.ssl.SSLContext;
54import javax.net.ssl.SSLSocket;
55import javax.net.ssl.SSLSocketFactory;
56import javax.net.ssl.TrustManager;
57import javax.net.ssl.X509TrustManager;
58
59/**
60 * A scriptable web server. Callers supply canned responses and the server
61 * replays them upon request in sequence.
62 */
63public final class MockWebServer {
64    private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() {
65        @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
66                throws CertificateException {
67            throw new CertificateException();
68        }
69
70        @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
71            throw new AssertionError();
72        }
73
74        @Override public X509Certificate[] getAcceptedIssuers() {
75            throw new AssertionError();
76        }
77    };
78
79    private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
80
81    private final BlockingQueue<RecordedRequest> requestQueue
82            = new LinkedBlockingQueue<RecordedRequest>();
83
84    /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */
85    private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>();
86    private final AtomicInteger requestCount = new AtomicInteger();
87    private int bodyLimit = Integer.MAX_VALUE;
88    private ServerSocket serverSocket;
89    private SSLSocketFactory sslSocketFactory;
90    private ExecutorService acceptExecutor;
91    private ExecutorService requestExecutor;
92    private boolean tunnelProxy;
93    private Dispatcher dispatcher = new QueueDispatcher();
94
95    private int port = -1;
96    private int workerThreads = Integer.MAX_VALUE;
97
98    public int getPort() {
99        if (port == -1) {
100            throw new IllegalStateException("Cannot retrieve port before calling play()");
101        }
102        return port;
103    }
104
105    public String getHostName() {
106        try {
107            return InetAddress.getLocalHost().getHostName();
108        } catch (UnknownHostException e) {
109            throw new AssertionError(e);
110        }
111    }
112
113    public Proxy toProxyAddress() {
114        return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort()));
115    }
116
117    /**
118     * Returns a URL for connecting to this server.
119     *
120     * @param path the request path, such as "/".
121     */
122    public URL getUrl(String path) {
123        try {
124            return sslSocketFactory != null
125                    ? new URL("https://" + getHostName() + ":" + getPort() + path)
126                    : new URL("http://" + getHostName() + ":" + getPort() + path);
127        } catch (MalformedURLException e) {
128            throw new AssertionError(e);
129        }
130    }
131
132    /**
133     * Returns a cookie domain for this server. This returns the server's
134     * non-loopback host name if it is known. Otherwise this returns ".local"
135     * for this server's loopback name.
136     */
137    public String getCookieDomain() {
138        String hostName = getHostName();
139        return hostName.contains(".") ? hostName : ".local";
140    }
141
142    public void setWorkerThreads(int threads) {
143        this.workerThreads = threads;
144    }
145
146    /**
147     * Sets the number of bytes of the POST body to keep in memory to the given
148     * limit.
149     */
150    public void setBodyLimit(int maxBodyLength) {
151        this.bodyLimit = maxBodyLength;
152    }
153
154    /**
155     * Serve requests with HTTPS rather than otherwise.
156     *
157     * @param tunnelProxy whether to expect the HTTP CONNECT method before
158     *     negotiating TLS.
159     */
160    public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
161        this.sslSocketFactory = sslSocketFactory;
162        this.tunnelProxy = tunnelProxy;
163    }
164
165    /**
166     * Awaits the next HTTP request, removes it, and returns it. Callers should
167     * use this to verify the request sent was as intended.
168     */
169    public RecordedRequest takeRequest() throws InterruptedException {
170        return requestQueue.take();
171    }
172
173    /**
174     * Returns the number of HTTP requests received thus far by this server.
175     * This may exceed the number of HTTP connections when connection reuse is
176     * in practice.
177     */
178    public int getRequestCount() {
179        return requestCount.get();
180    }
181
182    /**
183     * Scripts {@code response} to be returned to a request made in sequence.
184     * The first request is served by the first enqueued response; the second
185     * request by the second enqueued response; and so on.
186     *
187     * @throws ClassCastException if the default dispatcher has been replaced
188     *     with {@link #setDispatcher(Dispatcher)}.
189     */
190    public void enqueue(MockResponse response) {
191        ((QueueDispatcher) dispatcher).enqueueResponse(response.clone());
192    }
193
194    /**
195     * Equivalent to {@code play(0)}.
196     */
197    public void play() throws IOException {
198        play(0);
199    }
200
201    /**
202     * Starts the server, serves all enqueued requests, and shuts the server
203     * down.
204     *
205     * @param port the port to listen to, or 0 for any available port.
206     *     Automated tests should always use port 0 to avoid flakiness when a
207     *     specific port is unavailable.
208     */
209    public void play(int port) throws IOException {
210        if (acceptExecutor != null) {
211            throw new IllegalStateException("play() already called");
212        }
213        // The acceptExecutor handles the Socket.accept() and hands each request off to the
214        // requestExecutor. It also handles shutdown.
215        acceptExecutor = Executors.newSingleThreadExecutor();
216        // The requestExecutor has a fixed number of worker threads. In order to get strict
217        // guarantees that requests are handled in the order in which they are accepted
218        // workerThreads should be set to 1.
219        requestExecutor = Executors.newFixedThreadPool(workerThreads);
220        serverSocket = new ServerSocket(port);
221        serverSocket.setReuseAddress(true);
222
223        this.port = serverSocket.getLocalPort();
224        acceptExecutor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() {
225            public void run() {
226                try {
227                    acceptConnections();
228                } catch (Throwable e) {
229                    logger.log(Level.WARNING, "MockWebServer connection failed", e);
230                }
231
232                /*
233                 * This gnarly block of code will release all sockets and
234                 * all thread, even if any close fails.
235                 */
236                try {
237                    serverSocket.close();
238                } catch (Throwable e) {
239                    logger.log(Level.WARNING, "MockWebServer server socket close failed", e);
240                }
241                for (Iterator<Socket> s = openClientSockets.keySet().iterator(); s.hasNext(); ) {
242                    try {
243                        s.next().close();
244                        s.remove();
245                    } catch (Throwable e) {
246                        logger.log(Level.WARNING, "MockWebServer socket close failed", e);
247                    }
248                }
249                try {
250                    acceptExecutor.shutdown();
251                } catch (Throwable e) {
252                    logger.log(Level.WARNING, "MockWebServer acceptExecutor shutdown failed", e);
253                }
254                try {
255                    requestExecutor.shutdown();
256                } catch (Throwable e) {
257                    logger.log(Level.WARNING, "MockWebServer requestExecutor shutdown failed", e);
258                }
259            }
260
261            private void acceptConnections() throws Exception {
262                while (true) {
263                    Socket socket;
264                    try {
265                        socket = serverSocket.accept();
266                    } catch (SocketException e) {
267                        return;
268                    }
269                    SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
270                    if (socketPolicy == DISCONNECT_AT_START) {
271                        dispatchBookkeepingRequest(0, socket);
272                        socket.close();
273                    } else {
274                        openClientSockets.put(socket, true);
275                        serveConnection(socket);
276                    }
277                }
278            }
279        }));
280    }
281
282    public void shutdown() throws IOException {
283        if (serverSocket != null) {
284            serverSocket.close(); // should cause acceptConnections() to break out
285        }
286    }
287
288    private void serveConnection(final Socket raw) {
289        String name = "MockWebServer-" + raw.getRemoteSocketAddress();
290        requestExecutor.execute(namedRunnable(name, new Runnable() {
291            int sequenceNumber = 0;
292
293            public void run() {
294                try {
295                    processConnection();
296                } catch (Exception e) {
297                    logger.log(Level.WARNING, "MockWebServer connection failed", e);
298                }
299            }
300
301            public void processConnection() throws Exception {
302                Socket socket;
303                if (sslSocketFactory != null) {
304                    if (tunnelProxy) {
305                        createTunnel();
306                    }
307                    SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
308                    if (socketPolicy == FAIL_HANDSHAKE) {
309                        dispatchBookkeepingRequest(sequenceNumber, raw);
310                        processHandshakeFailure(raw);
311                        return;
312                    }
313                    socket = sslSocketFactory.createSocket(
314                            raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
315                    SSLSocket sslSocket = (SSLSocket) socket;
316                    sslSocket.setUseClientMode(false);
317                    openClientSockets.put(socket, true);
318
319                    sslSocket.startHandshake();
320
321                    openClientSockets.remove(raw);
322                } else {
323                    socket = raw;
324                }
325
326                InputStream in = new BufferedInputStream(socket.getInputStream());
327                OutputStream out = new BufferedOutputStream(socket.getOutputStream());
328
329                while (processOneRequest(socket, in, out)) {
330                }
331
332                if (sequenceNumber == 0) {
333                    logger.warning("MockWebServer connection didn't make a request");
334                }
335
336                in.close();
337                out.close();
338                socket.close();
339                openClientSockets.remove(socket);
340            }
341
342            /**
343             * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response
344             * is dispatched.
345             */
346            private void createTunnel() throws IOException, InterruptedException {
347                while (true) {
348                    SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
349                    if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
350                        throw new IllegalStateException("Tunnel without any CONNECT!");
351                    }
352                    if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return;
353                }
354            }
355
356            /**
357             * Reads a request and writes its response. Returns true if a request
358             * was processed.
359             */
360            private boolean processOneRequest(Socket socket, InputStream in, OutputStream out)
361                    throws IOException, InterruptedException {
362                RecordedRequest request = readRequest(socket, in, out, sequenceNumber);
363                if (request == null) {
364                    return false;
365                }
366                requestCount.incrementAndGet();
367                requestQueue.add(request);
368                MockResponse response = dispatcher.dispatch(request);
369                if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AFTER_READING_REQUEST) {
370                  logger.info("Received request: " + request + " and disconnected without responding");
371                  return false;
372                }
373                writeResponse(out, response);
374
375                // For socket policies that poison the socket after the response is written:
376                // The client has received the response and will no longer be blocked after
377                // writeResponse() has returned. A client can then re-use the connection before
378                // the socket is poisoned (i.e. keep-alive / connection pooling). The second
379                // request/response may fail at the beginning, middle, end, or even succeed
380                // depending on scheduling. Delays can be required in tests to improve the chances
381                // of sockets being in a known state when subsequent requests are made.
382                //
383                // For SHUTDOWN_OUTPUT_AT_END the client may detect a problem with its input socket
384                // after the request has been made but before the server has chosen a response.
385                // For clients that perform retries, this can cause the client to issue a retry
386                // request. The retry handler may call dispatcher.dispatch(request) before the
387                // initial, failed request handler does and cause non-obvious response ordering.
388                // Setting workerThreads = 1 ensures that the dispatcher is called for requests in
389                // the order they are received.
390
391                if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
392                    in.close();
393                    out.close();
394                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) {
395                    socket.shutdownInput();
396                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) {
397                    socket.shutdownOutput();
398                }
399                logger.info("Received request: " + request + " and responded: " + response);
400                sequenceNumber++;
401                return true;
402            }
403        }));
404    }
405
406    private void processHandshakeFailure(Socket raw) throws Exception {
407        SSLContext context = SSLContext.getInstance("TLS");
408        context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom());
409        SSLSocketFactory sslSocketFactory = context.getSocketFactory();
410        SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
411                raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
412        try {
413            socket.startHandshake(); // we're testing a handshake failure
414            throw new AssertionError();
415        } catch (IOException expected) {
416        }
417        socket.close();
418    }
419
420    private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket) throws InterruptedException {
421        requestCount.incrementAndGet();
422        RecordedRequest request = new RecordedRequest(null, null, null, -1, null, sequenceNumber,
423                socket);
424        dispatcher.dispatch(request);
425    }
426
427    /** @param sequenceNumber the index of this request on this connection. */
428    private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out,
429            int sequenceNumber) throws IOException {
430        String request;
431        try {
432            request = readAsciiUntilCrlf(in);
433        } catch (IOException streamIsClosed) {
434            return null; // no request because we closed the stream
435        }
436        if (request.length() == 0) {
437            return null; // no request because the stream is exhausted
438        }
439
440        List<String> headers = new ArrayList<String>();
441        long contentLength = -1;
442        boolean chunked = false;
443        boolean expectContinue = false;
444        String header;
445        while ((header = readAsciiUntilCrlf(in)).length() != 0) {
446            headers.add(header);
447            String lowercaseHeader = header.toLowerCase(Locale.US);
448            if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
449                contentLength = Long.parseLong(header.substring(15).trim());
450            }
451            if (lowercaseHeader.startsWith("transfer-encoding:")
452                    && lowercaseHeader.substring(18).trim().equals("chunked")) {
453                chunked = true;
454            }
455            if (lowercaseHeader.startsWith("expect:")
456                    && lowercaseHeader.substring(7).trim().equals("100-continue")) {
457                expectContinue = true;
458            }
459        }
460
461        if (expectContinue) {
462            out.write(("HTTP/1.1 100 Continue\r\n").getBytes(StandardCharsets.US_ASCII));
463            out.write(("Content-Length: 0\r\n").getBytes(StandardCharsets.US_ASCII));
464            out.write(("\r\n").getBytes(StandardCharsets.US_ASCII));
465            out.flush();
466        }
467
468        boolean hasBody = false;
469        TruncatingOutputStream requestBody = new TruncatingOutputStream();
470        List<Integer> chunkSizes = new ArrayList<Integer>();
471        MockResponse throttlePolicy = dispatcher.peek();
472        if (contentLength != -1) {
473            hasBody = true;
474            throttledTransfer(throttlePolicy, in, requestBody, contentLength);
475        } else if (chunked) {
476            hasBody = true;
477            while (true) {
478                int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
479                if (chunkSize == 0) {
480                    readEmptyLine(in);
481                    break;
482                }
483                chunkSizes.add(chunkSize);
484                throttledTransfer(throttlePolicy, in, requestBody, chunkSize);
485                readEmptyLine(in);
486            }
487        }
488
489        if (request.startsWith("OPTIONS ")
490                || request.startsWith("GET ")
491                || request.startsWith("HEAD ")
492                || request.startsWith("TRACE ")
493                || request.startsWith("CONNECT ")) {
494            if (hasBody) {
495                throw new IllegalArgumentException("Request must not have a body: " + request);
496            }
497        } else if (!request.startsWith("POST ")
498                && !request.startsWith("PUT ")
499                && !request.startsWith("PATCH ")
500                && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous.
501            throw new UnsupportedOperationException("Unexpected method: " + request);
502        }
503
504        return new RecordedRequest(request, headers, chunkSizes, requestBody.numBytesReceived,
505                requestBody.toByteArray(), sequenceNumber, socket);
506    }
507
508    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
509        out.write((response.getStatus() + "\r\n").getBytes(StandardCharsets.US_ASCII));
510        List<String> headers = response.getHeaders();
511        for (int i = 0, size = headers.size(); i < size; i++) {
512            String header = headers.get(i);
513            out.write((header + "\r\n").getBytes(StandardCharsets.US_ASCII));
514        }
515        out.write(("\r\n").getBytes(StandardCharsets.US_ASCII));
516        out.flush();
517
518        InputStream in = response.getBodyStream();
519        if (in == null) return;
520        throttledTransfer(response, in, out, Long.MAX_VALUE);
521    }
522
523    /**
524     * Transfer bytes from {@code in} to {@code out} until either {@code length}
525     * bytes have been transferred or {@code in} is exhausted. The transfer is
526     * throttled according to {@code throttlePolicy}.
527     */
528    private void throttledTransfer(MockResponse throttlePolicy, InputStream in, OutputStream out,
529            long limit) throws IOException {
530        byte[] buffer = new byte[1024];
531        int bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod();
532        long delayMs = throttlePolicy.getThrottleUnit().toMillis(throttlePolicy.getThrottlePeriod());
533
534        while (true) {
535            for (int b = 0; b < bytesPerPeriod; ) {
536                int toRead = (int) Math.min(Math.min(buffer.length, limit), bytesPerPeriod - b);
537                int read = in.read(buffer, 0, toRead);
538                if (read == -1) return;
539
540                out.write(buffer, 0, read);
541                out.flush();
542                b += read;
543                limit -= read;
544
545                if (limit == 0) return;
546            }
547
548            try {
549                if (delayMs != 0) Thread.sleep(delayMs);
550            } catch (InterruptedException e) {
551                throw new AssertionError();
552            }
553        }
554    }
555
556    /**
557     * Returns the text from {@code in} until the next "\r\n", or null if
558     * {@code in} is exhausted.
559     */
560    private String readAsciiUntilCrlf(InputStream in) throws IOException {
561        StringBuilder builder = new StringBuilder();
562        while (true) {
563            int c = in.read();
564            if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
565                builder.deleteCharAt(builder.length() - 1);
566                return builder.toString();
567            } else if (c == -1) {
568                return builder.toString();
569            } else {
570                builder.append((char) c);
571            }
572        }
573    }
574
575    private void readEmptyLine(InputStream in) throws IOException {
576        String line = readAsciiUntilCrlf(in);
577        if (line.length() != 0) {
578            throw new IllegalStateException("Expected empty but was: " + line);
579        }
580    }
581
582    /**
583     * Sets the dispatcher used to match incoming requests to mock responses.
584     * The default dispatcher simply serves a fixed sequence of responses from
585     * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the
586     * response based on timing or the content of the request.
587     */
588    public void setDispatcher(Dispatcher dispatcher) {
589        if (dispatcher == null) {
590            throw new NullPointerException();
591        }
592        this.dispatcher = dispatcher;
593    }
594
595    /**
596     * An output stream that drops data after bodyLimit bytes.
597     */
598    private class TruncatingOutputStream extends ByteArrayOutputStream {
599        private int numBytesReceived = 0;
600        @Override public void write(byte[] buffer, int offset, int len) {
601            numBytesReceived += len;
602            super.write(buffer, offset, Math.min(len, bodyLimit - count));
603        }
604        @Override public void write(int oneByte) {
605            numBytesReceived++;
606            if (count < bodyLimit) {
607                super.write(oneByte);
608            }
609        }
610    }
611
612    private static Runnable namedRunnable(final String name, final Runnable runnable) {
613        return new Runnable() {
614            public void run() {
615                String originalName = Thread.currentThread().getName();
616                Thread.currentThread().setName(name);
617                try {
618                    runnable.run();
619                } finally {
620                    Thread.currentThread().setName(originalName);
621                }
622            }
623        };
624    }
625}
626