1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package tests.http;
18
19import java.io.BufferedInputStream;
20import java.io.BufferedOutputStream;
21import java.io.ByteArrayOutputStream;
22import java.io.IOException;
23import java.io.InputStream;
24import java.io.OutputStream;
25import java.net.HttpURLConnection;
26import java.net.InetAddress;
27import java.net.InetSocketAddress;
28import java.net.MalformedURLException;
29import java.net.Proxy;
30import java.net.ServerSocket;
31import java.net.Socket;
32import java.net.SocketException;
33import java.net.URL;
34import java.net.UnknownHostException;
35import java.util.ArrayList;
36import java.util.Collections;
37import java.util.Iterator;
38import java.util.List;
39import java.util.Locale;
40import java.util.Set;
41import java.util.concurrent.BlockingQueue;
42import java.util.concurrent.ConcurrentHashMap;
43import java.util.concurrent.ExecutorService;
44import java.util.concurrent.Executors;
45import java.util.concurrent.LinkedBlockingDeque;
46import java.util.concurrent.LinkedBlockingQueue;
47import java.util.concurrent.atomic.AtomicInteger;
48import java.util.logging.Level;
49import java.util.logging.Logger;
50import javax.net.ssl.SSLSocket;
51import javax.net.ssl.SSLSocketFactory;
52import static tests.http.SocketPolicy.DISCONNECT_AT_START;
53
54/**
55 * A scriptable web server. Callers supply canned responses and the server
56 * replays them upon request in sequence.
57 *
58 * @deprecated Use {@code com.google.mockwebserver.MockWebServer} instead.
59 */
60@Deprecated
61public final class MockWebServer {
62
63    static final String ASCII = "US-ASCII";
64
65    private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
66    private final BlockingQueue<RecordedRequest> requestQueue
67            = new LinkedBlockingQueue<RecordedRequest>();
68    private final BlockingQueue<MockResponse> responseQueue
69            = new LinkedBlockingDeque<MockResponse>();
70    private final Set<Socket> openClientSockets
71            = Collections.newSetFromMap(new ConcurrentHashMap<Socket, Boolean>());
72    private boolean singleResponse;
73    private final AtomicInteger requestCount = new AtomicInteger();
74    private int bodyLimit = Integer.MAX_VALUE;
75    private ServerSocket serverSocket;
76    private SSLSocketFactory sslSocketFactory;
77    private ExecutorService executor;
78    private boolean tunnelProxy;
79
80    private int port = -1;
81
82    public int getPort() {
83        if (port == -1) {
84            throw new IllegalStateException("Cannot retrieve port before calling play()");
85        }
86        return port;
87    }
88
89    public String getHostName() {
90        return InetAddress.getLoopbackAddress().getHostName();
91    }
92
93    public Proxy toProxyAddress() {
94        return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort()));
95    }
96
97    /**
98     * Returns a URL for connecting to this server.
99     *
100     * @param path the request path, such as "/".
101     */
102    public URL getUrl(String path) throws MalformedURLException, UnknownHostException {
103        return sslSocketFactory != null
104                ? new URL("https://" + getHostName() + ":" + getPort() + path)
105                : new URL("http://" + getHostName() + ":" + getPort() + path);
106    }
107
108    /**
109     * Sets the number of bytes of the POST body to keep in memory to the given
110     * limit.
111     */
112    public void setBodyLimit(int maxBodyLength) {
113        this.bodyLimit = maxBodyLength;
114    }
115
116    /**
117     * Serve requests with HTTPS rather than otherwise.
118     *
119     * @param tunnelProxy whether to expect the HTTP CONNECT method before
120     *     negotiating TLS.
121     */
122    public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
123        this.sslSocketFactory = sslSocketFactory;
124        this.tunnelProxy = tunnelProxy;
125    }
126
127    /**
128     * Awaits the next HTTP request, removes it, and returns it. Callers should
129     * use this to verify the request sent was as intended.
130     */
131    public RecordedRequest takeRequest() throws InterruptedException {
132        return requestQueue.take();
133    }
134
135    /**
136     * Returns the number of HTTP requests received thus far by this server.
137     * This may exceed the number of HTTP connections when connection reuse is
138     * in practice.
139     */
140    public int getRequestCount() {
141        return requestCount.get();
142    }
143
144    public void enqueue(MockResponse response) {
145        responseQueue.add(response.clone());
146    }
147
148    /**
149     * By default, this class processes requests coming in by adding them to a
150     * queue and serves responses by removing them from another queue. This mode
151     * is appropriate for correctness testing.
152     *
153     * <p>Serving a single response causes the server to be stateless: requests
154     * are not enqueued, and responses are not dequeued. This mode is appropriate
155     * for benchmarking.
156     */
157    public void setSingleResponse(boolean singleResponse) {
158        this.singleResponse = singleResponse;
159    }
160
161    /**
162     * Equivalent to {@code play(0)}.
163     */
164    public void play() throws IOException {
165        play(0);
166    }
167
168    /**
169     * Starts the server, serves all enqueued requests, and shuts the server
170     * down.
171     *
172     * @param port the port to listen to, or 0 for any available port.
173     *     Automated tests should always use port 0 to avoid flakiness when a
174     *     specific port is unavailable.
175     */
176    public void play(int port) throws IOException {
177        executor = Executors.newCachedThreadPool();
178        serverSocket = new ServerSocket(port);
179        serverSocket.setReuseAddress(true);
180
181        this.port = serverSocket.getLocalPort();
182        executor.execute(namedRunnable("MockWebServer-accept-" + port, new Runnable() {
183            public void run() {
184                try {
185                    acceptConnections();
186                } catch (Throwable e) {
187                    logger.log(Level.WARNING, "MockWebServer connection failed", e);
188                }
189
190                /*
191                 * This gnarly block of code will release all sockets and
192                 * all thread, even if any close fails.
193                 */
194                try {
195                    serverSocket.close();
196                } catch (Throwable e) {
197                    logger.log(Level.WARNING, "MockWebServer server socket close failed", e);
198                }
199                for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext();) {
200                    try {
201                        s.next().close();
202                        s.remove();
203                    } catch (Throwable e) {
204                        logger.log(Level.WARNING, "MockWebServer socket close failed", e);
205                    }
206                }
207                try {
208                    executor.shutdown();
209                } catch (Throwable e) {
210                    logger.log(Level.WARNING, "MockWebServer executor shutdown failed", e);
211                }
212            }
213
214            private void acceptConnections() throws Exception {
215                do {
216                    Socket socket;
217                    try {
218                        socket = serverSocket.accept();
219                    } catch (SocketException ignored) {
220                        continue;
221                    }
222                    MockResponse peek = responseQueue.peek();
223                    if (peek != null && peek.getSocketPolicy() == DISCONNECT_AT_START) {
224                        responseQueue.take();
225                        socket.close();
226                    } else {
227                        openClientSockets.add(socket);
228                        serveConnection(socket);
229                    }
230                } while (!responseQueue.isEmpty());
231            }
232        }));
233    }
234
235    public void shutdown() throws IOException {
236        if (serverSocket != null) {
237            serverSocket.close(); // should cause acceptConnections() to break out
238        }
239    }
240
241    private void serveConnection(final Socket raw) {
242        String name = "MockWebServer-" + raw.getRemoteSocketAddress();
243        executor.execute(namedRunnable(name, new Runnable() {
244            int sequenceNumber = 0;
245
246            public void run() {
247                try {
248                    processConnection();
249                } catch (Exception e) {
250                    logger.log(Level.WARNING, "MockWebServer connection failed", e);
251                }
252            }
253
254            public void processConnection() throws Exception {
255                Socket socket;
256                if (sslSocketFactory != null) {
257                    if (tunnelProxy) {
258                        createTunnel();
259                    }
260                    socket = sslSocketFactory.createSocket(
261                            raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
262                    ((SSLSocket) socket).setUseClientMode(false);
263                    openClientSockets.add(socket);
264                    openClientSockets.remove(raw);
265                } else {
266                    socket = raw;
267                }
268
269                InputStream in = new BufferedInputStream(socket.getInputStream());
270                OutputStream out = new BufferedOutputStream(socket.getOutputStream());
271
272                while (!responseQueue.isEmpty() && processOneRequest(in, out, socket)) {}
273
274                if (sequenceNumber == 0) {
275                    logger.warning("MockWebServer connection didn't make a request");
276                }
277
278                in.close();
279                out.close();
280                socket.close();
281                if (responseQueue.isEmpty()) {
282                    shutdown();
283                }
284                openClientSockets.remove(socket);
285            }
286
287            /**
288             * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response
289             * is dispatched.
290             */
291            private void createTunnel() throws IOException, InterruptedException {
292                while (true) {
293                    MockResponse connect = responseQueue.peek();
294                    if (!processOneRequest(raw.getInputStream(), raw.getOutputStream(), raw)) {
295                        throw new IllegalStateException("Tunnel without any CONNECT!");
296                    }
297                    if (connect.getSocketPolicy() == SocketPolicy.UPGRADE_TO_SSL_AT_END) {
298                        return;
299                    }
300                }
301            }
302
303            /**
304             * Reads a request and writes its response. Returns true if a request
305             * was processed.
306             */
307            private boolean processOneRequest(InputStream in, OutputStream out, Socket socket)
308                    throws IOException, InterruptedException {
309                RecordedRequest request = readRequest(in, sequenceNumber);
310                if (request == null) {
311                    return false;
312                }
313                MockResponse response = dispatch(request);
314                writeResponse(out, response);
315                if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
316                    in.close();
317                    out.close();
318                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) {
319                    socket.shutdownInput();
320                } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) {
321                    socket.shutdownOutput();
322                }
323                sequenceNumber++;
324                return true;
325            }
326        }));
327    }
328
329    /**
330     * @param sequenceNumber the index of this request on this connection.
331     */
332    private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
333        String request;
334        try {
335            request = readAsciiUntilCrlf(in);
336        } catch (IOException streamIsClosed) {
337            return null; // no request because we closed the stream
338        }
339        if (request.isEmpty()) {
340            return null; // no request because the stream is exhausted
341        }
342
343        List<String> headers = new ArrayList<String>();
344        int contentLength = -1;
345        boolean chunked = false;
346        String header;
347        while (!(header = readAsciiUntilCrlf(in)).isEmpty()) {
348            headers.add(header);
349            String lowercaseHeader = header.toLowerCase(Locale.ROOT);
350            if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
351                contentLength = Integer.parseInt(header.substring(15).trim());
352            }
353            if (lowercaseHeader.startsWith("transfer-encoding:") &&
354                    lowercaseHeader.substring(18).trim().equals("chunked")) {
355                chunked = true;
356            }
357        }
358
359        boolean hasBody = false;
360        TruncatingOutputStream requestBody = new TruncatingOutputStream();
361        List<Integer> chunkSizes = new ArrayList<Integer>();
362        if (contentLength != -1) {
363            hasBody = true;
364            transfer(contentLength, in, requestBody);
365        } else if (chunked) {
366            hasBody = true;
367            while (true) {
368                int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
369                if (chunkSize == 0) {
370                    readEmptyLine(in);
371                    break;
372                }
373                chunkSizes.add(chunkSize);
374                transfer(chunkSize, in, requestBody);
375                readEmptyLine(in);
376            }
377        }
378
379        if (request.startsWith("OPTIONS ") || request.startsWith("GET ")
380                || request.startsWith("HEAD ") || request.startsWith("DELETE ")
381                || request .startsWith("TRACE ") || request.startsWith("CONNECT ")) {
382            if (hasBody) {
383                throw new IllegalArgumentException("Request must not have a body: " + request);
384            }
385        } else if (request.startsWith("POST ") || request.startsWith("PUT ")) {
386            if (!hasBody) {
387                throw new IllegalArgumentException("Request must have a body: " + request);
388            }
389        } else {
390            throw new UnsupportedOperationException("Unexpected method: " + request);
391        }
392
393        return new RecordedRequest(request, headers, chunkSizes,
394                requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
395    }
396
397    /**
398     * Returns a response to satisfy {@code request}.
399     */
400    private MockResponse dispatch(RecordedRequest request) throws InterruptedException {
401        if (responseQueue.isEmpty()) {
402            throw new IllegalStateException("Unexpected request: " + request);
403        }
404
405        // to permit interactive/browser testing, ignore requests for favicons
406        if (request.getRequestLine().equals("GET /favicon.ico HTTP/1.1")) {
407            System.out.println("served " + request.getRequestLine());
408            return new MockResponse()
409                        .setResponseCode(HttpURLConnection.HTTP_NOT_FOUND);
410        }
411
412        if (singleResponse) {
413            return responseQueue.peek();
414        } else {
415            requestCount.incrementAndGet();
416            requestQueue.add(request);
417            return responseQueue.take();
418        }
419    }
420
421    private void writeResponse(OutputStream out, MockResponse response) throws IOException {
422        out.write((response.getStatus() + "\r\n").getBytes(ASCII));
423        for (String header : response.getHeaders()) {
424            out.write((header + "\r\n").getBytes(ASCII));
425        }
426        out.write(("\r\n").getBytes(ASCII));
427        out.write(response.getBody());
428        out.flush();
429    }
430
431    /**
432     * Transfer bytes from {@code in} to {@code out} until either {@code length}
433     * bytes have been transferred or {@code in} is exhausted.
434     */
435    private void transfer(int length, InputStream in, OutputStream out) throws IOException {
436        byte[] buffer = new byte[1024];
437        while (length > 0) {
438            int count = in.read(buffer, 0, Math.min(buffer.length, length));
439            if (count == -1) {
440                return;
441            }
442            out.write(buffer, 0, count);
443            length -= count;
444        }
445    }
446
447    /**
448     * Returns the text from {@code in} until the next "\r\n", or null if
449     * {@code in} is exhausted.
450     */
451    private String readAsciiUntilCrlf(InputStream in) throws IOException {
452        StringBuilder builder = new StringBuilder();
453        while (true) {
454            int c = in.read();
455            if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
456                builder.deleteCharAt(builder.length() - 1);
457                return builder.toString();
458            } else if (c == -1) {
459                return builder.toString();
460            } else {
461                builder.append((char) c);
462            }
463        }
464    }
465
466    private void readEmptyLine(InputStream in) throws IOException {
467        String line = readAsciiUntilCrlf(in);
468        if (!line.isEmpty()) {
469            throw new IllegalStateException("Expected empty but was: " + line);
470        }
471    }
472
473    /**
474     * An output stream that drops data after bodyLimit bytes.
475     */
476    private class TruncatingOutputStream extends ByteArrayOutputStream {
477        private int numBytesReceived = 0;
478        @Override public void write(byte[] buffer, int offset, int len) {
479            numBytesReceived += len;
480            super.write(buffer, offset, Math.min(len, bodyLimit - count));
481        }
482        @Override public void write(int oneByte) {
483            numBytesReceived++;
484            if (count < bodyLimit) {
485                super.write(oneByte);
486            }
487        }
488    }
489
490    private static Runnable namedRunnable(final String name, final Runnable runnable) {
491        return new Runnable() {
492            public void run() {
493                String originalName = Thread.currentThread().getName();
494                Thread.currentThread().setName(name);
495                try {
496                    runnable.run();
497                } finally {
498                    Thread.currentThread().setName(originalName);
499                }
500            }
501        };
502    }
503}
504