1/* 2 * Copyright (C) 2011 Google Inc. 3 * Copyright (C) 2013 Square, Inc. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package com.squareup.okhttp.mockwebserver; 19 20import com.squareup.okhttp.Protocol; 21import com.squareup.okhttp.internal.NamedRunnable; 22import com.squareup.okhttp.internal.Platform; 23import com.squareup.okhttp.internal.Util; 24import com.squareup.okhttp.internal.spdy.ErrorCode; 25import com.squareup.okhttp.internal.spdy.Header; 26import com.squareup.okhttp.internal.spdy.IncomingStreamHandler; 27import com.squareup.okhttp.internal.spdy.SpdyConnection; 28import com.squareup.okhttp.internal.spdy.SpdyStream; 29import java.io.BufferedInputStream; 30import java.io.BufferedOutputStream; 31import java.io.ByteArrayOutputStream; 32import java.io.IOException; 33import java.io.InputStream; 34import java.io.OutputStream; 35import java.net.InetAddress; 36import java.net.InetSocketAddress; 37import java.net.MalformedURLException; 38import java.net.Proxy; 39import java.net.ServerSocket; 40import java.net.Socket; 41import java.net.SocketException; 42import java.net.URL; 43import java.net.UnknownHostException; 44import java.security.SecureRandom; 45import java.security.cert.CertificateException; 46import java.security.cert.X509Certificate; 47import java.util.ArrayList; 48import java.util.Collections; 49import java.util.Iterator; 50import java.util.List; 51import java.util.Locale; 52import java.util.Map; 53import java.util.concurrent.BlockingQueue; 54import java.util.concurrent.ConcurrentHashMap; 55import java.util.concurrent.ExecutorService; 56import java.util.concurrent.Executors; 57import java.util.concurrent.LinkedBlockingQueue; 58import java.util.concurrent.atomic.AtomicInteger; 59import java.util.logging.Level; 60import java.util.logging.Logger; 61import javax.net.ssl.SSLContext; 62import javax.net.ssl.SSLSocket; 63import javax.net.ssl.SSLSocketFactory; 64import javax.net.ssl.TrustManager; 65import javax.net.ssl.X509TrustManager; 66import okio.BufferedSink; 67import okio.ByteString; 68import okio.OkBuffer; 69import okio.Okio; 70 71import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_AT_START; 72import static com.squareup.okhttp.mockwebserver.SocketPolicy.FAIL_HANDSHAKE; 73 74/** 75 * A scriptable web server. Callers supply canned responses and the server 76 * replays them upon request in sequence. 77 */ 78public final class MockWebServer { 79 private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() { 80 @Override public void checkClientTrusted(X509Certificate[] chain, String authType) 81 throws CertificateException { 82 throw new CertificateException(); 83 } 84 85 @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { 86 throw new AssertionError(); 87 } 88 89 @Override public X509Certificate[] getAcceptedIssuers() { 90 throw new AssertionError(); 91 } 92 }; 93 94 private static final Logger logger = Logger.getLogger(MockWebServer.class.getName()); 95 96 private final BlockingQueue<RecordedRequest> requestQueue = 97 new LinkedBlockingQueue<RecordedRequest>(); 98 99 /** All map values are Boolean.TRUE. (Collections.newSetFromMap isn't available in Froyo) */ 100 private final Map<Socket, Boolean> openClientSockets = new ConcurrentHashMap<Socket, Boolean>(); 101 private final Map<SpdyConnection, Boolean> openSpdyConnections 102 = new ConcurrentHashMap<SpdyConnection, Boolean>(); 103 private final AtomicInteger requestCount = new AtomicInteger(); 104 private int bodyLimit = Integer.MAX_VALUE; 105 private ServerSocket serverSocket; 106 private SSLSocketFactory sslSocketFactory; 107 private ExecutorService executor; 108 private boolean tunnelProxy; 109 private Dispatcher dispatcher = new QueueDispatcher(); 110 111 private int port = -1; 112 private boolean npnEnabled = true; 113 private List<Protocol> npnProtocols = Protocol.HTTP2_SPDY3_AND_HTTP; 114 115 public int getPort() { 116 if (port == -1) throw new IllegalStateException("Cannot retrieve port before calling play()"); 117 return port; 118 } 119 120 public String getHostName() { 121 try { 122 return InetAddress.getLocalHost().getHostName(); 123 } catch (UnknownHostException e) { 124 throw new AssertionError(e); 125 } 126 } 127 128 public Proxy toProxyAddress() { 129 return new Proxy(Proxy.Type.HTTP, new InetSocketAddress(getHostName(), getPort())); 130 } 131 132 /** 133 * Returns a URL for connecting to this server. 134 * @param path the request path, such as "/". 135 */ 136 public URL getUrl(String path) { 137 try { 138 return sslSocketFactory != null 139 ? new URL("https://" + getHostName() + ":" + getPort() + path) 140 : new URL("http://" + getHostName() + ":" + getPort() + path); 141 } catch (MalformedURLException e) { 142 throw new AssertionError(e); 143 } 144 } 145 146 /** 147 * Returns a cookie domain for this server. This returns the server's 148 * non-loopback host name if it is known. Otherwise this returns ".local" for 149 * this server's loopback name. 150 */ 151 public String getCookieDomain() { 152 String hostName = getHostName(); 153 return hostName.contains(".") ? hostName : ".local"; 154 } 155 156 /** 157 * Sets the number of bytes of the POST body to keep in memory to the given 158 * limit. 159 */ 160 public void setBodyLimit(int maxBodyLength) { 161 this.bodyLimit = maxBodyLength; 162 } 163 164 /** 165 * Sets whether NPN is used on incoming HTTPS connections to negotiate a 166 * protocol like HTTP/1.1 or SPDY/3. Call this method to disable NPN and 167 * SPDY. 168 */ 169 public void setNpnEnabled(boolean npnEnabled) { 170 this.npnEnabled = npnEnabled; 171 } 172 173 /** 174 * Indicates the protocols supported by NPN on incoming HTTPS connections. 175 * This list is ignored when npn is disabled. 176 * 177 * @param protocols the protocols to use, in order of preference. The list 178 * must contain "http/1.1". It must not contain null. 179 */ 180 public void setNpnProtocols(List<Protocol> protocols) { 181 protocols = Util.immutableList(protocols); 182 if (!protocols.contains(Protocol.HTTP_11)) { 183 throw new IllegalArgumentException("protocols doesn't contain http/1.1: " + protocols); 184 } 185 if (protocols.contains(null)) { 186 throw new IllegalArgumentException("protocols must not contain null"); 187 } 188 this.npnProtocols = Util.immutableList(protocols); 189 } 190 191 /** 192 * Serve requests with HTTPS rather than otherwise. 193 * @param tunnelProxy true to expect the HTTP CONNECT method before 194 * negotiating TLS. 195 */ 196 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 197 this.sslSocketFactory = sslSocketFactory; 198 this.tunnelProxy = tunnelProxy; 199 } 200 201 /** 202 * Awaits the next HTTP request, removes it, and returns it. Callers should 203 * use this to verify the request was sent as intended. 204 */ 205 public RecordedRequest takeRequest() throws InterruptedException { 206 return requestQueue.take(); 207 } 208 209 /** 210 * Returns the number of HTTP requests received thus far by this server. This 211 * may exceed the number of HTTP connections when connection reuse is in 212 * practice. 213 */ 214 public int getRequestCount() { 215 return requestCount.get(); 216 } 217 218 /** 219 * Scripts {@code response} to be returned to a request made in sequence. The 220 * first request is served by the first enqueued response; the second request 221 * by the second enqueued response; and so on. 222 * 223 * @throws ClassCastException if the default dispatcher has been replaced 224 * with {@link #setDispatcher(Dispatcher)}. 225 */ 226 public void enqueue(MockResponse response) { 227 ((QueueDispatcher) dispatcher).enqueueResponse(response.clone()); 228 } 229 230 /** Equivalent to {@code play(0)}. */ 231 public void play() throws IOException { 232 play(0); 233 } 234 235 /** 236 * Starts the server, serves all enqueued requests, and shuts the server down. 237 * 238 * @param port the port to listen to, or 0 for any available port. Automated 239 * tests should always use port 0 to avoid flakiness when a specific port 240 * is unavailable. 241 */ 242 public void play(int port) throws IOException { 243 if (executor != null) throw new IllegalStateException("play() already called"); 244 executor = Executors.newCachedThreadPool(Util.threadFactory("MockWebServer", false)); 245 serverSocket = new ServerSocket(port); 246 serverSocket.setReuseAddress(true); 247 248 this.port = serverSocket.getLocalPort(); 249 executor.execute(new NamedRunnable("MockWebServer %s", port) { 250 @Override protected void execute() { 251 try { 252 acceptConnections(); 253 } catch (Throwable e) { 254 logger.log(Level.WARNING, "MockWebServer connection failed", e); 255 } 256 257 // This gnarly block of code will release all sockets and all thread, 258 // even if any close fails. 259 Util.closeQuietly(serverSocket); 260 for (Iterator<Socket> s = openClientSockets.keySet().iterator(); s.hasNext(); ) { 261 Util.closeQuietly(s.next()); 262 s.remove(); 263 } 264 for (Iterator<SpdyConnection> s = openSpdyConnections.keySet().iterator(); s.hasNext(); ) { 265 Util.closeQuietly(s.next()); 266 s.remove(); 267 } 268 executor.shutdown(); 269 } 270 271 private void acceptConnections() throws Exception { 272 while (true) { 273 Socket socket; 274 try { 275 socket = serverSocket.accept(); 276 } catch (SocketException e) { 277 return; 278 } 279 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 280 if (socketPolicy == DISCONNECT_AT_START) { 281 dispatchBookkeepingRequest(0, socket); 282 socket.close(); 283 } else { 284 openClientSockets.put(socket, true); 285 serveConnection(socket); 286 } 287 } 288 } 289 }); 290 } 291 292 public void shutdown() throws IOException { 293 if (serverSocket != null) { 294 serverSocket.close(); // Should cause acceptConnections() to break out. 295 } 296 } 297 298 private void serveConnection(final Socket raw) { 299 executor.execute(new NamedRunnable("MockWebServer %s", raw.getRemoteSocketAddress()) { 300 int sequenceNumber = 0; 301 302 @Override protected void execute() { 303 try { 304 processConnection(); 305 } catch (Exception e) { 306 logger.log(Level.WARNING, "MockWebServer connection failed", e); 307 } 308 } 309 310 public void processConnection() throws Exception { 311 Protocol protocol = Protocol.HTTP_11; 312 Socket socket; 313 if (sslSocketFactory != null) { 314 if (tunnelProxy) { 315 createTunnel(); 316 } 317 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 318 if (socketPolicy == FAIL_HANDSHAKE) { 319 dispatchBookkeepingRequest(sequenceNumber, raw); 320 processHandshakeFailure(raw); 321 return; 322 } 323 socket = sslSocketFactory.createSocket( 324 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 325 SSLSocket sslSocket = (SSLSocket) socket; 326 sslSocket.setUseClientMode(false); 327 openClientSockets.put(socket, true); 328 329 if (npnEnabled) { 330 Platform.get().setNpnProtocols(sslSocket, npnProtocols); 331 } 332 333 sslSocket.startHandshake(); 334 335 if (npnEnabled) { 336 ByteString selectedProtocol = Platform.get().getNpnSelectedProtocol(sslSocket); 337 protocol = Protocol.find(selectedProtocol); 338 } 339 openClientSockets.remove(raw); 340 } else { 341 socket = raw; 342 } 343 344 if (protocol.spdyVariant) { 345 SpdySocketHandler spdySocketHandler = new SpdySocketHandler(socket, protocol); 346 SpdyConnection spdyConnection = new SpdyConnection.Builder(false, socket) 347 .protocol(protocol) 348 .handler(spdySocketHandler).build(); 349 openSpdyConnections.put(spdyConnection, Boolean.TRUE); 350 openClientSockets.remove(socket); 351 return; 352 } 353 354 InputStream in = new BufferedInputStream(socket.getInputStream()); 355 OutputStream out = new BufferedOutputStream(socket.getOutputStream()); 356 357 while (processOneRequest(socket, in, out)) { 358 } 359 360 if (sequenceNumber == 0) { 361 logger.warning("MockWebServer connection didn't make a request"); 362 } 363 364 in.close(); 365 out.close(); 366 socket.close(); 367 openClientSockets.remove(socket); 368 } 369 370 /** 371 * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response is 372 * dispatched. 373 */ 374 private void createTunnel() throws IOException, InterruptedException { 375 while (true) { 376 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 377 if (!processOneRequest(raw, raw.getInputStream(), raw.getOutputStream())) { 378 throw new IllegalStateException("Tunnel without any CONNECT!"); 379 } 380 if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return; 381 } 382 } 383 384 /** 385 * Reads a request and writes its response. Returns true if a request was 386 * processed. 387 */ 388 private boolean processOneRequest(Socket socket, InputStream in, OutputStream out) 389 throws IOException, InterruptedException { 390 RecordedRequest request = readRequest(socket, in, out, sequenceNumber); 391 if (request == null) return false; 392 requestCount.incrementAndGet(); 393 requestQueue.add(request); 394 MockResponse response = dispatcher.dispatch(request); 395 writeResponse(out, response); 396 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) { 397 in.close(); 398 out.close(); 399 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) { 400 socket.shutdownInput(); 401 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) { 402 socket.shutdownOutput(); 403 } 404 if (logger.isLoggable(Level.INFO)) { 405 logger.info("Received request: " + request + " and responded: " + response); 406 } 407 sequenceNumber++; 408 return true; 409 } 410 }); 411 } 412 413 private void processHandshakeFailure(Socket raw) throws Exception { 414 SSLContext context = SSLContext.getInstance("TLS"); 415 context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom()); 416 SSLSocketFactory sslSocketFactory = context.getSocketFactory(); 417 SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket( 418 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 419 try { 420 socket.startHandshake(); // we're testing a handshake failure 421 throw new AssertionError(); 422 } catch (IOException expected) { 423 } 424 socket.close(); 425 } 426 427 private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket) 428 throws InterruptedException { 429 requestCount.incrementAndGet(); 430 dispatcher.dispatch(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket)); 431 } 432 433 /** @param sequenceNumber the index of this request on this connection. */ 434 private RecordedRequest readRequest(Socket socket, InputStream in, OutputStream out, 435 int sequenceNumber) throws IOException { 436 String request; 437 try { 438 request = readAsciiUntilCrlf(in); 439 } catch (IOException streamIsClosed) { 440 return null; // no request because we closed the stream 441 } 442 if (request.length() == 0) { 443 return null; // no request because the stream is exhausted 444 } 445 446 List<String> headers = new ArrayList<String>(); 447 long contentLength = -1; 448 boolean chunked = false; 449 boolean expectContinue = false; 450 String header; 451 while ((header = readAsciiUntilCrlf(in)).length() != 0) { 452 headers.add(header); 453 String lowercaseHeader = header.toLowerCase(Locale.US); 454 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 455 contentLength = Long.parseLong(header.substring(15).trim()); 456 } 457 if (lowercaseHeader.startsWith("transfer-encoding:") 458 && lowercaseHeader.substring(18).trim().equals("chunked")) { 459 chunked = true; 460 } 461 if (lowercaseHeader.startsWith("expect:") 462 && lowercaseHeader.substring(7).trim().equals("100-continue")) { 463 expectContinue = true; 464 } 465 } 466 467 if (expectContinue) { 468 out.write(("HTTP/1.1 100 Continue\r\n").getBytes(Util.US_ASCII)); 469 out.write(("Content-Length: 0\r\n").getBytes(Util.US_ASCII)); 470 out.write(("\r\n").getBytes(Util.US_ASCII)); 471 out.flush(); 472 } 473 474 boolean hasBody = false; 475 TruncatingOutputStream requestBody = new TruncatingOutputStream(); 476 List<Integer> chunkSizes = new ArrayList<Integer>(); 477 MockResponse throttlePolicy = dispatcher.peek(); 478 if (contentLength != -1) { 479 hasBody = true; 480 throttledTransfer(throttlePolicy, in, requestBody, contentLength); 481 } else if (chunked) { 482 hasBody = true; 483 while (true) { 484 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); 485 if (chunkSize == 0) { 486 readEmptyLine(in); 487 break; 488 } 489 chunkSizes.add(chunkSize); 490 throttledTransfer(throttlePolicy, in, requestBody, chunkSize); 491 readEmptyLine(in); 492 } 493 } 494 495 if (request.startsWith("OPTIONS ") 496 || request.startsWith("GET ") 497 || request.startsWith("HEAD ") 498 || request.startsWith("TRACE ") 499 || request.startsWith("CONNECT ")) { 500 if (hasBody) { 501 throw new IllegalArgumentException("Request must not have a body: " + request); 502 } 503 } else if (!request.startsWith("POST ") 504 && !request.startsWith("PUT ") 505 && !request.startsWith("PATCH ") 506 && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous. 507 throw new UnsupportedOperationException("Unexpected method: " + request); 508 } 509 510 return new RecordedRequest(request, headers, chunkSizes, requestBody.numBytesReceived, 511 requestBody.toByteArray(), sequenceNumber, socket); 512 } 513 514 private void writeResponse(OutputStream out, MockResponse response) throws IOException { 515 out.write((response.getStatus() + "\r\n").getBytes(Util.US_ASCII)); 516 List<String> headers = response.getHeaders(); 517 for (int i = 0, size = headers.size(); i < size; i++) { 518 String header = headers.get(i); 519 out.write((header + "\r\n").getBytes(Util.US_ASCII)); 520 } 521 out.write(("\r\n").getBytes(Util.US_ASCII)); 522 out.flush(); 523 524 InputStream in = response.getBodyStream(); 525 if (in == null) return; 526 throttledTransfer(response, in, out, Long.MAX_VALUE); 527 } 528 529 /** 530 * Transfer bytes from {@code in} to {@code out} until either {@code length} 531 * bytes have been transferred or {@code in} is exhausted. The transfer is 532 * throttled according to {@code throttlePolicy}. 533 */ 534 private void throttledTransfer(MockResponse throttlePolicy, InputStream in, OutputStream out, 535 long limit) throws IOException { 536 byte[] buffer = new byte[1024]; 537 int bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod(); 538 long delayMs = throttlePolicy.getThrottleUnit().toMillis(throttlePolicy.getThrottlePeriod()); 539 540 while (true) { 541 for (int b = 0; b < bytesPerPeriod; ) { 542 int toRead = (int) Math.min(Math.min(buffer.length, limit), bytesPerPeriod - b); 543 int read = in.read(buffer, 0, toRead); 544 if (read == -1) return; 545 546 out.write(buffer, 0, read); 547 out.flush(); 548 b += read; 549 limit -= read; 550 551 if (limit == 0) return; 552 } 553 554 try { 555 if (delayMs != 0) Thread.sleep(delayMs); 556 } catch (InterruptedException e) { 557 throw new AssertionError(); 558 } 559 } 560 } 561 562 /** 563 * Returns the text from {@code in} until the next "\r\n", or null if {@code 564 * in} is exhausted. 565 */ 566 private String readAsciiUntilCrlf(InputStream in) throws IOException { 567 StringBuilder builder = new StringBuilder(); 568 while (true) { 569 int c = in.read(); 570 if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { 571 builder.deleteCharAt(builder.length() - 1); 572 return builder.toString(); 573 } else if (c == -1) { 574 return builder.toString(); 575 } else { 576 builder.append((char) c); 577 } 578 } 579 } 580 581 private void readEmptyLine(InputStream in) throws IOException { 582 String line = readAsciiUntilCrlf(in); 583 if (line.length() != 0) throw new IllegalStateException("Expected empty but was: " + line); 584 } 585 586 /** 587 * Sets the dispatcher used to match incoming requests to mock responses. 588 * The default dispatcher simply serves a fixed sequence of responses from 589 * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the 590 * response based on timing or the content of the request. 591 */ 592 public void setDispatcher(Dispatcher dispatcher) { 593 if (dispatcher == null) throw new NullPointerException(); 594 this.dispatcher = dispatcher; 595 } 596 597 /** An output stream that drops data after bodyLimit bytes. */ 598 private class TruncatingOutputStream extends ByteArrayOutputStream { 599 private long numBytesReceived = 0; 600 601 @Override public void write(byte[] buffer, int offset, int len) { 602 numBytesReceived += len; 603 super.write(buffer, offset, Math.min(len, bodyLimit - count)); 604 } 605 606 @Override public void write(int oneByte) { 607 numBytesReceived++; 608 if (count < bodyLimit) { 609 super.write(oneByte); 610 } 611 } 612 } 613 614 /** Processes HTTP requests layered over SPDY/3. */ 615 private class SpdySocketHandler implements IncomingStreamHandler { 616 private final Socket socket; 617 private final Protocol protocol; 618 private final AtomicInteger sequenceNumber = new AtomicInteger(); 619 620 private SpdySocketHandler(Socket socket, Protocol protocol) { 621 this.socket = socket; 622 this.protocol = protocol; 623 } 624 625 @Override public void receive(SpdyStream stream) throws IOException { 626 RecordedRequest request = readRequest(stream); 627 requestQueue.add(request); 628 MockResponse response; 629 try { 630 response = dispatcher.dispatch(request); 631 } catch (InterruptedException e) { 632 throw new AssertionError(e); 633 } 634 writeResponse(stream, response); 635 if (logger.isLoggable(Level.INFO)) { 636 logger.info("Received request: " + request + " and responded: " + response 637 + " protocol is " + protocol.name.utf8()); 638 } 639 } 640 641 private RecordedRequest readRequest(SpdyStream stream) throws IOException { 642 List<Header> spdyHeaders = stream.getRequestHeaders(); 643 List<String> httpHeaders = new ArrayList<String>(); 644 String method = "<:method omitted>"; 645 String path = "<:path omitted>"; 646 String version = protocol == Protocol.SPDY_3 ? "<:version omitted>" : "HTTP/1.1"; 647 for (int i = 0, size = spdyHeaders.size(); i < size; i++) { 648 ByteString name = spdyHeaders.get(i).name; 649 String value = spdyHeaders.get(i).value.utf8(); 650 if (name.equals(Header.TARGET_METHOD)) { 651 method = value; 652 } else if (name.equals(Header.TARGET_PATH)) { 653 path = value; 654 } else if (name.equals(Header.VERSION)) { 655 version = value; 656 } else { 657 httpHeaders.add(name.utf8() + ": " + value); 658 } 659 } 660 661 InputStream bodyIn = Okio.buffer(stream.getSource()).inputStream(); 662 ByteArrayOutputStream bodyOut = new ByteArrayOutputStream(); 663 byte[] buffer = new byte[8192]; 664 int count; 665 while ((count = bodyIn.read(buffer)) != -1) { 666 bodyOut.write(buffer, 0, count); 667 } 668 bodyIn.close(); 669 String requestLine = method + ' ' + path + ' ' + version; 670 List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY. 671 return new RecordedRequest(requestLine, httpHeaders, chunkSizes, bodyOut.size(), 672 bodyOut.toByteArray(), sequenceNumber.getAndIncrement(), socket); 673 } 674 675 private void writeResponse(SpdyStream stream, MockResponse response) throws IOException { 676 if (response.getSocketPolicy() == SocketPolicy.NO_RESPONSE) { 677 return; 678 } 679 List<Header> spdyHeaders = new ArrayList<Header>(); 680 String[] statusParts = response.getStatus().split(" ", 2); 681 if (statusParts.length != 2) { 682 throw new AssertionError("Unexpected status: " + response.getStatus()); 683 } 684 // TODO: constants for well-known header names. 685 spdyHeaders.add(new Header(Header.RESPONSE_STATUS, statusParts[1])); 686 if (protocol == Protocol.SPDY_3) { 687 spdyHeaders.add(new Header(Header.VERSION, statusParts[0])); 688 } 689 List<String> headers = response.getHeaders(); 690 for (int i = 0, size = headers.size(); i < size; i++) { 691 String header = headers.get(i); 692 String[] headerParts = header.split(":", 2); 693 if (headerParts.length != 2) { 694 throw new AssertionError("Unexpected header: " + header); 695 } 696 spdyHeaders.add(new Header(headerParts[0], headerParts[1])); 697 } 698 OkBuffer body = new OkBuffer(); 699 if (response.getBody() != null) { 700 body.write(response.getBody()); 701 } 702 boolean closeStreamAfterHeaders = body.size() > 0 || !response.getPushPromises().isEmpty(); 703 stream.reply(spdyHeaders, closeStreamAfterHeaders); 704 pushPromises(stream, response.getPushPromises()); 705 if (body.size() > 0) { 706 if (response.getBodyDelayTimeMs() != 0) { 707 try { 708 Thread.sleep(response.getBodyDelayTimeMs()); 709 } catch (InterruptedException e) { 710 throw new AssertionError(e); 711 } 712 } 713 BufferedSink sink = Okio.buffer(stream.getSink()); 714 if (response.getThrottleBytesPerPeriod() == Integer.MAX_VALUE) { 715 sink.write(body, body.size()); 716 sink.flush(); 717 } else { 718 while (body.size() > 0) { 719 long toWrite = Math.min(body.size(), response.getThrottleBytesPerPeriod()); 720 sink.write(body, toWrite); 721 sink.flush(); 722 try { 723 long delayMs = response.getThrottleUnit().toMillis(response.getThrottlePeriod()); 724 if (delayMs != 0) Thread.sleep(delayMs); 725 } catch (InterruptedException e) { 726 throw new AssertionError(); 727 } 728 } 729 } 730 sink.close(); 731 } else if (closeStreamAfterHeaders) { 732 stream.close(ErrorCode.NO_ERROR); 733 } 734 } 735 736 private void pushPromises(SpdyStream stream, List<PushPromise> promises) throws IOException { 737 for (PushPromise pushPromise : promises) { 738 List<Header> pushedHeaders = new ArrayList<Header>(); 739 pushedHeaders.add(new Header(stream.getConnection().getProtocol() == Protocol.SPDY_3 740 ? Header.TARGET_HOST 741 : Header.TARGET_AUTHORITY, getUrl(pushPromise.getPath()).getHost())); 742 pushedHeaders.add(new Header(Header.TARGET_METHOD, pushPromise.getMethod())); 743 pushedHeaders.add(new Header(Header.TARGET_PATH, pushPromise.getPath())); 744 for (int i = 0, size = pushPromise.getHeaders().size(); i < size; i++) { 745 String header = pushPromise.getHeaders().get(i); 746 String[] headerParts = header.split(":", 2); 747 if (headerParts.length != 2) { 748 throw new AssertionError("Unexpected header: " + header); 749 } 750 pushedHeaders.add(new Header(headerParts[0], headerParts[1].trim())); 751 } 752 String requestLine = pushPromise.getMethod() + ' ' + pushPromise.getPath() + " HTTP/1.1"; 753 List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY. 754 requestQueue.add(new RecordedRequest(requestLine, pushPromise.getHeaders(), chunkSizes, 0, 755 Util.EMPTY_BYTE_ARRAY, sequenceNumber.getAndIncrement(), socket)); 756 byte[] pushedBody = pushPromise.getResponse().getBody(); 757 SpdyStream pushedStream = 758 stream.getConnection().pushStream(stream.getId(), pushedHeaders, pushedBody.length > 0); 759 writeResponse(pushedStream, pushPromise.getResponse()); 760 } 761 } 762 } 763} 764