1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package tests.http;
18
19import java.io.BufferedInputStream;
20import java.io.BufferedOutputStream;
21import java.io.ByteArrayOutputStream;
22import java.io.IOException;
23import java.io.InputStream;
24import java.io.OutputStream;
25import java.net.HttpURLConnection;
26import java.net.InetAddress;
27import java.net.InetSocketAddress;
28import java.net.MalformedURLException;
29import java.net.Proxy;
30import java.net.ServerSocket;
31import java.net.Socket;
32import java.net.SocketException;
33import java.net.URL;
34import java.net.UnknownHostException;
35import java.util.ArrayList;
36import java.util.Collections;
37import java.util.Iterator;
38import java.util.List;
39import java.util.Set;
40import java.util.concurrent.BlockingQueue;
41import java.util.concurrent.ConcurrentHashMap;
42import java.util.concurrent.ExecutorService;
43import java.util.concurrent.Executors;
44import java.util.concurrent.LinkedBlockingDeque;
45import java.util.concurrent.LinkedBlockingQueue;
46import java.util.concurrent.atomic.AtomicInteger;
47import java.util.logging.Level;
48import java.util.logging.Logger;
49import javax.net.ssl.SSLSocket;
50import javax.net.ssl.SSLSocketFactory;
51import static tests.http.SocketPolicy.DISCONNECT_AT_START;
52
53/**
54 * A scriptable web server. Callers supply canned responses and the server
55 * replays them upon request in sequence.
56 *
57 * @deprecated prefer com.google.mockwebserver.MockWebServer
58 */
59@Deprecated
60public final class MockWebServer {
61
62    static final String ASCII = "US-ASCII";
63
64    private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
65    private final BlockingQueue<RecordedRequest> requestQueue
66            = new LinkedBlockingQueue<RecordedRequest>();
67    private final BlockingQueue<MockResponse> responseQueue
68            = new LinkedBlockingDeque<MockResponse>();
69    private final Set<Socket> openClientSockets
70            = Collections.newSetFromMap(new ConcurrentHashMap<Socket, Boolean>());
71    private boolean singleResponse;
72    private final AtomicInteger requestCount = new AtomicInteger();
73    private int bodyLimit = Integer.MAX_VALUE;
74    private ServerSocket serverSocket;
75    private SSLSocketFactory sslSocketFactory;
76    private ExecutorService executor;
77    private boolean tunnelProxy;
78
79    private int port = -1;
80
81    public int getPort() {
82        if (port == -1) {
83            throw new IllegalStateException("Cannot retrieve port before calling play()");
84        }
85        return port;
86    }
87
88    public String getHostName() {
89        return InetAddress.getLoopbackAddress().getHostName();
90    }
91
92    public Proxy toProxyAddress() {
93        return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort()));
94    }
95
96    /**
97     * Returns a URL for connecting to this server.
98     *
99     * @param path the request path, such as "/".
100     */
101    public URL getUrl(String path) throws MalformedURLException, UnknownHostException {
102        return sslSocketFactory != null
103                ? new URL("https://" + getHostName() + ":" + getPort() + path)
104                : new URL("http://" + getHostName() + ":" + getPort() + path);
105    }
106
107    /**
108     * Sets the number of bytes of the POST body to keep in memory to the given
109     * limit.
110     */
111    public void setBodyLimit(int maxBodyLength) {
112        this.bodyLimit = maxBodyLength;
113    }
114
115    /**
116     * Serve requests with HTTPS rather than otherwise.
117     *
118     * @param tunnelProxy whether to expect the HTTP CONNECT method before
119     *     negotiating TLS.
120     */
121    public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
122        this.sslSocketFactory = sslSocketFactory;
123        this.tunnelProxy = tunnelProxy;
124    }
125
126    /**
127     * Awaits the next HTTP request, removes it, and returns it. Callers should
128     * use this to verify the request sent was as intended.
129     */
130    public RecordedRequest takeRequest() throws InterruptedException {
131        return requestQueue.take();
132    }
133
134    /**
135     * Returns the number of HTTP requests received thus far by this server.
136     * This may exceed the number of HTTP connections when connection reuse is
137     * in practice.
138     */
139    public int getRequestCount() {
140        return requestCount.get();
141    }
142
143    public void enqueue(MockResponse response) {
144        responseQueue.add(response.clone());
145    }
146
147    /**
148     * By default, this class processes requests coming in by adding them to a
149     * queue and serves responses by removing them from another queue. This mode
150     * is appropriate for correctness testing.
151     *
152     * <p>Serving a single response causes the server to be stateless: requests
153     * are not enqueued, and responses are not dequeued. This mode is appropriate
154     * for benchmarking.
155     */
156    public void setSingleResponse(boolean singleResponse) {
157        this.singleResponse = singleResponse;
158    }
159
160    /**
161     * Equivalent to {@code play(0)}.
162     */
163    public void play() throws IOException {
164        play(0);
165    }
166
167    /**
168     * Starts the server, serves all enqueued requests, and shuts the server
169     * down.
170     *
171     * @param port the port to listen to, or 0 for any available port.
172     *     Automated tests should always use port 0 to avoid flakiness when a
173     *     specific port is unavailable.
174     */
175    public void play(int port) throws IOException {
176        executor = Executors.newCachedThreadPool();
177        serverSocket = new ServerSocket(port);
178        serverSocket.setReuseAddress(true);
179
180        this.port = serverSocket.getLocalPort();
181        executor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() {
182            public void run() {
183                try {
184                    acceptConnections();
185                } catch (Throwable e) {
186                    logger.log(Level.WARNING, "MockWebServer connection failed", e);
187                }
188
189                /*
190                 * This gnarly block of code will release all sockets and
191                 * all thread, even if any close fails.
192                 */
193                try {
194                    serverSocket.close();
195                } catch (Throwable e) {
196                    logger.log(Level.WARNING, "MockWebServer server socket close failed", e);
197                }
198                for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext();) {
199                    try {
200                        s.next().close();
201                        s.remove();
202                    } catch (Throwable e) {
203                        logger.log(Level.WARNING, "MockWebServer socket close failed", e);
204                    }
205                }
206                try {
207                    executor.shutdown();
208                } catch (Throwable e) {
209                    logger.log(Level.WARNING, "MockWebServer executor shutdown failed", e);
210                }
211            }
212
213            private void acceptConnections() throws Exception {
214                do {
215                    Socket socket;
216                    try {
217                        socket = serverSocket.accept();
218                    } catch (SocketException ignored) {
219                        continue;
220                    }
221                    MockResponse peek = responseQueue.peek();
222                    if (peek != null && peek.getSocketPolicy() == DISCONNECT_AT_START) {
223                        responseQueue.take();
224                        socket.close();
225                    } else {
226                        openClientSockets.add(socket);
227                        serveConnection(socket);
228                    }
229                } while (!responseQueue.isEmpty());
230            }
231        }));
232    }
233
234    public void shutdown() throws IOException {
235        if (serverSocket != null) {
236            serverSocket.close(); // should cause acceptConnections() to break out
237        }
238    }
239
240    private void serveConnection(final Socket raw) {
241        String name = "MockWebServer-" + raw.getRemoteSocketAddress();
242        executor.execute(namedRunnable(name, new Runnable() {
243            int sequenceNumber = 0;
244
245            public void run() {
246                try {
247                    processConnection();
248                } catch (Exception e) {
249                    logger.log(Level.WARNING, "MockWebServer connection failed", e);
250                }
251            }
252
253            public void processConnection() throws Exception {
254                Socket socket;
255                if (sslSocketFactory != null) {
256                    if (tunnelProxy) {
257                        createTunnel();
258                    }
259                    socket = sslSocketFactory.createSocket(
260                            raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
261                    ((SSLSocket) socket).setUseClientMode(false);
262                    openClientSockets.add(socket);
263                    openClientSockets.remove(raw);
264                } else {
265                    socket = raw;
266                }
267
268                InputStream in = new BufferedInputStream(socket.getInputStream());
269                OutputStream out = new BufferedOutputStream(socket.getOutputStream());
270
271                while (!responseQueue.isEmpty() && processOneRequest(in, out, socket)) {}
272
273                if (sequenceNumber == 0) {
274                    logger.warning("MockWebServer connection didn't make a request");
275                }
276
277                in.close();
278                out.close();
279                socket.close();
280                if (responseQueue.isEmpty()) {
281                    shutdown();
282                }
283                openClientSockets.remove(socket);
284            }
285
286            /**
287             * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response
288             * is dispatched.
289             */
290            private void createTunnel() throws IOException, InterruptedException {
291                while (true) {
292                    MockResponse connect = responseQueue.peek();
293                    if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) {
294                        throw new IllegalStateException("Tunnel without any CONNECT!");
295                    }
296                    if (connect.getSocketPolicy() == SocketPolicy.UPGRADE_TO_SSL_AT_END) {
297                        return;
298                    }
299                }
300            }
301
302            /**
303             * Reads a request and writes its response. Returns true if a request
304             * was processed.
305             */
306            private boolean processOneRequest(InputStream in, OutputStream out, Socket socket)
307                    throws IOException, InterruptedException {
308                RecordedRequest request = readRequest(in, sequenceNumber);
309                if (request == null) {
310                    return false;
311                }
312                MockResponse response = dispatch(request);
313                writeResponse(out, response);
314                if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
315                    in.close();
316                    out.close();
317                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) {
318                    socket.shutdownInput();
319                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) {
320                    socket.shutdownOutput();
321                }
322                sequenceNumber++;
323                return true;
324            }
325        }));
326    }
327
328    /**
329     * @param sequenceNumber the index of this request on this connection.
330     */
331    private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
332        String request;
333        try {
334            request = readAsciiUntilCrlf(in);
335        } catch (IOException streamIsClosed) {
336            return null; // no request because we closed the stream
337        }
338        if (request.isEmpty()) {
339            return null; // no request because the stream is exhausted
340        }
341
342        List<String> headers = new ArrayList<String>();
343        int contentLength = -1;
344        boolean chunked = false;
345        String header;
346        while (!(header = readAsciiUntilCrlf(in)).isEmpty()) {
347            headers.add(header);
348            String lowercaseHeader = header.toLowerCase();
349            if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
350                contentLength = Integer.parseInt(header.substring(15).trim());
351            }
352            if (lowercaseHeader.startsWith("transfer-encoding:") &&
353                    lowercaseHeader.substring(18).trim().equals("chunked")) {
354                chunked = true;
355            }
356        }
357
358        boolean hasBody = false;
359        TruncatingOutputStream requestBody = new TruncatingOutputStream();
360        List<Integer> chunkSizes = new ArrayList<Integer>();
361        if (contentLength != -1) {
362            hasBody = true;
363            transfer(contentLength, in, requestBody);
364        } else if (chunked) {
365            hasBody = true;
366            while (true) {
367                int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
368                if (chunkSize == 0) {
369                    readEmptyLine(in);
370                    break;
371                }
372                chunkSizes.add(chunkSize);
373                transfer(chunkSize, in, requestBody);
374                readEmptyLine(in);
375            }
376        }
377
378        if (request.startsWith("OPTIONS ") || request.startsWith("GET ")
379                || request.startsWith("HEAD ") || request.startsWith("DELETE ")
380                || request .startsWith("TRACE ") || request.startsWith("CONNECT ")) {
381            if (hasBody) {
382                throw new IllegalArgumentException("Request must not have a body: " + request);
383            }
384        } else if (request.startsWith("POST ") || request.startsWith("PUT ")) {
385            if (!hasBody) {
386                throw new IllegalArgumentException("Request must have a body: " + request);
387            }
388        } else {
389            throw new UnsupportedOperationException("Unexpected method: " + request);
390        }
391
392        return new RecordedRequest(request, headers, chunkSizes,
393                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
394    }
395
396    /**
397     * Returns a response to satisfy {@code request}.
398     */
399    private MockResponse dispatch(RecordedRequest request) throws InterruptedException {
400        if (responseQueue.isEmpty()) {
401            throw new IllegalStateException("Unexpected request: " + request);
402        }
403
404        // to permit interactive/browser testing, ignore requests for favicons
405        if (request.getRequestLine().equals("GET /favicon.ico HTTP/1.1")) {
406            System.out.println("served " + request.getRequestLine());
407            return new MockResponse()
408                        .setResponseCode(HttpURLConnection.HTTP_NOT_FOUND);
409        }
410
411        if (singleResponse) {
412            return responseQueue.peek();
413        } else {
414            requestCount.incrementAndGet();
415            requestQueue.add(request);
416            return responseQueue.take();
417        }
418    }
419
420    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
421        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
422        for (String header : response.getHeaders()) {
423            out.write((header + "\r\n").getBytes(ASCII));
424        }
425        out.write(("\r\n").getBytes(ASCII));
426        out.write(response.getBody());
427        out.flush();
428    }
429
430    /**
431     * Transfer bytes from {@code in} to {@code out} until either {@code length}
432     * bytes have been transferred or {@code in} is exhausted.
433     */
434    private void transfer(int length, InputStream in, OutputStream out) throws IOException {
435        byte[] buffer = new byte[1024];
436        while (length > 0) {
437            int count = in.read(buffer, 0, Math.min(buffer.length, length));
438            if (count == -1) {
439                return;
440            }
441            out.write(buffer, 0, count);
442            length -= count;
443        }
444    }
445
446    /**
447     * Returns the text from {@code in} until the next "\r\n", or null if
448     * {@code in} is exhausted.
449     */
450    private String readAsciiUntilCrlf(InputStream in) throws IOException {
451        StringBuilder builder = new StringBuilder();
452        while (true) {
453            int c = in.read();
454            if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
455                builder.deleteCharAt(builder.length() - 1);
456                return builder.toString();
457            } else if (c == -1) {
458                return builder.toString();
459            } else {
460                builder.append((char) c);
461            }
462        }
463    }
464
465    private void readEmptyLine(InputStream in) throws IOException {
466        String line = readAsciiUntilCrlf(in);
467        if (!line.isEmpty()) {
468            throw new IllegalStateException("Expected empty but was: " + line);
469        }
470    }
471
472    /**
473     * An output stream that drops data after bodyLimit bytes.
474     */
475    private class TruncatingOutputStream extends ByteArrayOutputStream {
476        private int numBytesReceived = 0;
477        @Override public void write(byte[] buffer, int offset, int len) {
478            numBytesReceived += len;
479            super.write(buffer, offset, Math.min(len, bodyLimit - count));
480        }
481        @Override public void write(int oneByte) {
482            numBytesReceived++;
483            if (count < bodyLimit) {
484                super.write(oneByte);
485            }
486        }
487    }
488
489    private static Runnable namedRunnable(final String name, final Runnable runnable) {
490        return new Runnable() {
491            public void run() {
492                String originalName = Thread.currentThread().getName();
493                Thread.currentThread().setName(name);
494                try {
495                    runnable.run();
496                } finally {
497                    Thread.currentThread().setName(originalName);
498                }
499            }
500        };
501    }
502}
503