1/*
2 * Copyright (C) 2014 Square, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16package com.squareup.okhttp;
17
18import com.squareup.okhttp.mockwebserver.MockResponse;
19import com.squareup.okhttp.mockwebserver.MockWebServer;
20import com.squareup.okhttp.mockwebserver.RecordedRequest;
21import java.io.IOException;
22import java.util.Arrays;
23import java.util.List;
24import java.util.Locale;
25import java.util.concurrent.BlockingQueue;
26import java.util.concurrent.LinkedBlockingQueue;
27import java.util.concurrent.SynchronousQueue;
28import java.util.concurrent.ThreadPoolExecutor;
29import java.util.concurrent.TimeUnit;
30import okio.Buffer;
31import okio.BufferedSink;
32import okio.ForwardingSink;
33import okio.ForwardingSource;
34import okio.GzipSink;
35import okio.Okio;
36import okio.Sink;
37import okio.Source;
38import org.junit.Rule;
39import org.junit.Test;
40
41import static org.junit.Assert.assertEquals;
42import static org.junit.Assert.assertNotNull;
43import static org.junit.Assert.assertNull;
44import static org.junit.Assert.assertSame;
45import static org.junit.Assert.fail;
46
47public final class InterceptorTest {
48  @Rule public MockWebServer server = new MockWebServer();
49
50  private OkHttpClient client = new OkHttpClient();
51  private RecordingCallback callback = new RecordingCallback();
52
53  @Test public void applicationInterceptorsCanShortCircuitResponses() throws Exception {
54    server.shutdown(); // Accept no connections.
55
56    Request request = new Request.Builder()
57        .url("https://localhost:1/")
58        .build();
59
60    final Response interceptorResponse = new Response.Builder()
61        .request(request)
62        .protocol(Protocol.HTTP_1_1)
63        .code(200)
64        .message("Intercepted!")
65        .body(ResponseBody.create(MediaType.parse("text/plain; charset=utf-8"), "abc"))
66        .build();
67
68    client.interceptors().add(new Interceptor() {
69      @Override public Response intercept(Chain chain) throws IOException {
70        return interceptorResponse;
71      }
72    });
73
74    Response response = client.newCall(request).execute();
75    assertSame(interceptorResponse, response);
76  }
77
78  @Test public void networkInterceptorsCannotShortCircuitResponses() throws Exception {
79    server.enqueue(new MockResponse().setResponseCode(500));
80
81    Interceptor interceptor = new Interceptor() {
82      @Override public Response intercept(Chain chain) throws IOException {
83        return new Response.Builder()
84            .request(chain.request())
85            .protocol(Protocol.HTTP_1_1)
86            .code(200)
87            .message("Intercepted!")
88            .body(ResponseBody.create(MediaType.parse("text/plain; charset=utf-8"), "abc"))
89            .build();
90      }
91    };
92    client.networkInterceptors().add(interceptor);
93
94    Request request = new Request.Builder()
95        .url(server.url("/"))
96        .build();
97
98    try {
99      client.newCall(request).execute();
100      fail();
101    } catch (IllegalStateException expected) {
102      assertEquals("network interceptor " + interceptor + " must call proceed() exactly once",
103          expected.getMessage());
104    }
105  }
106
107  @Test public void networkInterceptorsCannotCallProceedMultipleTimes() throws Exception {
108    server.enqueue(new MockResponse());
109    server.enqueue(new MockResponse());
110
111    Interceptor interceptor = new Interceptor() {
112      @Override public Response intercept(Chain chain) throws IOException {
113        chain.proceed(chain.request());
114        return chain.proceed(chain.request());
115      }
116    };
117    client.networkInterceptors().add(interceptor);
118
119    Request request = new Request.Builder()
120        .url(server.url("/"))
121        .build();
122
123    try {
124      client.newCall(request).execute();
125      fail();
126    } catch (IllegalStateException expected) {
127      assertEquals("network interceptor " + interceptor + " must call proceed() exactly once",
128          expected.getMessage());
129    }
130  }
131
132  @Test public void networkInterceptorsCannotChangeServerAddress() throws Exception {
133    server.enqueue(new MockResponse().setResponseCode(500));
134
135    Interceptor interceptor = new Interceptor() {
136      @Override public Response intercept(Chain chain) throws IOException {
137        Address address = chain.connection().getRoute().getAddress();
138        String sameHost = address.getUriHost();
139        int differentPort = address.getUriPort() + 1;
140        return chain.proceed(chain.request().newBuilder()
141            .url(HttpUrl.parse("http://" + sameHost + ":" + differentPort + "/"))
142            .build());
143      }
144    };
145    client.networkInterceptors().add(interceptor);
146
147    Request request = new Request.Builder()
148        .url(server.url("/"))
149        .build();
150
151    try {
152      client.newCall(request).execute();
153      fail();
154    } catch (IllegalStateException expected) {
155      assertEquals("network interceptor " + interceptor + " must retain the same host and port",
156          expected.getMessage());
157    }
158  }
159
160  @Test public void networkInterceptorsHaveConnectionAccess() throws Exception {
161    server.enqueue(new MockResponse());
162
163    client.networkInterceptors().add(new Interceptor() {
164      @Override public Response intercept(Chain chain) throws IOException {
165        Connection connection = chain.connection();
166        assertNotNull(connection);
167        return chain.proceed(chain.request());
168      }
169    });
170
171    Request request = new Request.Builder()
172        .url(server.url("/"))
173        .build();
174    client.newCall(request).execute();
175  }
176
177  @Test public void networkInterceptorsObserveNetworkHeaders() throws Exception {
178    server.enqueue(new MockResponse()
179        .setBody(gzip("abcabcabc"))
180        .addHeader("Content-Encoding: gzip"));
181
182    client.networkInterceptors().add(new Interceptor() {
183      @Override public Response intercept(Chain chain) throws IOException {
184        // The network request has everything: User-Agent, Host, Accept-Encoding.
185        Request networkRequest = chain.request();
186        assertNotNull(networkRequest.header("User-Agent"));
187        assertEquals(server.getHostName() + ":" + server.getPort(),
188            networkRequest.header("Host"));
189        assertNotNull(networkRequest.header("Accept-Encoding"));
190
191        // The network response also has everything, including the raw gzipped content.
192        Response networkResponse = chain.proceed(networkRequest);
193        assertEquals("gzip", networkResponse.header("Content-Encoding"));
194        return networkResponse;
195      }
196    });
197
198    Request request = new Request.Builder()
199        .url(server.url("/"))
200        .build();
201
202    // No extra headers in the application's request.
203    assertNull(request.header("User-Agent"));
204    assertNull(request.header("Host"));
205    assertNull(request.header("Accept-Encoding"));
206
207    // No extra headers in the application's response.
208    Response response = client.newCall(request).execute();
209    assertNull(request.header("Content-Encoding"));
210    assertEquals("abcabcabc", response.body().string());
211  }
212
213  @Test public void networkInterceptorsCanChangeRequestMethodFromGetToPost() throws Exception {
214    server.enqueue(new MockResponse());
215
216    client.networkInterceptors().add(new Interceptor() {
217      @Override
218      public Response intercept(Chain chain) throws IOException {
219        Request originalRequest = chain.request();
220        MediaType mediaType = MediaType.parse("text/plain");
221        RequestBody body = RequestBody.create(mediaType, "abc");
222        return chain.proceed(originalRequest.newBuilder()
223            .method("POST", body)
224            .header("Content-Type", mediaType.toString())
225            .header("Content-Length", Long.toString(body.contentLength()))
226            .build());
227      }
228    });
229
230    Request request = new Request.Builder()
231        .url(server.url("/"))
232        .get()
233        .build();
234
235    client.newCall(request).execute();
236
237    RecordedRequest recordedRequest = server.takeRequest();
238    assertEquals("POST", recordedRequest.getMethod());
239    assertEquals("abc", recordedRequest.getBody().readUtf8());
240  }
241
242  @Test public void applicationInterceptorsRewriteRequestToServer() throws Exception {
243    rewriteRequestToServer(client.interceptors());
244  }
245
246  @Test public void networkInterceptorsRewriteRequestToServer() throws Exception {
247    rewriteRequestToServer(client.networkInterceptors());
248  }
249
250  private void rewriteRequestToServer(List<Interceptor> interceptors) throws Exception {
251    server.enqueue(new MockResponse());
252
253    interceptors.add(new Interceptor() {
254      @Override public Response intercept(Chain chain) throws IOException {
255        Request originalRequest = chain.request();
256        return chain.proceed(originalRequest.newBuilder()
257            .method("POST", uppercase(originalRequest.body()))
258            .addHeader("OkHttp-Intercepted", "yep")
259            .build());
260      }
261    });
262
263    Request request = new Request.Builder()
264        .url(server.url("/"))
265        .addHeader("Original-Header", "foo")
266        .method("PUT", RequestBody.create(MediaType.parse("text/plain"), "abc"))
267        .build();
268
269    client.newCall(request).execute();
270
271    RecordedRequest recordedRequest = server.takeRequest();
272    assertEquals("ABC", recordedRequest.getBody().readUtf8());
273    assertEquals("foo", recordedRequest.getHeader("Original-Header"));
274    assertEquals("yep", recordedRequest.getHeader("OkHttp-Intercepted"));
275    assertEquals("POST", recordedRequest.getMethod());
276  }
277
278  @Test public void applicationInterceptorsRewriteResponseFromServer() throws Exception {
279    rewriteResponseFromServer(client.interceptors());
280  }
281
282  @Test public void networkInterceptorsRewriteResponseFromServer() throws Exception {
283    rewriteResponseFromServer(client.networkInterceptors());
284  }
285
286  private void rewriteResponseFromServer(List<Interceptor> interceptors) throws Exception {
287    server.enqueue(new MockResponse()
288        .addHeader("Original-Header: foo")
289        .setBody("abc"));
290
291    interceptors.add(new Interceptor() {
292      @Override public Response intercept(Chain chain) throws IOException {
293        Response originalResponse = chain.proceed(chain.request());
294        return originalResponse.newBuilder()
295            .body(uppercase(originalResponse.body()))
296            .addHeader("OkHttp-Intercepted", "yep")
297            .build();
298      }
299    });
300
301    Request request = new Request.Builder()
302        .url(server.url("/"))
303        .build();
304
305    Response response = client.newCall(request).execute();
306    assertEquals("ABC", response.body().string());
307    assertEquals("yep", response.header("OkHttp-Intercepted"));
308    assertEquals("foo", response.header("Original-Header"));
309  }
310
311  @Test public void multipleApplicationInterceptors() throws Exception {
312    multipleInterceptors(client.interceptors());
313  }
314
315  @Test public void multipleNetworkInterceptors() throws Exception {
316    multipleInterceptors(client.networkInterceptors());
317  }
318
319  private void multipleInterceptors(List<Interceptor> interceptors) throws Exception {
320    server.enqueue(new MockResponse());
321
322    interceptors.add(new Interceptor() {
323      @Override public Response intercept(Chain chain) throws IOException {
324        Request originalRequest = chain.request();
325        Response originalResponse = chain.proceed(originalRequest.newBuilder()
326            .addHeader("Request-Interceptor", "Android") // 1. Added first.
327            .build());
328        return originalResponse.newBuilder()
329            .addHeader("Response-Interceptor", "Donut") // 4. Added last.
330            .build();
331      }
332    });
333    interceptors.add(new Interceptor() {
334      @Override public Response intercept(Chain chain) throws IOException {
335        Request originalRequest = chain.request();
336        Response originalResponse = chain.proceed(originalRequest.newBuilder()
337            .addHeader("Request-Interceptor", "Bob") // 2. Added second.
338            .build());
339        return originalResponse.newBuilder()
340            .addHeader("Response-Interceptor", "Cupcake") // 3. Added third.
341            .build();
342      }
343    });
344
345    Request request = new Request.Builder()
346        .url(server.url("/"))
347        .build();
348
349    Response response = client.newCall(request).execute();
350    assertEquals(Arrays.asList("Cupcake", "Donut"),
351        response.headers("Response-Interceptor"));
352
353    RecordedRequest recordedRequest = server.takeRequest();
354    assertEquals(Arrays.asList("Android", "Bob"),
355        recordedRequest.getHeaders().values("Request-Interceptor"));
356  }
357
358  @Test public void asyncApplicationInterceptors() throws Exception {
359    asyncInterceptors(client.interceptors());
360  }
361
362  @Test public void asyncNetworkInterceptors() throws Exception {
363    asyncInterceptors(client.networkInterceptors());
364  }
365
366  private void asyncInterceptors(List<Interceptor> interceptors) throws Exception {
367    server.enqueue(new MockResponse());
368
369    interceptors.add(new Interceptor() {
370      @Override public Response intercept(Chain chain) throws IOException {
371        Response originalResponse = chain.proceed(chain.request());
372        return originalResponse.newBuilder()
373            .addHeader("OkHttp-Intercepted", "yep")
374            .build();
375      }
376    });
377
378    Request request = new Request.Builder()
379        .url(server.url("/"))
380        .build();
381    client.newCall(request).enqueue(callback);
382
383    callback.await(request.httpUrl())
384        .assertCode(200)
385        .assertHeader("OkHttp-Intercepted", "yep");
386  }
387
388  @Test public void applicationInterceptorsCanMakeMultipleRequestsToServer() throws Exception {
389    server.enqueue(new MockResponse().setBody("a"));
390    server.enqueue(new MockResponse().setBody("b"));
391
392    client.interceptors().add(new Interceptor() {
393      @Override public Response intercept(Chain chain) throws IOException {
394        Response response1 = chain.proceed(chain.request());
395        response1.body().close();
396        return chain.proceed(chain.request());
397      }
398    });
399
400    Request request = new Request.Builder()
401        .url(server.url("/"))
402        .build();
403
404    Response response = client.newCall(request).execute();
405    assertEquals(response.body().string(), "b");
406  }
407
408  /** Make sure interceptors can interact with the OkHttp client. */
409  @Test public void interceptorMakesAnUnrelatedRequest() throws Exception {
410    server.enqueue(new MockResponse().setBody("a")); // Fetched by interceptor.
411    server.enqueue(new MockResponse().setBody("b")); // Fetched directly.
412
413    client.interceptors().add(new Interceptor() {
414      @Override public Response intercept(Chain chain) throws IOException {
415        if (chain.request().url().getPath().equals("/b")) {
416          Request requestA = new Request.Builder()
417              .url(server.url("/a"))
418              .build();
419          Response responseA = client.newCall(requestA).execute();
420          assertEquals("a", responseA.body().string());
421        }
422
423        return chain.proceed(chain.request());
424      }
425    });
426
427    Request requestB = new Request.Builder()
428        .url(server.url("/b"))
429        .build();
430    Response responseB = client.newCall(requestB).execute();
431    assertEquals("b", responseB.body().string());
432  }
433
434  /** Make sure interceptors can interact with the OkHttp client asynchronously. */
435  @Test public void interceptorMakesAnUnrelatedAsyncRequest() throws Exception {
436    server.enqueue(new MockResponse().setBody("a")); // Fetched by interceptor.
437    server.enqueue(new MockResponse().setBody("b")); // Fetched directly.
438
439    client.interceptors().add(new Interceptor() {
440      @Override public Response intercept(Chain chain) throws IOException {
441        if (chain.request().url().getPath().equals("/b")) {
442          Request requestA = new Request.Builder()
443              .url(server.url("/a"))
444              .build();
445
446          try {
447            RecordingCallback callbackA = new RecordingCallback();
448            client.newCall(requestA).enqueue(callbackA);
449            callbackA.await(requestA.httpUrl()).assertBody("a");
450          } catch (Exception e) {
451            throw new RuntimeException(e);
452          }
453        }
454
455        return chain.proceed(chain.request());
456      }
457    });
458
459    Request requestB = new Request.Builder()
460        .url(server.url("/b"))
461        .build();
462    RecordingCallback callbackB = new RecordingCallback();
463    client.newCall(requestB).enqueue(callbackB);
464    callbackB.await(requestB.httpUrl()).assertBody("b");
465  }
466
467  @Test public void applicationkInterceptorThrowsRuntimeExceptionSynchronous() throws Exception {
468    interceptorThrowsRuntimeExceptionSynchronous(client.interceptors());
469  }
470
471  @Test public void networkInterceptorThrowsRuntimeExceptionSynchronous() throws Exception {
472    interceptorThrowsRuntimeExceptionSynchronous(client.networkInterceptors());
473  }
474
475  /**
476   * When an interceptor throws an unexpected exception, synchronous callers can catch it and deal
477   * with it.
478   *
479   * TODO(jwilson): test that resources are not leaked when this happens.
480   */
481  private void interceptorThrowsRuntimeExceptionSynchronous(
482      List<Interceptor> interceptors) throws Exception {
483    interceptors.add(new Interceptor() {
484      @Override public Response intercept(Chain chain) throws IOException {
485        throw new RuntimeException("boom!");
486      }
487    });
488
489    Request request = new Request.Builder()
490        .url(server.url("/"))
491        .build();
492
493    try {
494      client.newCall(request).execute();
495      fail();
496    } catch (RuntimeException expected) {
497      assertEquals("boom!", expected.getMessage());
498    }
499  }
500
501  @Test public void networkInterceptorModifiedRequestIsReturned() throws IOException {
502    server.enqueue(new MockResponse());
503
504    Interceptor modifyHeaderInterceptor = new Interceptor() {
505      @Override public Response intercept(Chain chain) throws IOException {
506        return chain.proceed(chain.request().newBuilder()
507          .header("User-Agent", "intercepted request")
508          .build());
509      }
510    };
511
512    client.networkInterceptors().add(modifyHeaderInterceptor);
513
514    Request request = new Request.Builder()
515        .url(server.url("/"))
516        .header("User-Agent", "user request")
517        .build();
518
519    Response response = client.newCall(request).execute();
520    assertNotNull(response.request().header("User-Agent"));
521    assertEquals("user request", response.request().header("User-Agent"));
522    assertEquals("intercepted request", response.networkResponse().request().header("User-Agent"));
523  }
524
525  @Test public void applicationInterceptorThrowsRuntimeExceptionAsynchronous() throws Exception {
526    interceptorThrowsRuntimeExceptionAsynchronous(client.interceptors());
527  }
528
529  @Test public void networkInterceptorThrowsRuntimeExceptionAsynchronous() throws Exception {
530    interceptorThrowsRuntimeExceptionAsynchronous(client.networkInterceptors());
531  }
532
533  /**
534   * When an interceptor throws an unexpected exception, asynchronous callers are left hanging. The
535   * exception goes to the uncaught exception handler.
536   *
537   * TODO(jwilson): test that resources are not leaked when this happens.
538   */
539  private void interceptorThrowsRuntimeExceptionAsynchronous(
540        List<Interceptor> interceptors) throws Exception {
541    interceptors.add(new Interceptor() {
542      @Override public Response intercept(Chain chain) throws IOException {
543        throw new RuntimeException("boom!");
544      }
545    });
546
547    ExceptionCatchingExecutor executor = new ExceptionCatchingExecutor();
548    client.setDispatcher(new Dispatcher(executor));
549
550    Request request = new Request.Builder()
551        .url(server.url("/"))
552        .build();
553    client.newCall(request).enqueue(callback);
554
555    assertEquals("boom!", executor.takeException().getMessage());
556  }
557
558  @Test public void applicationInterceptorReturnsNull() throws Exception {
559    server.enqueue(new MockResponse());
560
561    Interceptor interceptor = new Interceptor() {
562      @Override public Response intercept(Chain chain) throws IOException {
563        chain.proceed(chain.request());
564        return null;
565      }
566    };
567    client.interceptors().add(interceptor);
568
569    ExceptionCatchingExecutor executor = new ExceptionCatchingExecutor();
570    client.setDispatcher(new Dispatcher(executor));
571
572    Request request = new Request.Builder()
573        .url(server.url("/"))
574        .build();
575    try {
576      client.newCall(request).execute();
577      fail();
578    } catch (NullPointerException expected) {
579      assertEquals("application interceptor " + interceptor
580          + " returned null", expected.getMessage());
581    }
582  }
583
584  @Test public void networkInterceptorReturnsNull() throws Exception {
585    server.enqueue(new MockResponse());
586
587    Interceptor interceptor = new Interceptor() {
588      @Override public Response intercept(Chain chain) throws IOException {
589        chain.proceed(chain.request());
590        return null;
591      }
592    };
593    client.networkInterceptors().add(interceptor);
594
595    ExceptionCatchingExecutor executor = new ExceptionCatchingExecutor();
596    client.setDispatcher(new Dispatcher(executor));
597
598    Request request = new Request.Builder()
599        .url(server.url("/"))
600        .build();
601    try {
602      client.newCall(request).execute();
603      fail();
604    } catch (NullPointerException expected) {
605      assertEquals("network interceptor " + interceptor + " returned null", expected.getMessage());
606    }
607  }
608
609  private RequestBody uppercase(final RequestBody original) {
610    return new RequestBody() {
611      @Override public MediaType contentType() {
612        return original.contentType();
613      }
614
615      @Override public long contentLength() throws IOException {
616        return original.contentLength();
617      }
618
619      @Override public void writeTo(BufferedSink sink) throws IOException {
620        Sink uppercase = uppercase(sink);
621        BufferedSink bufferedSink = Okio.buffer(uppercase);
622        original.writeTo(bufferedSink);
623        bufferedSink.emit();
624      }
625    };
626  }
627
628  private Sink uppercase(final BufferedSink original) {
629    return new ForwardingSink(original) {
630      @Override public void write(Buffer source, long byteCount) throws IOException {
631        original.writeUtf8(source.readUtf8(byteCount).toUpperCase(Locale.US));
632      }
633    };
634  }
635
636  static ResponseBody uppercase(ResponseBody original) throws IOException {
637    return ResponseBody.create(original.contentType(), original.contentLength(),
638        Okio.buffer(uppercase(original.source())));
639  }
640
641  private static Source uppercase(final Source original) {
642    return new ForwardingSource(original) {
643      @Override public long read(Buffer sink, long byteCount) throws IOException {
644        Buffer mixedCase = new Buffer();
645        long count = original.read(mixedCase, byteCount);
646        sink.writeUtf8(mixedCase.readUtf8().toUpperCase(Locale.US));
647        return count;
648      }
649    };
650  }
651
652  private Buffer gzip(String data) throws IOException {
653    Buffer result = new Buffer();
654    BufferedSink sink = Okio.buffer(new GzipSink(result));
655    sink.writeUtf8(data);
656    sink.close();
657    return result;
658  }
659
660  /** Catches exceptions that are otherwise headed for the uncaught exception handler. */
661  private static class ExceptionCatchingExecutor extends ThreadPoolExecutor {
662    private final BlockingQueue<Exception> exceptions = new LinkedBlockingQueue<>();
663
664    public ExceptionCatchingExecutor() {
665      super(1, 1, 0, TimeUnit.SECONDS, new SynchronousQueue<Runnable>());
666    }
667
668    @Override public void execute(final Runnable runnable) {
669      super.execute(new Runnable() {
670        @Override public void run() {
671          try {
672            runnable.run();
673          } catch (Exception e) {
674            exceptions.add(e);
675          }
676        }
677      });
678    }
679
680    public Exception takeException() throws InterruptedException {
681      return exceptions.take();
682    }
683  }
684}
685