1/*
2 *  Licensed to the Apache Software Foundation (ASF) under one or more
3 *  contributor license agreements.  See the NOTICE file distributed with
4 *  this work for additional information regarding copyright ownership.
5 *  The ASF licenses this file to You under the Apache License, Version 2.0
6 *  (the "License"); you may not use this file except in compliance with
7 *  the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 *  Unless required by applicable law or agreed to in writing, software
12 *  distributed under the License is distributed on an "AS IS" BASIS,
13 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *  See the License for the specific language governing permissions and
15 *  limitations under the License.
16 */
17
18package com.squareup.okhttp.internal;
19
20import com.android.org.conscrypt.OpenSSLSocketImpl;
21import com.squareup.okhttp.Protocol;
22
23import org.junit.Test;
24
25import java.io.IOException;
26import java.nio.charset.StandardCharsets;
27import java.security.SecureRandom;
28import java.security.cert.CertificateException;
29import java.security.cert.X509Certificate;
30import java.util.Arrays;
31import java.util.List;
32import javax.net.ssl.HandshakeCompletedListener;
33import javax.net.ssl.SSLContext;
34import javax.net.ssl.SSLSession;
35import javax.net.ssl.SSLSocket;
36import javax.net.ssl.SSLSocketFactory;
37import javax.net.ssl.TrustManager;
38import javax.net.ssl.X509TrustManager;
39
40import static org.junit.Assert.assertArrayEquals;
41import static org.junit.Assert.assertEquals;
42import static org.junit.Assert.assertNotNull;
43import static org.junit.Assert.assertNull;
44import static org.junit.Assert.assertTrue;
45import static org.junit.Assert.fail;
46
47/**
48 * Tests for {@link Platform}.
49 */
50public class PlatformTest {
51
52  @Test
53  public void enableTlsExtensionOptionalMethods() throws Exception {
54    Platform platform = new Platform();
55
56    // Expect no error
57    TestSSLSocketImpl arbitrarySocketImpl = new TestSSLSocketImpl();
58    List<Protocol> protocols = Arrays.asList(Protocol.HTTP_1_1, Protocol.SPDY_3);
59    platform.configureTlsExtensions(arbitrarySocketImpl, "host", protocols);
60    NpnOnlySSLSocketImpl npnOnlySSLSocketImpl = new NpnOnlySSLSocketImpl();
61    platform.configureTlsExtensions(npnOnlySSLSocketImpl, "host", protocols);
62
63    FullOpenSSLSocketImpl openSslSocket = new FullOpenSSLSocketImpl();
64    platform.configureTlsExtensions(openSslSocket, "host", protocols);
65    assertTrue(openSslSocket.useSessionTickets);
66    assertEquals("host", openSslSocket.hostname);
67    assertArrayEquals(Platform.concatLengthPrefixed(protocols), openSslSocket.alpnProtocols);
68  }
69
70  @Test
71  public void getSelectedProtocol() throws Exception {
72    Platform platform = new Platform();
73    String selectedProtocol = "alpn";
74
75    TestSSLSocketImpl arbitrarySocketImpl = new TestSSLSocketImpl();
76    assertNull(platform.getSelectedProtocol(arbitrarySocketImpl));
77
78    NpnOnlySSLSocketImpl npnOnlySSLSocketImpl = new NpnOnlySSLSocketImpl();
79    assertNull(platform.getSelectedProtocol(npnOnlySSLSocketImpl));
80
81    FullOpenSSLSocketImpl openSslSocket = new FullOpenSSLSocketImpl();
82    openSslSocket.alpnProtocols = selectedProtocol.getBytes(StandardCharsets.UTF_8);
83    assertEquals(selectedProtocol, platform.getSelectedProtocol(openSslSocket));
84  }
85
86  @Test public void rootTrustIndex_notNull_viaSocketFactory() throws Exception {
87    Platform platform = new Platform();
88    SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
89    sslContext.init(null, new TrustManager[] { TRUST_NO_ONE_TRUST_MANAGER }, new SecureRandom());
90    SSLSocketFactory socketFactory = sslContext.getSocketFactory();
91    X509TrustManager trustManager = platform.trustManager(socketFactory);
92    assertNotNull(platform.trustRootIndex(trustManager));
93  }
94
95  @Test public void rootTrustIndex_notNull() throws Exception {
96    Platform platform = new Platform();
97    assertNotNull(platform.trustRootIndex(TRUST_NO_ONE_TRUST_MANAGER));
98  }
99
100  @Test public void trustManager() throws Exception {
101    Platform platform = new Platform();
102    SSLContext sslContext = SSLContext.getInstance("TLSv1.2");
103    sslContext.init(null, new TrustManager[] { TRUST_NO_ONE_TRUST_MANAGER }, new SecureRandom());
104    SSLSocketFactory socketFactory = sslContext.getSocketFactory();
105    X509TrustManager trustManager = platform.trustManager(socketFactory);
106    assertEquals(TRUST_NO_ONE_TRUST_MANAGER, trustManager);
107  }
108
109  private static final X509TrustManager TRUST_NO_ONE_TRUST_MANAGER = new X509TrustManager() {
110    @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
111            throws CertificateException {
112      throw new CertificateException();
113    }
114
115    @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
116      throw new AssertionError();
117    }
118
119    @Override public X509Certificate[] getAcceptedIssuers() {
120      return new X509Certificate[0];
121    }
122  };
123
124  private static class FullOpenSSLSocketImpl extends OpenSSLSocketImpl {
125    private boolean useSessionTickets;
126    private String hostname;
127    private byte[] alpnProtocols;
128
129    public FullOpenSSLSocketImpl() throws IOException {
130      super(null);
131    }
132
133    @Override
134    public void setUseSessionTickets(boolean useSessionTickets) {
135      this.useSessionTickets = useSessionTickets;
136    }
137
138    @Override
139    public void setHostname(String hostname) {
140      this.hostname = hostname;
141    }
142
143    @Override
144    public void setAlpnProtocols(byte[] alpnProtocols) {
145      this.alpnProtocols = alpnProtocols;
146    }
147
148    @Override
149    public byte[] getAlpnSelectedProtocol() {
150      return alpnProtocols;
151    }
152  }
153
154  // Legacy case - NPN support has been dropped.
155  private static class NpnOnlySSLSocketImpl extends TestSSLSocketImpl {
156
157    private byte[] npnProtocols;
158
159    public void setNpnProtocols(byte[] npnProtocols) {
160      this.npnProtocols = npnProtocols;
161    }
162
163    public byte[] getNpnSelectedProtocol() {
164      return npnProtocols;
165    }
166  }
167
168  private static class TestSSLSocketImpl extends SSLSocket {
169
170    @Override
171    public String[] getSupportedCipherSuites() {
172      return new String[0];
173    }
174
175    @Override
176    public String[] getEnabledCipherSuites() {
177      return new String[0];
178    }
179
180    @Override
181    public void setEnabledCipherSuites(String[] suites) {
182    }
183
184    @Override
185    public String[] getSupportedProtocols() {
186      return new String[0];
187    }
188
189    @Override
190    public String[] getEnabledProtocols() {
191      return new String[0];
192    }
193
194    @Override
195    public void setEnabledProtocols(String[] protocols) {
196    }
197
198    @Override
199    public SSLSession getSession() {
200      return null;
201    }
202
203    @Override
204    public void addHandshakeCompletedListener(HandshakeCompletedListener listener) {
205    }
206
207    @Override
208    public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) {
209    }
210
211    @Override
212    public void startHandshake() throws IOException {
213    }
214
215    @Override
216    public void setUseClientMode(boolean mode) {
217    }
218
219    @Override
220    public boolean getUseClientMode() {
221      return false;
222    }
223
224    @Override
225    public void setNeedClientAuth(boolean need) {
226    }
227
228    @Override
229    public void setWantClientAuth(boolean want) {
230    }
231
232    @Override
233    public boolean getNeedClientAuth() {
234      return false;
235    }
236
237    @Override
238    public boolean getWantClientAuth() {
239      return false;
240    }
241
242    @Override
243    public void setEnableSessionCreation(boolean flag) {
244    }
245
246    @Override
247    public boolean getEnableSessionCreation() {
248      return false;
249    }
250  }
251}
252