1/*
2 * Copyright (C) 2011 Google Inc.
3 * Copyright (C) 2013 Square, Inc.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *      http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package com.squareup.okhttp.mockwebserver;
19
20import com.squareup.okhttp.Protocol;
21import com.squareup.okhttp.internal.NamedRunnable;
22import com.squareup.okhttp.internal.Platform;
23import com.squareup.okhttp.internal.Util;
24import com.squareup.okhttp.internal.spdy.ErrorCode;
25import com.squareup.okhttp.internal.spdy.Header;
26import com.squareup.okhttp.internal.spdy.IncomingStreamHandler;
27import com.squareup.okhttp.internal.spdy.SpdyConnection;
28import com.squareup.okhttp.internal.spdy.SpdyStream;
29import java.io.BufferedInputStream;
30import java.io.BufferedOutputStream;
31import java.io.ByteArrayOutputStream;
32import java.io.IOException;
33import java.io.InputStream;
34import java.io.OutputStream;
35import java.net.InetAddress;
36import java.net.InetSocketAddress;
37import java.net.MalformedURLException;
38import java.net.Proxy;
39import java.net.ServerSocket;
40import java.net.Socket;
41import java.net.SocketException;
42import java.net.URL;
43import java.net.UnknownHostException;
44import java.security.SecureRandom;
45import java.security.cert.CertificateException;
46import java.security.cert.X509Certificate;
47import java.util.ArrayList;
48import java.util.Collections;
49import java.util.Iterator;
50import java.util.List;
51import java.util.Locale;
52import java.util.Map;
53import java.util.concurrent.BlockingQueue;
54import java.util.concurrent.ConcurrentHashMap;
55import java.util.concurrent.ExecutorService;
56import java.util.concurrent.Executors;
57import java.util.concurrent.LinkedBlockingQueue;
58import java.util.concurrent.atomic.AtomicInteger;
59import java.util.logging.Level;
60import java.util.logging.Logger;
61import javax.net.ssl.SSLContext;
62import javax.net.ssl.SSLSocket;
63import javax.net.ssl.SSLSocketFactory;
64import javax.net.ssl.TrustManager;
65import javax.net.ssl.X509TrustManager;
66import okio.BufferedSink;
67import okio.ByteString;
68import okio.OkBuffer;
69import okio.Okio;
70
71import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_AT_START;
72import static com.squareup.okhttp.mockwebserver.SocketPolicy.FAIL_HANDSHAKE;
73
74/**
75 * A scriptable web server. Callers supply canned responses and the server
76 * replays them upon request in sequence.
77 */
78public final class MockWebServer {
79  private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() {
80    @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
81        throws CertificateException {
82      throw new CertificateException();
83    }
84
85    @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
86      throw new AssertionError();
87    }
88
89    @Override public X509Certificate[] getAcceptedIssuers() {
90      throw new AssertionError();
91    }
92  };
93
94  private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
95
96  private final BlockingQueue<RecordedRequest> requestQueue =
97      new LinkedBlockingQueue<RecordedRequest>();
98
99  /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */
100  private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>();
101  private final Map<SpdyConnection, Boolean> openSpdyConnections
102      = new ConcurrentHashMap<SpdyConnection, Boolean>();
103  private final AtomicInteger requestCount = new AtomicInteger();
104  private int bodyLimit = Integer.MAX_VALUE;
105  private ServerSocket serverSocket;
106  private SSLSocketFactory sslSocketFactory;
107  private ExecutorService executor;
108  private boolean tunnelProxy;
109  private Dispatcher dispatcher = new QueueDispatcher();
110
111  private int port = -1;
112  private boolean npnEnabled = true;
113  private List<Protocol> npnProtocols = Protocol.HTTP2_SPDY3_AND_HTTP;
114
115  public int getPort() {
116    if (port == -1) throw new IllegalStateException("Cannot retrieve port before calling play()");
117    return port;
118  }
119
120  public String getHostName() {
121    try {
122      return InetAddress.getLocalHost().getHostName();
123    } catch (UnknownHostException e) {
124      throw new AssertionError(e);
125    }
126  }
127
128  public Proxy toProxyAddress() {
129    return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort()));
130  }
131
132  /**
133   * Returns a URL for connecting to this server.
134   * @param path the request path, such as "/".
135   */
136  public URL getUrl(String path) {
137    try {
138      return sslSocketFactory != null
139          ? new URL("https://" + getHostName() + ":" + getPort() + path)
140          : new URL("http://" + getHostName() + ":" + getPort() + path);
141    } catch (MalformedURLException e) {
142      throw new AssertionError(e);
143    }
144  }
145
146  /**
147   * Returns a cookie domain for this server. This returns the server's
148   * non-loopback host name if it is known. Otherwise this returns ".local" for
149   * this server's loopback name.
150   */
151  public String getCookieDomain() {
152    String hostName = getHostName();
153    return hostName.contains(".") ? hostName : ".local";
154  }
155
156  /**
157   * Sets the number of bytes of the POST body to keep in memory to the given
158   * limit.
159   */
160  public void setBodyLimit(int maxBodyLength) {
161    this.bodyLimit = maxBodyLength;
162  }
163
164  /**
165   * Sets whether NPN is used on incoming HTTPS connections to negotiate a
166   * protocol like HTTP/1.1 or SPDY/3. Call this method to disable NPN and
167   * SPDY.
168   */
169  public void setNpnEnabled(boolean npnEnabled) {
170    this.npnEnabled = npnEnabled;
171  }
172
173  /**
174   * Indicates the protocols supported by NPN on incoming HTTPS connections.
175   * This list is ignored when npn is disabled.
176   *
177   * @param protocols the protocols to use, in order of preference. The list
178   *     must contain "http/1.1". It must not contain null.
179   */
180  public void setNpnProtocols(List<Protocol> protocols) {
181    protocols = Util.immutableList(protocols);
182    if (!protocols.contains(Protocol.HTTP_11)) {
183      throw new IllegalArgumentException("protocols doesn't contain http/1.1: " + protocols);
184    }
185    if (protocols.contains(null)) {
186      throw new IllegalArgumentException("protocols must not contain null");
187    }
188    this.npnProtocols = Util.immutableList(protocols);
189  }
190
191  /**
192   * Serve requests with HTTPS rather than otherwise.
193   * @param tunnelProxy true to expect the HTTP CONNECT method before
194   *     negotiating TLS.
195   */
196  public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
197    this.sslSocketFactory = sslSocketFactory;
198    this.tunnelProxy = tunnelProxy;
199  }
200
201  /**
202   * Awaits the next HTTP request, removes it, and returns it. Callers should
203   * use this to verify the request was sent as intended.
204   */
205  public RecordedRequest takeRequest() throws InterruptedException {
206    return requestQueue.take();
207  }
208
209  /**
210   * Returns the number of HTTP requests received thus far by this server. This
211   * may exceed the number of HTTP connections when connection reuse is in
212   * practice.
213   */
214  public int getRequestCount() {
215    return requestCount.get();
216  }
217
218  /**
219   * Scripts {@code response} to be returned to a request made in sequence. The
220   * first request is served by the first enqueued response; the second request
221   * by the second enqueued response; and so on.
222   *
223   * @throws ClassCastException if the default dispatcher has been replaced
224   *     with {@link #setDispatcher(Dispatcher)}.
225   */
226  public void enqueue(MockResponse response) {
227    ((QueueDispatcher) dispatcher).enqueueResponse(response.clone());
228  }
229
230  /** Equivalent to {@code play(0)}. */
231  public void play() throws IOException {
232    play(0);
233  }
234
235  /**
236   * Starts the server, serves all enqueued requests, and shuts the server down.
237   *
238   * @param port the port to listen to, or 0 for any available port. Automated
239   *     tests should always use port 0 to avoid flakiness when a specific port
240   *     is unavailable.
241   */
242  public void play(int port) throws IOException {
243    if (executor != null) throw new IllegalStateException("play() already called");
244    executor = Executors.newCachedThreadPool(Util.threadFactory("MockWebServer", false));
245    serverSocket = new ServerSocket(port);
246    serverSocket.setReuseAddress(true);
247
248    this.port = serverSocket.getLocalPort();
249    executor.execute(new NamedRunnable("MockWebServer %s", port) {
250      @Override protected void execute() {
251        try {
252          acceptConnections();
253        } catch (Throwable e) {
254          logger.log(Level.WARNING, "MockWebServer connection failed", e);
255        }
256
257        // This gnarly block of code will release all sockets and all thread,
258        // even if any close fails.
259        Util.closeQuietly(serverSocket);
260        for (Iterator<Socket> s = openClientSockets.keySet().iterator(); s.hasNext(); ) {
261          Util.closeQuietly(s.next());
262          s.remove();
263        }
264        for (Iterator<SpdyConnection> s = openSpdyConnections.keySet().iterator(); s.hasNext(); ) {
265          Util.closeQuietly(s.next());
266          s.remove();
267        }
268        executor.shutdown();
269      }
270
271      private void acceptConnections() throws Exception {
272        while (true) {
273          Socket socket;
274          try {
275            socket = serverSocket.accept();
276          } catch (SocketException e) {
277            return;
278          }
279          SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
280          if (socketPolicy == DISCONNECT_AT_START) {
281            dispatchBookkeepingRequest(0, socket);
282            socket.close();
283          } else {
284            openClientSockets.put(socket, true);
285            serveConnection(socket);
286          }
287        }
288      }
289    });
290  }
291
292  public void shutdown() throws IOException {
293    if (serverSocket != null) {
294      serverSocket.close(); // Should cause acceptConnections() to break out.
295    }
296  }
297
298  private void serveConnection(final Socket raw) {
299    executor.execute(new NamedRunnable("MockWebServer %s", raw.getRemoteSocketAddress()) {
300      int sequenceNumber = 0;
301
302      @Override protected void execute() {
303        try {
304          processConnection();
305        } catch (Exception e) {
306          logger.log(Level.WARNING, "MockWebServer connection failed", e);
307        }
308      }
309
310      public void processConnection() throws Exception {
311        Protocol protocol = Protocol.HTTP_11;
312        Socket socket;
313        if (sslSocketFactory != null) {
314          if (tunnelProxy) {
315            createTunnel();
316          }
317          SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
318          if (socketPolicy == FAIL_HANDSHAKE) {
319            dispatchBookkeepingRequest(sequenceNumber, raw);
320            processHandshakeFailure(raw);
321            return;
322          }
323          socket = sslSocketFactory.createSocket(
324              raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
325          SSLSocket sslSocket = (SSLSocket) socket;
326          sslSocket.setUseClientMode(false);
327          openClientSockets.put(socket, true);
328
329          if (npnEnabled) {
330            Platform.get().setNpnProtocols(sslSocket, npnProtocols);
331          }
332
333          sslSocket.startHandshake();
334
335          if (npnEnabled) {
336            ByteString selectedProtocol = Platform.get().getNpnSelectedProtocol(sslSocket);
337            protocol = Protocol.find(selectedProtocol);
338          }
339          openClientSockets.remove(raw);
340        } else {
341          socket = raw;
342        }
343
344        if (protocol.spdyVariant) {
345          SpdySocketHandler spdySocketHandler = new SpdySocketHandler(socket, protocol);
346          SpdyConnection spdyConnection = new SpdyConnection.Builder(false, socket)
347              .protocol(protocol)
348              .handler(spdySocketHandler).build();
349          openSpdyConnections.put(spdyConnection, Boolean.TRUE);
350          openClientSockets.remove(socket);
351          return;
352        }
353
354        InputStream in = new BufferedInputStream(socket.getInputStream());
355        OutputStream out = new BufferedOutputStream(socket.getOutputStream());
356
357        while (processOneRequest(socket, in, out)) {
358        }
359
360        if (sequenceNumber == 0) {
361          logger.warning("MockWebServer connection didn't make a request");
362        }
363
364        in.close();
365        out.close();
366        socket.close();
367        openClientSockets.remove(socket);
368      }
369
370      /**
371       * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response is
372       * dispatched.
373       */
374      private void createTunnel() throws IOException, InterruptedException {
375        while (true) {
376          SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
377          if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) {
378            throw new IllegalStateException("Tunnel without any CONNECT!");
379          }
380          if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return;
381        }
382      }
383
384      /**
385       * Reads a request and writes its response. Returns true if a request was
386       * processed.
387       */
388      private boolean processOneRequest(Socket socket, InputStream in, OutputStream out)
389          throws IOException, InterruptedException {
390        RecordedRequest request = readRequest(socket, in, out, sequenceNumber);
391        if (request == null) return false;
392        requestCount.incrementAndGet();
393        requestQueue.add(request);
394        MockResponse response = dispatcher.dispatch(request);
395        writeResponse(out, response);
396        if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
397          in.close();
398          out.close();
399        } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) {
400          socket.shutdownInput();
401        } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) {
402          socket.shutdownOutput();
403        }
404        if (logger.isLoggable(Level.INFO)) {
405          logger.info("Received request: " + request + " and responded: " + response);
406        }
407        sequenceNumber++;
408        return true;
409      }
410    });
411  }
412
413  private void processHandshakeFailure(Socket raw) throws Exception {
414    SSLContext context = SSLContext.getInstance("TLS");
415    context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom());
416    SSLSocketFactory sslSocketFactory = context.getSocketFactory();
417    SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
418        raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
419    try {
420      socket.startHandshake(); // we're testing a handshake failure
421      throw new AssertionError();
422    } catch (IOException expected) {
423    }
424    socket.close();
425  }
426
427  private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket)
428      throws InterruptedException {
429    requestCount.incrementAndGet();
430    dispatcher.dispatch(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket));
431  }
432
433  /** @param sequenceNumber the index of this request on this connection. */
434  private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out,
435      int sequenceNumber) throws IOException {
436    String request;
437    try {
438      request = readAsciiUntilCrlf(in);
439    } catch (IOException streamIsClosed) {
440      return null; // no request because we closed the stream
441    }
442    if (request.length() == 0) {
443      return null; // no request because the stream is exhausted
444    }
445
446    List<String> headers = new ArrayList<String>();
447    long contentLength = -1;
448    boolean chunked = false;
449    boolean expectContinue = false;
450    String header;
451    while ((header = readAsciiUntilCrlf(in)).length() != 0) {
452      headers.add(header);
453      String lowercaseHeader = header.toLowerCase(Locale.US);
454      if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
455        contentLength = Long.parseLong(header.substring(15).trim());
456      }
457      if (lowercaseHeader.startsWith("transfer-encoding:")
458          && lowercaseHeader.substring(18).trim().equals("chunked")) {
459        chunked = true;
460      }
461      if (lowercaseHeader.startsWith("expect:")
462          && lowercaseHeader.substring(7).trim().equals("100-continue")) {
463        expectContinue = true;
464      }
465    }
466
467    if (expectContinue) {
468      out.write(("HTTP/1.1 100 Continue\r\n").getBytes(Util.US_ASCII));
469      out.write(("Content-Length: 0\r\n").getBytes(Util.US_ASCII));
470      out.write(("\r\n").getBytes(Util.US_ASCII));
471      out.flush();
472    }
473
474    boolean hasBody = false;
475    TruncatingOutputStream requestBody = new TruncatingOutputStream();
476    List<Integer> chunkSizes = new ArrayList<Integer>();
477    MockResponse throttlePolicy = dispatcher.peek();
478    if (contentLength != -1) {
479      hasBody = true;
480      throttledTransfer(throttlePolicy, in, requestBody, contentLength);
481    } else if (chunked) {
482      hasBody = true;
483      while (true) {
484        int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
485        if (chunkSize == 0) {
486          readEmptyLine(in);
487          break;
488        }
489        chunkSizes.add(chunkSize);
490        throttledTransfer(throttlePolicy, in, requestBody, chunkSize);
491        readEmptyLine(in);
492      }
493    }
494
495    if (request.startsWith("OPTIONS ")
496        || request.startsWith("GET ")
497        || request.startsWith("HEAD ")
498        || request.startsWith("TRACE ")
499        || request.startsWith("CONNECT ")) {
500      if (hasBody) {
501        throw new IllegalArgumentException("Request must not have a body: " + request);
502      }
503    } else if (!request.startsWith("POST ")
504        && !request.startsWith("PUT ")
505        && !request.startsWith("PATCH ")
506        && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous.
507      throw new UnsupportedOperationException("Unexpected method: " + request);
508    }
509
510    return new RecordedRequest(request, headers, chunkSizes, requestBody.numBytesReceived,
511        requestBody.toByteArray(), sequenceNumber, socket);
512  }
513
514  private void writeResponse(OutputStream out, MockResponse response) throws IOException {
515    out.write((response.getStatus() + "\r\n").getBytes(Util.US_ASCII));
516    List<String> headers = response.getHeaders();
517    for (int i = 0, size = headers.size(); i < size; i++) {
518      String header = headers.get(i);
519      out.write((header + "\r\n").getBytes(Util.US_ASCII));
520    }
521    out.write(("\r\n").getBytes(Util.US_ASCII));
522    out.flush();
523
524    InputStream in = response.getBodyStream();
525    if (in == null) return;
526    throttledTransfer(response, in, out, Long.MAX_VALUE);
527  }
528
529  /**
530   * Transfer bytes from {@code in} to {@code out} until either {@code length}
531   * bytes have been transferred or {@code in} is exhausted. The transfer is
532   * throttled according to {@code throttlePolicy}.
533   */
534  private void throttledTransfer(MockResponse throttlePolicy, InputStream in, OutputStream out,
535      long limit) throws IOException {
536    byte[] buffer = new byte[1024];
537    int bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod();
538    long delayMs = throttlePolicy.getThrottleUnit().toMillis(throttlePolicy.getThrottlePeriod());
539
540    while (true) {
541      for (int b = 0; b < bytesPerPeriod; ) {
542        int toRead = (int) Math.min(Math.min(buffer.length, limit), bytesPerPeriod - b);
543        int read = in.read(buffer, 0, toRead);
544        if (read == -1) return;
545
546        out.write(buffer, 0, read);
547        out.flush();
548        b += read;
549        limit -= read;
550
551        if (limit == 0) return;
552      }
553
554      try {
555        if (delayMs != 0) Thread.sleep(delayMs);
556      } catch (InterruptedException e) {
557        throw new AssertionError();
558      }
559    }
560  }
561
562  /**
563   * Returns the text from {@code in} until the next "\r\n", or null if {@code
564   * in} is exhausted.
565   */
566  private String readAsciiUntilCrlf(InputStream in) throws IOException {
567    StringBuilder builder = new StringBuilder();
568    while (true) {
569      int c = in.read();
570      if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
571        builder.deleteCharAt(builder.length() - 1);
572        return builder.toString();
573      } else if (c == -1) {
574        return builder.toString();
575      } else {
576        builder.append((char) c);
577      }
578    }
579  }
580
581  private void readEmptyLine(InputStream in) throws IOException {
582    String line = readAsciiUntilCrlf(in);
583    if (line.length() != 0) throw new IllegalStateException("Expected empty but was: " + line);
584  }
585
586  /**
587   * Sets the dispatcher used to match incoming requests to mock responses.
588   * The default dispatcher simply serves a fixed sequence of responses from
589   * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the
590   * response based on timing or the content of the request.
591   */
592  public void setDispatcher(Dispatcher dispatcher) {
593    if (dispatcher == null) throw new NullPointerException();
594    this.dispatcher = dispatcher;
595  }
596
597  /** An output stream that drops data after bodyLimit bytes. */
598  private class TruncatingOutputStream extends ByteArrayOutputStream {
599    private long numBytesReceived = 0;
600
601    @Override public void write(byte[] buffer, int offset, int len) {
602      numBytesReceived += len;
603      super.write(buffer, offset, Math.min(len, bodyLimit - count));
604    }
605
606    @Override public void write(int oneByte) {
607      numBytesReceived++;
608      if (count < bodyLimit) {
609        super.write(oneByte);
610      }
611    }
612  }
613
614  /** Processes HTTP requests layered over SPDY/3. */
615  private class SpdySocketHandler implements IncomingStreamHandler {
616    private final Socket socket;
617    private final Protocol protocol;
618    private final AtomicInteger sequenceNumber = new AtomicInteger();
619
620    private SpdySocketHandler(Socket socket, Protocol protocol) {
621      this.socket = socket;
622      this.protocol = protocol;
623    }
624
625    @Override public void receive(SpdyStream stream) throws IOException {
626      RecordedRequest request = readRequest(stream);
627      requestQueue.add(request);
628      MockResponse response;
629      try {
630        response = dispatcher.dispatch(request);
631      } catch (InterruptedException e) {
632        throw new AssertionError(e);
633      }
634      writeResponse(stream, response);
635      if (logger.isLoggable(Level.INFO)) {
636        logger.info("Received request: " + request + " and responded: " + response
637            + " protocol is " + protocol.name.utf8());
638      }
639    }
640
641    private RecordedRequest readRequest(SpdyStream stream) throws IOException {
642      List<Header> spdyHeaders = stream.getRequestHeaders();
643      List<String> httpHeaders = new ArrayList<String>();
644      String method = "<:method omitted>";
645      String path = "<:path omitted>";
646      String version = protocol == Protocol.SPDY_3 ? "<:version omitted>" : "HTTP/1.1";
647      for (int i = 0, size = spdyHeaders.size(); i < size; i++) {
648        ByteString name = spdyHeaders.get(i).name;
649        String value = spdyHeaders.get(i).value.utf8();
650        if (name.equals(Header.TARGET_METHOD)) {
651          method = value;
652        } else if (name.equals(Header.TARGET_PATH)) {
653          path = value;
654        } else if (name.equals(Header.VERSION)) {
655          version = value;
656        } else {
657          httpHeaders.add(name.utf8() + ": " + value);
658        }
659      }
660
661      InputStream bodyIn = Okio.buffer(stream.getSource()).inputStream();
662      ByteArrayOutputStream bodyOut = new ByteArrayOutputStream();
663      byte[] buffer = new byte[8192];
664      int count;
665      while ((count = bodyIn.read(buffer)) != -1) {
666        bodyOut.write(buffer, 0, count);
667      }
668      bodyIn.close();
669      String requestLine = method + ' ' + path + ' ' + version;
670      List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY.
671      return new RecordedRequest(requestLine, httpHeaders, chunkSizes, bodyOut.size(),
672          bodyOut.toByteArray(), sequenceNumber.getAndIncrement(), socket);
673    }
674
675    private void writeResponse(SpdyStream stream, MockResponse response) throws IOException {
676      if (response.getSocketPolicy() == SocketPolicy.NO_RESPONSE) {
677        return;
678      }
679      List<Header> spdyHeaders = new ArrayList<Header>();
680      String[] statusParts = response.getStatus().split(" ", 2);
681      if (statusParts.length != 2) {
682        throw new AssertionError("Unexpected status: " + response.getStatus());
683      }
684      // TODO: constants for well-known header names.
685      spdyHeaders.add(new Header(Header.RESPONSE_STATUS, statusParts[1]));
686      if (protocol == Protocol.SPDY_3) {
687        spdyHeaders.add(new Header(Header.VERSION, statusParts[0]));
688      }
689      List<String> headers = response.getHeaders();
690      for (int i = 0, size = headers.size(); i < size; i++) {
691        String header = headers.get(i);
692        String[] headerParts = header.split(":", 2);
693        if (headerParts.length != 2) {
694          throw new AssertionError("Unexpected header: " + header);
695        }
696        spdyHeaders.add(new Header(headerParts[0], headerParts[1]));
697      }
698      OkBuffer body = new OkBuffer();
699      if (response.getBody() != null) {
700        body.write(response.getBody());
701      }
702      boolean closeStreamAfterHeaders = body.size() > 0 || !response.getPushPromises().isEmpty();
703      stream.reply(spdyHeaders, closeStreamAfterHeaders);
704      pushPromises(stream, response.getPushPromises());
705      if (body.size() > 0) {
706        if (response.getBodyDelayTimeMs() != 0) {
707          try {
708            Thread.sleep(response.getBodyDelayTimeMs());
709          } catch (InterruptedException e) {
710            throw new AssertionError(e);
711          }
712        }
713        BufferedSink sink = Okio.buffer(stream.getSink());
714        if (response.getThrottleBytesPerPeriod() == Integer.MAX_VALUE) {
715          sink.write(body, body.size());
716          sink.flush();
717        } else {
718          while (body.size() > 0) {
719            long toWrite = Math.min(body.size(), response.getThrottleBytesPerPeriod());
720            sink.write(body, toWrite);
721            sink.flush();
722            try {
723              long delayMs = response.getThrottleUnit().toMillis(response.getThrottlePeriod());
724              if (delayMs != 0) Thread.sleep(delayMs);
725            } catch (InterruptedException e) {
726              throw new AssertionError();
727            }
728          }
729        }
730        sink.close();
731      } else if (closeStreamAfterHeaders) {
732        stream.close(ErrorCode.NO_ERROR);
733      }
734    }
735
736    private void pushPromises(SpdyStream stream, List<PushPromise> promises) throws IOException {
737      for (PushPromise pushPromise : promises) {
738        List<Header> pushedHeaders = new ArrayList<Header>();
739        pushedHeaders.add(new Header(stream.getConnection().getProtocol() == Protocol.SPDY_3
740            ? Header.TARGET_HOST
741            : Header.TARGET_AUTHORITY, getUrl(pushPromise.getPath()).getHost()));
742        pushedHeaders.add(new Header(Header.TARGET_METHOD, pushPromise.getMethod()));
743        pushedHeaders.add(new Header(Header.TARGET_PATH, pushPromise.getPath()));
744        for (int i = 0, size = pushPromise.getHeaders().size(); i < size; i++) {
745          String header = pushPromise.getHeaders().get(i);
746          String[] headerParts = header.split(":", 2);
747          if (headerParts.length != 2) {
748            throw new AssertionError("Unexpected header: " + header);
749          }
750          pushedHeaders.add(new Header(headerParts[0], headerParts[1].trim()));
751        }
752        String requestLine = pushPromise.getMethod() + ' ' + pushPromise.getPath() + " HTTP/1.1";
753        List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY.
754        requestQueue.add(new RecordedRequest(requestLine, pushPromise.getHeaders(), chunkSizes, 0,
755            Util.EMPTY_BYTE_ARRAY, sequenceNumber.getAndIncrement(), socket));
756        byte[] pushedBody = pushPromise.getResponse().getBody();
757        SpdyStream pushedStream =
758            stream.getConnection().pushStream(stream.getId(), pushedHeaders, pushedBody.length > 0);
759        writeResponse(pushedStream, pushPromise.getResponse());
760      }
761    }
762  }
763}
764