1/*
2 * Copyright (C) 2013 Square, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16package com.squareup.okhttp.internal.http;
17
18import com.squareup.okhttp.HttpResponseCache;
19import com.squareup.okhttp.OkHttpClient;
20import com.squareup.okhttp.Protocol;
21import com.squareup.okhttp.internal.RecordingAuthenticator;
22import com.squareup.okhttp.internal.SslContextBuilder;
23import com.squareup.okhttp.internal.Util;
24import com.squareup.okhttp.mockwebserver.MockResponse;
25import com.squareup.okhttp.mockwebserver.MockWebServer;
26import com.squareup.okhttp.mockwebserver.RecordedRequest;
27import com.squareup.okhttp.mockwebserver.SocketPolicy;
28import java.io.ByteArrayOutputStream;
29import java.io.File;
30import java.io.IOException;
31import java.io.InputStream;
32import java.io.OutputStream;
33import java.net.Authenticator;
34import java.net.CookieManager;
35import java.net.HttpURLConnection;
36import java.net.URL;
37import java.util.Arrays;
38import java.util.Collection;
39import java.util.Collections;
40import java.util.List;
41import java.util.Map;
42import java.util.UUID;
43import java.util.concurrent.CountDownLatch;
44import java.util.concurrent.ExecutorService;
45import java.util.concurrent.Executors;
46import java.util.zip.GZIPOutputStream;
47import javax.net.ssl.HostnameVerifier;
48import javax.net.ssl.SSLContext;
49import javax.net.ssl.SSLSession;
50import org.junit.After;
51import org.junit.Before;
52import org.junit.Ignore;
53import org.junit.Test;
54
55import static java.util.concurrent.TimeUnit.SECONDS;
56import static org.junit.Assert.assertArrayEquals;
57import static org.junit.Assert.assertEquals;
58import static org.junit.Assert.assertNull;
59import static org.junit.Assert.assertTrue;
60import static org.junit.Assert.fail;
61
62/** Test how SPDY interacts with HTTP features. */
63public abstract class HttpOverSpdyTest {
64
65  /** Protocol to test, for example {@link com.squareup.okhttp.Protocol#SPDY_3} */
66  private final Protocol protocol;
67  protected String hostHeader = ":host";
68
69  protected HttpOverSpdyTest(Protocol protocol){
70    this.protocol = protocol;
71  }
72
73  private static final HostnameVerifier NULL_HOSTNAME_VERIFIER = new HostnameVerifier() {
74    public boolean verify(String hostname, SSLSession session) {
75      return true;
76    }
77  };
78
79  private static final SSLContext sslContext = SslContextBuilder.localhost();
80  protected final MockWebServer server = new MockWebServer();
81  protected final String hostName = server.getHostName();
82  protected final OkHttpClient client = new OkHttpClient();
83  protected HttpURLConnection connection;
84  protected HttpResponseCache cache;
85
86  @Before public void setUp() throws Exception {
87    server.useHttps(sslContext.getSocketFactory(), false);
88    client.setProtocols(Arrays.asList(protocol, Protocol.HTTP_11));
89    client.setSslSocketFactory(sslContext.getSocketFactory());
90    client.setHostnameVerifier(NULL_HOSTNAME_VERIFIER);
91    String systemTmpDir = System.getProperty("java.io.tmpdir");
92    File cacheDir = new File(systemTmpDir, "HttpCache-" + protocol + "-" + UUID.randomUUID());
93    cache = new HttpResponseCache(cacheDir, Integer.MAX_VALUE);
94  }
95
96  @After public void tearDown() throws Exception {
97    Authenticator.setDefault(null);
98    server.shutdown();
99  }
100
101  @Test public void get() throws Exception {
102    MockResponse response = new MockResponse().setBody("ABCDE").setStatus("HTTP/1.1 200 Sweet");
103    server.enqueue(response);
104    server.play();
105
106    connection = client.open(server.getUrl("/foo"));
107    assertContent("ABCDE", connection, Integer.MAX_VALUE);
108    assertEquals(200, connection.getResponseCode());
109    assertEquals("Sweet", connection.getResponseMessage());
110
111    RecordedRequest request = server.takeRequest();
112    assertEquals("GET /foo HTTP/1.1", request.getRequestLine());
113    assertContains(request.getHeaders(), ":scheme: https");
114    assertContains(request.getHeaders(), hostHeader + ": " + hostName + ":" + server.getPort());
115  }
116
117  @Test public void emptyResponse() throws IOException {
118    server.enqueue(new MockResponse());
119    server.play();
120
121    connection = client.open(server.getUrl("/foo"));
122    assertEquals(-1, connection.getInputStream().read());
123  }
124
125  byte[] postBytes = "FGHIJ".getBytes(Util.UTF_8);
126
127  /** An output stream can be written to more than once, so we can't guess content length. */
128  @Test public void noDefaultContentLengthOnPost() throws Exception {
129    MockResponse response = new MockResponse().setBody("ABCDE");
130    server.enqueue(response);
131    server.play();
132
133    connection = client.open(server.getUrl("/foo"));
134    connection.setDoOutput(true);
135    connection.getOutputStream().write(postBytes);
136    assertContent("ABCDE", connection, Integer.MAX_VALUE);
137
138    RecordedRequest request = server.takeRequest();
139    assertEquals("POST /foo HTTP/1.1", request.getRequestLine());
140    assertArrayEquals(postBytes, request.getBody());
141    assertNull(request.getHeader("Content-Length"));
142  }
143
144  @Test public void userSuppliedContentLengthHeader() throws Exception {
145    MockResponse response = new MockResponse().setBody("ABCDE");
146    server.enqueue(response);
147    server.play();
148
149    connection = client.open(server.getUrl("/foo"));
150    connection.setRequestProperty("Content-Length", String.valueOf(postBytes.length));
151    connection.setDoOutput(true);
152    connection.getOutputStream().write(postBytes);
153    assertContent("ABCDE", connection, Integer.MAX_VALUE);
154
155    RecordedRequest request = server.takeRequest();
156    assertEquals("POST /foo HTTP/1.1", request.getRequestLine());
157    assertArrayEquals(postBytes, request.getBody());
158    assertEquals(postBytes.length, Integer.parseInt(request.getHeader("Content-Length")));
159  }
160
161  @Test public void closeAfterFlush() throws Exception {
162    MockResponse response = new MockResponse().setBody("ABCDE");
163    server.enqueue(response);
164    server.play();
165
166    connection = client.open(server.getUrl("/foo"));
167    connection.setRequestProperty("Content-Length", String.valueOf(postBytes.length));
168    connection.setDoOutput(true);
169    connection.getOutputStream().write(postBytes); // push bytes into SpdyDataOutputStream.buffer
170    connection.getOutputStream().flush(); // SpdyConnection.writeData subject to write window
171    connection.getOutputStream().close(); // SpdyConnection.writeData empty frame
172    assertContent("ABCDE", connection, Integer.MAX_VALUE);
173
174    RecordedRequest request = server.takeRequest();
175    assertEquals("POST /foo HTTP/1.1", request.getRequestLine());
176    assertArrayEquals(postBytes, request.getBody());
177    assertEquals(postBytes.length, Integer.parseInt(request.getHeader("Content-Length")));
178  }
179
180  @Test public void setFixedLengthStreamingModeSetsContentLength() throws Exception {
181    MockResponse response = new MockResponse().setBody("ABCDE");
182    server.enqueue(response);
183    server.play();
184
185    connection = client.open(server.getUrl("/foo"));
186    connection.setFixedLengthStreamingMode(postBytes.length);
187    connection.setDoOutput(true);
188    connection.getOutputStream().write(postBytes);
189    assertContent("ABCDE", connection, Integer.MAX_VALUE);
190
191    RecordedRequest request = server.takeRequest();
192    assertEquals("POST /foo HTTP/1.1", request.getRequestLine());
193    assertArrayEquals(postBytes, request.getBody());
194    assertEquals(postBytes.length, Integer.parseInt(request.getHeader("Content-Length")));
195  }
196
197  @Test public void spdyConnectionReuse() throws Exception {
198    server.enqueue(new MockResponse().setBody("ABCDEF"));
199    server.enqueue(new MockResponse().setBody("GHIJKL"));
200    server.play();
201
202    HttpURLConnection connection1 = client.open(server.getUrl("/r1"));
203    HttpURLConnection connection2 = client.open(server.getUrl("/r2"));
204    assertEquals("ABC", readAscii(connection1.getInputStream(), 3));
205    assertEquals("GHI", readAscii(connection2.getInputStream(), 3));
206    assertEquals("DEF", readAscii(connection1.getInputStream(), 3));
207    assertEquals("JKL", readAscii(connection2.getInputStream(), 3));
208    assertEquals(0, server.takeRequest().getSequenceNumber());
209    assertEquals(1, server.takeRequest().getSequenceNumber());
210  }
211
212  @Test @Ignore public void synchronousSpdyRequest() throws Exception {
213    server.enqueue(new MockResponse().setBody("A"));
214    server.enqueue(new MockResponse().setBody("A"));
215    server.play();
216
217    ExecutorService executor = Executors.newCachedThreadPool();
218    CountDownLatch countDownLatch = new CountDownLatch(2);
219    executor.execute(new SpdyRequest("/r1", countDownLatch));
220    executor.execute(new SpdyRequest("/r2", countDownLatch));
221    countDownLatch.await();
222    assertEquals(0, server.takeRequest().getSequenceNumber());
223    assertEquals(1, server.takeRequest().getSequenceNumber());
224  }
225
226  @Test public void gzippedResponseBody() throws Exception {
227    server.enqueue(new MockResponse().addHeader("Content-Encoding: gzip")
228        .setBody(gzip("ABCABCABC".getBytes(Util.UTF_8))));
229    server.play();
230    assertContent("ABCABCABC", client.open(server.getUrl("/r1")), Integer.MAX_VALUE);
231  }
232
233  @Test public void authenticate() throws Exception {
234    server.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_UNAUTHORIZED)
235        .addHeader("www-authenticate: Basic realm=\"protected area\"")
236        .setBody("Please authenticate."));
237    server.enqueue(new MockResponse().setBody("Successful auth!"));
238    server.play();
239
240    Authenticator.setDefault(new RecordingAuthenticator());
241    connection = client.open(server.getUrl("/"));
242    assertEquals("Successful auth!", readAscii(connection.getInputStream(), Integer.MAX_VALUE));
243
244    RecordedRequest denied = server.takeRequest();
245    assertContainsNoneMatching(denied.getHeaders(), "authorization: Basic .*");
246    RecordedRequest accepted = server.takeRequest();
247    assertEquals("GET / HTTP/1.1", accepted.getRequestLine());
248    assertContains(accepted.getHeaders(),
249        "authorization: Basic " + RecordingAuthenticator.BASE_64_CREDENTIALS);
250  }
251
252  @Test public void redirect() throws Exception {
253    server.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_MOVED_TEMP)
254        .addHeader("Location: /foo")
255        .setBody("This page has moved!"));
256    server.enqueue(new MockResponse().setBody("This is the new location!"));
257    server.play();
258
259    connection = client.open(server.getUrl("/"));
260    assertContent("This is the new location!", connection, Integer.MAX_VALUE);
261
262    RecordedRequest request1 = server.takeRequest();
263    assertEquals("/", request1.getPath());
264    RecordedRequest request2 = server.takeRequest();
265    assertEquals("/foo", request2.getPath());
266  }
267
268  @Test public void readAfterLastByte() throws Exception {
269    server.enqueue(new MockResponse().setBody("ABC"));
270    server.play();
271
272    connection = client.open(server.getUrl("/"));
273    InputStream in = connection.getInputStream();
274    assertEquals("ABC", readAscii(in, 3));
275    assertEquals(-1, in.read());
276    assertEquals(-1, in.read());
277  }
278
279  @Ignore // See https://github.com/square/okhttp/issues/578
280  @Test(timeout = 3000) public void readResponseHeaderTimeout() throws Exception {
281    server.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.NO_RESPONSE));
282    server.enqueue(new MockResponse().setBody("A"));
283    server.play();
284
285    connection = client.open(server.getUrl("/"));
286    connection.setReadTimeout(1000);
287    assertContent("A", connection, Integer.MAX_VALUE);
288  }
289
290  /**
291   * Test to ensure we don't  throw a read timeout on responses that are
292   * progressing.  For this case, we take a 4KiB body and throttle it to
293   * 1KiB/second.  We set the read timeout to two seconds.  If our
294   * implementation is acting correctly, it will not throw, as it is
295   * progressing.
296   */
297  @Test public void readTimeoutMoreGranularThanBodySize() throws Exception {
298    char[] body = new char[4096]; // 4KiB to read
299    Arrays.fill(body, 'y');
300    server.enqueue(new MockResponse()
301        .setBody(new String(body))
302        .throttleBody(1024, 1, SECONDS)); // slow connection 1KiB/second
303    server.play();
304
305    connection = client.open(server.getUrl("/"));
306    connection.setReadTimeout(2000); // 2 seconds to read something.
307    assertContent(new String(body), connection, Integer.MAX_VALUE);
308  }
309
310  /**
311   * Test to ensure we throw a read timeout on responses that are progressing
312   * too slowly.  For this case, we take a 2KiB body and throttle it to
313   * 1KiB/second.  We set the read timeout to half a second.  If our
314   * implementation is acting correctly, it will throw, as a byte doesn't
315   * arrive in time.
316   */
317  @Test public void readTimeoutOnSlowConnection() throws Exception {
318    char[] body = new char[2048]; // 2KiB to read
319    Arrays.fill(body, 'y');
320    server.enqueue(new MockResponse()
321        .setBody(new String(body))
322        .throttleBody(1024, 1, SECONDS)); // slow connection 1KiB/second
323    server.play();
324
325    connection = client.open(server.getUrl("/"));
326    connection.setReadTimeout(500); // half a second to read something
327    connection.connect();
328    try {
329      readAscii(connection.getInputStream(), Integer.MAX_VALUE);
330      fail("Should have timed out!");
331    } catch (IOException e){
332      assertEquals("Read timed out", e.getMessage());
333    }
334  }
335
336  @Test public void spdyConnectionTimeout() throws Exception {
337    MockResponse response = new MockResponse().setBody("A");
338    response.setBodyDelayTimeMs(1000);
339    server.enqueue(response);
340    server.play();
341
342    HttpURLConnection connection1 = client.open(server.getUrl("/"));
343    connection1.setReadTimeout(2000);
344    HttpURLConnection connection2 = client.open(server.getUrl("/"));
345    connection2.setReadTimeout(200);
346    connection1.connect();
347    connection2.connect();
348    assertContent("A", connection1, Integer.MAX_VALUE);
349  }
350
351  @Test public void responsesAreCached() throws IOException {
352    client.setOkResponseCache(cache);
353
354    server.enqueue(new MockResponse().addHeader("cache-control: max-age=60").setBody("A"));
355    server.play();
356
357    assertContent("A", client.open(server.getUrl("/")), Integer.MAX_VALUE);
358    assertEquals(1, cache.getRequestCount());
359    assertEquals(1, cache.getNetworkCount());
360    assertEquals(0, cache.getHitCount());
361    assertContent("A", client.open(server.getUrl("/")), Integer.MAX_VALUE);
362    assertContent("A", client.open(server.getUrl("/")), Integer.MAX_VALUE);
363    assertEquals(3, cache.getRequestCount());
364    assertEquals(1, cache.getNetworkCount());
365    assertEquals(2, cache.getHitCount());
366  }
367
368  @Test public void conditionalCache() throws IOException {
369    client.setOkResponseCache(cache);
370
371    server.enqueue(new MockResponse().addHeader("ETag: v1").setBody("A"));
372    server.enqueue(new MockResponse().setResponseCode(HttpURLConnection.HTTP_NOT_MODIFIED));
373    server.play();
374
375    assertContent("A", client.open(server.getUrl("/")), Integer.MAX_VALUE);
376    assertEquals(1, cache.getRequestCount());
377    assertEquals(1, cache.getNetworkCount());
378    assertEquals(0, cache.getHitCount());
379    assertContent("A", client.open(server.getUrl("/")), Integer.MAX_VALUE);
380    assertEquals(2, cache.getRequestCount());
381    assertEquals(2, cache.getNetworkCount());
382    assertEquals(1, cache.getHitCount());
383  }
384
385  @Test public void responseCachedWithoutConsumingFullBody() throws IOException {
386    client.setOkResponseCache(cache);
387
388    server.enqueue(new MockResponse().addHeader("cache-control: max-age=60").setBody("ABCD"));
389    server.enqueue(new MockResponse().addHeader("cache-control: max-age=60").setBody("EFGH"));
390    server.play();
391
392    HttpURLConnection connection1 = client.open(server.getUrl("/"));
393    InputStream in1 = connection1.getInputStream();
394    assertEquals("AB", readAscii(in1, 2));
395    in1.close();
396
397    HttpURLConnection connection2 = client.open(server.getUrl("/"));
398    InputStream in2 = connection2.getInputStream();
399    assertEquals("ABCD", readAscii(in2, Integer.MAX_VALUE));
400    in2.close();
401  }
402
403  @Test public void acceptAndTransmitCookies() throws Exception {
404    CookieManager cookieManager = new CookieManager();
405    client.setCookieHandler(cookieManager);
406    server.enqueue(
407        new MockResponse().addHeader("set-cookie: c=oreo; domain=" + server.getCookieDomain())
408            .setBody("A"));
409    server.enqueue(new MockResponse().setBody("B"));
410    server.play();
411
412    URL url = server.getUrl("/");
413    assertContent("A", client.open(url), Integer.MAX_VALUE);
414    Map<String, List<String>> requestHeaders = Collections.emptyMap();
415    assertEquals(Collections.singletonMap("Cookie", Arrays.asList("c=oreo")),
416        cookieManager.get(url.toURI(), requestHeaders));
417
418    assertContent("B", client.open(url), Integer.MAX_VALUE);
419    RecordedRequest requestA = server.takeRequest();
420    assertContainsNoneMatching(requestA.getHeaders(), "Cookie.*");
421    RecordedRequest requestB = server.takeRequest();
422    assertContains(requestB.getHeaders(), "cookie: c=oreo");
423  }
424
425  <T> void assertContains(Collection<T> collection, T value) {
426    assertTrue(collection.toString(), collection.contains(value));
427  }
428
429  void assertContent(String expected, HttpURLConnection connection, int limit)
430      throws IOException {
431    connection.connect();
432    assertEquals(expected, readAscii(connection.getInputStream(), limit));
433  }
434
435  private void assertContainsNoneMatching(List<String> headers, String pattern) {
436    for (String header : headers) {
437      if (header.matches(pattern)) {
438        fail("Header " + header + " matches " + pattern);
439      }
440    }
441  }
442
443  private String readAscii(InputStream in, int count) throws IOException {
444    StringBuilder result = new StringBuilder();
445    for (int i = 0; i < count; i++) {
446      int value = in.read();
447      if (value == -1) {
448        in.close();
449        break;
450      }
451      result.append((char) value);
452    }
453    return result.toString();
454  }
455
456  public byte[] gzip(byte[] bytes) throws IOException {
457    ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
458    OutputStream gzippedOut = new GZIPOutputStream(bytesOut);
459    gzippedOut.write(bytes);
460    gzippedOut.close();
461    return bytesOut.toByteArray();
462  }
463
464  class SpdyRequest implements Runnable {
465    String path;
466    CountDownLatch countDownLatch;
467    public SpdyRequest(String path, CountDownLatch countDownLatch) {
468      this.path = path;
469      this.countDownLatch = countDownLatch;
470    }
471
472    @Override public void run() {
473      try {
474        HttpURLConnection conn = client.open(server.getUrl(path));
475        assertEquals("A", readAscii(conn.getInputStream(), 1));
476        countDownLatch.countDown();
477      } catch (Exception e) {
478        throw new RuntimeException(e);
479      }
480    }
481  }
482}
483