SSLSocketTest.java revision c9461f39290f815f560f2ec50e9ccde5ff4eb8f7
1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package libcore.javax.net.ssl;
18
19import java.io.EOFException;
20import java.io.IOException;
21import java.io.InputStream;
22import java.io.OutputStream;
23import java.lang.Thread.UncaughtExceptionHandler;
24import java.lang.reflect.Method;
25import java.net.InetAddress;
26import java.net.InetSocketAddress;
27import java.net.ServerSocket;
28import java.net.Socket;
29import java.net.SocketException;
30import java.net.SocketTimeoutException;
31import java.security.Principal;
32import java.security.PrivateKey;
33import java.security.cert.Certificate;
34import java.security.cert.CertificateException;
35import java.security.cert.X509Certificate;
36import java.util.Arrays;
37import java.util.concurrent.Callable;
38import java.util.concurrent.ExecutorService;
39import java.util.concurrent.Executors;
40import java.util.concurrent.Future;
41import java.util.concurrent.ThreadFactory;
42import java.util.concurrent.TimeUnit;
43import javax.crypto.SecretKey;
44import javax.crypto.spec.SecretKeySpec;
45import javax.net.ServerSocketFactory;
46import javax.net.ssl.HandshakeCompletedEvent;
47import javax.net.ssl.HandshakeCompletedListener;
48import javax.net.ssl.KeyManager;
49import javax.net.ssl.SSLContext;
50import javax.net.ssl.SSLException;
51import javax.net.ssl.SSLHandshakeException;
52import javax.net.ssl.SSLParameters;
53import javax.net.ssl.SSLPeerUnverifiedException;
54import javax.net.ssl.SSLProtocolException;
55import javax.net.ssl.SSLServerSocket;
56import javax.net.ssl.SSLSession;
57import javax.net.ssl.SSLSocket;
58import javax.net.ssl.SSLSocketFactory;
59import javax.net.ssl.TrustManager;
60import javax.net.ssl.X509KeyManager;
61import javax.net.ssl.X509TrustManager;
62import junit.framework.TestCase;
63import libcore.io.IoUtils;
64import libcore.io.Streams;
65import libcore.java.security.StandardNames;
66import libcore.java.security.TestKeyStore;
67
68public class SSLSocketTest extends TestCase {
69
70    public void test_SSLSocket_defaultConfiguration() throws Exception {
71        SSLDefaultConfigurationAsserts.assertSSLSocket(
72                (SSLSocket) SSLSocketFactory.getDefault().createSocket());
73    }
74
75    public void test_SSLSocket_getSupportedCipherSuites_returnsCopies() throws Exception {
76        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
77        SSLSocket ssl = (SSLSocket) sf.createSocket();
78        assertNotSame(ssl.getSupportedCipherSuites(), ssl.getSupportedCipherSuites());
79    }
80
81    public void test_SSLSocket_getSupportedCipherSuites_connect() throws Exception {
82        // note the rare usage of non-RSA keys
83        TestKeyStore testKeyStore = new TestKeyStore.Builder()
84                .keyAlgorithms("RSA", "DSA", "EC", "EC_RSA")
85                .aliasPrefix("rsa-dsa-ec")
86                .ca(true)
87                .build();
88        StringBuilder error = new StringBuilder();
89        test_SSLSocket_getSupportedCipherSuites_connect(testKeyStore, error);
90        if (error.length() > 0) {
91            throw new Exception("One or more problems in "
92                    + "test_SSLSocket_getSupportedCipherSuites_connect:\n" + error);
93        }
94    }
95    private void test_SSLSocket_getSupportedCipherSuites_connect(TestKeyStore testKeyStore,
96                                                                 StringBuilder error)
97            throws Exception {
98
99        String clientToServerString = "this is sent from the client to the server...";
100        String serverToClientString = "... and this from the server to the client";
101        byte[] clientToServer = clientToServerString.getBytes();
102        byte[] serverToClient = serverToClientString.getBytes();
103
104        KeyManager pskKeyManager = PSKKeyManagerProxy.getConscryptPSKKeyManager(
105                new PSKKeyManagerProxy() {
106            @Override
107            protected SecretKey getKey(String identityHint, String identity, Socket socket) {
108                return new SecretKeySpec("Just an arbitrary key".getBytes(), "RAW");
109            }
110        });
111        TestSSLContext c = TestSSLContext.createWithAdditionalKeyManagers(
112                testKeyStore, testKeyStore,
113                new KeyManager[] {pskKeyManager}, new KeyManager[] {pskKeyManager});
114
115        String[] cipherSuites = c.clientContext.getSocketFactory().getSupportedCipherSuites();
116
117        for (String cipherSuite : cipherSuites) {
118            boolean errorExpected = StandardNames.IS_RI && cipherSuite.endsWith("_SHA256");
119            try {
120                /*
121                 * TLS_EMPTY_RENEGOTIATION_INFO_SCSV cannot be used on
122                 * its own, but instead in conjunction with other
123                 * cipher suites.
124                 */
125                if (cipherSuite.equals(StandardNames.CIPHER_SUITE_SECURE_RENEGOTIATION)) {
126                    continue;
127                }
128                /*
129                 * Kerberos cipher suites require external setup. See "Kerberos Requirements" in
130                 * https://java.sun.com/j2se/1.5.0/docs/guide/security/jsse/JSSERefGuide.html
131                 * #KRBRequire
132                 */
133                if (cipherSuite.startsWith("TLS_KRB5_")) {
134                    continue;
135                }
136
137                String[] clientCipherSuiteArray = new String[] {
138                        cipherSuite,
139                        StandardNames.CIPHER_SUITE_SECURE_RENEGOTIATION };
140                String[] serverCipherSuiteArray = clientCipherSuiteArray;
141                SSLSocket[] pair = TestSSLSocketPair.connect(c,
142                                                             clientCipherSuiteArray,
143                                                             serverCipherSuiteArray);
144
145                SSLSocket server = pair[0];
146                SSLSocket client = pair[1];
147
148                // Check that the client can read the message sent by the server
149                server.getOutputStream().write(serverToClient);
150                byte[] clientFromServer = new byte[serverToClient.length];
151                Streams.readFully(client.getInputStream(), clientFromServer);
152                assertEquals(serverToClientString, new String(clientFromServer));
153
154                // Check that the server can read the message sent by the client
155                client.getOutputStream().write(clientToServer);
156                byte[] serverFromClient = new byte[clientToServer.length];
157                Streams.readFully(server.getInputStream(), serverFromClient);
158                assertEquals(clientToServerString, new String(serverFromClient));
159
160                // Check that the server and the client cannot read anything else
161                // (reads should time out)
162                server.setSoTimeout(10);
163                try {
164                  server.getInputStream().read();
165                  fail();
166                } catch (IOException expected) {}
167                client.setSoTimeout(10);
168                try {
169                  client.getInputStream().read();
170                  fail();
171                } catch (IOException expected) {}
172
173                client.close();
174                server.close();
175                assertFalse(errorExpected);
176            } catch (Exception maybeExpected) {
177                if (!errorExpected) {
178                    String message = ("Problem trying to connect cipher suite " + cipherSuite);
179                    System.out.println(message);
180                    maybeExpected.printStackTrace();
181                    error.append(message);
182                    error.append('\n');
183                }
184            }
185        }
186        c.close();
187    }
188
189    public void test_SSLSocket_getEnabledCipherSuites_returnsCopies() throws Exception {
190        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
191        SSLSocket ssl = (SSLSocket) sf.createSocket();
192        assertNotSame(ssl.getEnabledCipherSuites(), ssl.getEnabledCipherSuites());
193    }
194
195    public void test_SSLSocket_setEnabledCipherSuites() throws Exception {
196        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
197        SSLSocket ssl = (SSLSocket) sf.createSocket();
198
199        try {
200            ssl.setEnabledCipherSuites(null);
201            fail();
202        } catch (IllegalArgumentException expected) {
203        }
204        try {
205            ssl.setEnabledCipherSuites(new String[1]);
206            fail();
207        } catch (IllegalArgumentException expected) {
208        }
209        try {
210            ssl.setEnabledCipherSuites(new String[] { "Bogus" } );
211            fail();
212        } catch (IllegalArgumentException expected) {
213        }
214
215        ssl.setEnabledCipherSuites(new String[0]);
216        ssl.setEnabledCipherSuites(ssl.getEnabledCipherSuites());
217        ssl.setEnabledCipherSuites(ssl.getSupportedCipherSuites());
218
219        // Check that setEnabledCipherSuites affects getEnabledCipherSuites
220        String[] cipherSuites = new String[] { ssl.getSupportedCipherSuites()[0] };
221        ssl.setEnabledCipherSuites(cipherSuites);
222        assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
223    }
224
225    public void test_SSLSocket_getSupportedProtocols_returnsCopies() throws Exception {
226        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
227        SSLSocket ssl = (SSLSocket) sf.createSocket();
228        assertNotSame(ssl.getSupportedProtocols(), ssl.getSupportedProtocols());
229    }
230
231    public void test_SSLSocket_getEnabledProtocols_returnsCopies() throws Exception {
232        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
233        SSLSocket ssl = (SSLSocket) sf.createSocket();
234        assertNotSame(ssl.getEnabledProtocols(), ssl.getEnabledProtocols());
235    }
236
237    public void test_SSLSocket_setEnabledProtocols() throws Exception {
238        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
239        SSLSocket ssl = (SSLSocket) sf.createSocket();
240
241        try {
242            ssl.setEnabledProtocols(null);
243            fail();
244        } catch (IllegalArgumentException expected) {
245        }
246        try {
247            ssl.setEnabledProtocols(new String[1]);
248            fail();
249        } catch (IllegalArgumentException expected) {
250        }
251        try {
252            ssl.setEnabledProtocols(new String[] { "Bogus" } );
253            fail();
254        } catch (IllegalArgumentException expected) {
255        }
256        ssl.setEnabledProtocols(new String[0]);
257        ssl.setEnabledProtocols(ssl.getEnabledProtocols());
258        ssl.setEnabledProtocols(ssl.getSupportedProtocols());
259
260        // Check that setEnabledProtocols affects getEnabledProtocols
261        for (String protocol : ssl.getSupportedProtocols()) {
262            if ("SSLv2Hello".equals(protocol)) {
263                try {
264                    ssl.setEnabledProtocols(new String[] { protocol });
265                    fail("Should fail when SSLv2Hello is set by itself");
266                } catch (IllegalArgumentException expected) {}
267            } else {
268                String[] protocols = new String[] { protocol };
269                ssl.setEnabledProtocols(protocols);
270                assertEquals(Arrays.deepToString(protocols),
271                        Arrays.deepToString(ssl.getEnabledProtocols()));
272            }
273        }
274    }
275
276    public void test_SSLSocket_getSession() throws Exception {
277        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
278        SSLSocket ssl = (SSLSocket) sf.createSocket();
279        SSLSession session = ssl.getSession();
280        assertNotNull(session);
281        assertFalse(session.isValid());
282    }
283
284    public void test_SSLSocket_getHandshakeSession() throws Exception {
285        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
286        SSLSocket ssl = (SSLSocket) sf.createSocket();
287        SSLSession session = ssl.getHandshakeSession();
288        assertNull(session);
289    }
290
291    public void test_SSLSocket_startHandshake() throws Exception {
292        final TestSSLContext c = TestSSLContext.create();
293        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
294                                                                                       c.port);
295        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
296        ExecutorService executor = Executors.newSingleThreadExecutor();
297        Future<Void> future = executor.submit(new Callable<Void>() {
298            @Override public Void call() throws Exception {
299                server.startHandshake();
300                assertNotNull(server.getSession());
301                assertNull(server.getHandshakeSession());
302                try {
303                    server.getSession().getPeerCertificates();
304                    fail();
305                } catch (SSLPeerUnverifiedException expected) {
306                }
307                Certificate[] localCertificates = server.getSession().getLocalCertificates();
308                assertNotNull(localCertificates);
309                TestKeyStore.assertChainLength(localCertificates);
310                assertNotNull(localCertificates[0]);
311                TestSSLContext.assertServerCertificateChain(c.serverTrustManager,
312                                                            localCertificates);
313                TestSSLContext.assertCertificateInKeyStore(localCertificates[0],
314                                                           c.serverKeyStore);
315                return null;
316            }
317        });
318        executor.shutdown();
319        client.startHandshake();
320        assertNotNull(client.getSession());
321        assertNull(client.getSession().getLocalCertificates());
322        Certificate[] peerCertificates = client.getSession().getPeerCertificates();
323        assertNotNull(peerCertificates);
324        TestKeyStore.assertChainLength(peerCertificates);
325        assertNotNull(peerCertificates[0]);
326        TestSSLContext.assertServerCertificateChain(c.clientTrustManager,
327                                                    peerCertificates);
328        TestSSLContext.assertCertificateInKeyStore(peerCertificates[0], c.serverKeyStore);
329        future.get();
330        client.close();
331        server.close();
332        c.close();
333    }
334
335    private static final class SSLServerSessionIdCallable implements Callable<byte[]> {
336        private final SSLSocket server;
337        private SSLServerSessionIdCallable(SSLSocket server) {
338            this.server = server;
339        }
340        @Override public byte[] call() throws Exception {
341            server.startHandshake();
342            assertNotNull(server.getSession());
343            assertNotNull(server.getSession().getId());
344            return server.getSession().getId();
345        }
346    }
347
348    public void test_SSLSocket_confirmSessionReuse() throws Exception {
349        final TestSSLContext c = TestSSLContext.create();
350        final ExecutorService executor = Executors.newSingleThreadExecutor();
351
352        final SSLSocket client1 = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
353                                                                                       c.port);
354        final SSLSocket server1 = (SSLSocket) c.serverSocket.accept();
355        final Future<byte[]> future1 = executor.submit(new SSLServerSessionIdCallable(server1));
356        client1.startHandshake();
357        assertNotNull(client1.getSession());
358        assertNotNull(client1.getSession().getId());
359        final byte[] clientSessionId1 = client1.getSession().getId();
360        final byte[] serverSessionId1 = future1.get();
361        assertTrue(Arrays.equals(clientSessionId1, serverSessionId1));
362        client1.close();
363        server1.close();
364
365        final SSLSocket client2 = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
366                                                                                       c.port);
367        final SSLSocket server2 = (SSLSocket) c.serverSocket.accept();
368        final Future<byte[]> future2 = executor.submit(new SSLServerSessionIdCallable(server2));
369        client2.startHandshake();
370        assertNotNull(client2.getSession());
371        assertNotNull(client2.getSession().getId());
372        final byte[] clientSessionId2 = client2.getSession().getId();
373        final byte[] serverSessionId2 = future2.get();
374        assertTrue(Arrays.equals(clientSessionId2, serverSessionId2));
375        client2.close();
376        server2.close();
377
378        assertTrue(Arrays.equals(clientSessionId1, clientSessionId2));
379
380        executor.shutdown();
381        c.close();
382    }
383
384    public void test_SSLSocket_startHandshake_noKeyStore() throws Exception {
385        TestSSLContext c = TestSSLContext.create(null, null, null, null, null, null, null, null,
386                                                 SSLContext.getDefault(), SSLContext.getDefault());
387        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
388                                                                                       c.port);
389        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
390        ExecutorService executor = Executors.newSingleThreadExecutor();
391        Future<Void> future = executor.submit(new Callable<Void>() {
392            @Override public Void call() throws Exception {
393                try {
394                    server.startHandshake();
395                    fail();
396                } catch (SSLHandshakeException expected) {
397                }
398                return null;
399            }
400        });
401        executor.shutdown();
402        try {
403            client.startHandshake();
404            fail();
405        } catch (SSLHandshakeException expected) {
406        }
407        future.get();
408        server.close();
409        client.close();
410        c.close();
411    }
412
413    public void test_SSLSocket_startHandshake_noClientCertificate() throws Exception {
414        TestSSLContext c = TestSSLContext.create();
415        SSLContext serverContext = c.serverContext;
416        SSLContext clientContext = c.clientContext;
417        SSLSocket client = (SSLSocket)
418            clientContext.getSocketFactory().createSocket(c.host, c.port);
419        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
420        ExecutorService executor = Executors.newSingleThreadExecutor();
421        Future<Void> future = executor.submit(new Callable<Void>() {
422            @Override public Void call() throws Exception {
423                server.startHandshake();
424                return null;
425            }
426        });
427        executor.shutdown();
428        client.startHandshake();
429        future.get();
430        client.close();
431        server.close();
432        c.close();
433    }
434
435    public void test_SSLSocket_HandshakeCompletedListener() throws Exception {
436        final TestSSLContext c = TestSSLContext.create();
437        final SSLSocket client = (SSLSocket)
438                c.clientContext.getSocketFactory().createSocket(c.host, c.port);
439        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
440        ExecutorService executor = Executors.newSingleThreadExecutor();
441        Future<Void> future = executor.submit(new Callable<Void>() {
442            @Override public Void call() throws Exception {
443                server.startHandshake();
444                return null;
445            }
446        });
447        executor.shutdown();
448        final boolean[] handshakeCompletedListenerCalled = new boolean[1];
449        client.addHandshakeCompletedListener(new HandshakeCompletedListener() {
450            public void handshakeCompleted(HandshakeCompletedEvent event) {
451                try {
452                    SSLSession session = event.getSession();
453                    String cipherSuite = event.getCipherSuite();
454                    Certificate[] localCertificates = event.getLocalCertificates();
455                    Certificate[] peerCertificates = event.getPeerCertificates();
456                    javax.security.cert.X509Certificate[] peerCertificateChain
457                            = event.getPeerCertificateChain();
458                    Principal peerPrincipal = event.getPeerPrincipal();
459                    Principal localPrincipal = event.getLocalPrincipal();
460                    Socket socket = event.getSocket();
461
462                    if (false) {
463                        System.out.println("Session=" + session);
464                        System.out.println("CipherSuite=" + cipherSuite);
465                        System.out.println("LocalCertificates="
466                                + Arrays.toString(localCertificates));
467                        System.out.println("PeerCertificates="
468                                + Arrays.toString(peerCertificates));
469                        System.out.println("PeerCertificateChain="
470                                + Arrays.toString(peerCertificateChain));
471                        System.out.println("PeerPrincipal=" + peerPrincipal);
472                        System.out.println("LocalPrincipal=" + localPrincipal);
473                        System.out.println("Socket=" + socket);
474                    }
475
476                    assertNotNull(session);
477                    byte[] id = session.getId();
478                    assertNotNull(id);
479                    assertEquals(32, id.length);
480                    assertNotNull(c.clientContext.getClientSessionContext().getSession(id));
481
482                    assertNotNull(cipherSuite);
483                    assertTrue(Arrays.asList(
484                            client.getEnabledCipherSuites()).contains(cipherSuite));
485                    assertTrue(Arrays.asList(
486                            c.serverSocket.getEnabledCipherSuites()).contains(cipherSuite));
487
488                    assertNull(localCertificates);
489
490                    assertNotNull(peerCertificates);
491                    TestKeyStore.assertChainLength(peerCertificates);
492                    assertNotNull(peerCertificates[0]);
493                    TestSSLContext.assertServerCertificateChain(c.clientTrustManager,
494                                                                peerCertificates);
495                    TestSSLContext.assertCertificateInKeyStore(peerCertificates[0],
496                                                               c.serverKeyStore);
497
498                    assertNotNull(peerCertificateChain);
499                    TestKeyStore.assertChainLength(peerCertificateChain);
500                    assertNotNull(peerCertificateChain[0]);
501                    TestSSLContext.assertCertificateInKeyStore(
502                        peerCertificateChain[0].getSubjectDN(), c.serverKeyStore);
503
504                    assertNotNull(peerPrincipal);
505                    TestSSLContext.assertCertificateInKeyStore(peerPrincipal, c.serverKeyStore);
506
507                    assertNull(localPrincipal);
508
509                    assertNotNull(socket);
510                    assertSame(client, socket);
511
512                    assertTrue(socket instanceof SSLSocket);
513                    assertNull(((SSLSocket) socket).getHandshakeSession());
514
515                    synchronized (handshakeCompletedListenerCalled) {
516                        handshakeCompletedListenerCalled[0] = true;
517                        handshakeCompletedListenerCalled.notify();
518                    }
519                    handshakeCompletedListenerCalled[0] = true;
520                } catch (RuntimeException e) {
521                    throw e;
522                } catch (Exception e) {
523                    throw new RuntimeException(e);
524                }
525            }
526        });
527        client.startHandshake();
528        future.get();
529        if (!TestSSLContext.sslServerSocketSupportsSessionTickets()) {
530            assertNotNull(c.serverContext.getServerSessionContext().getSession(
531                    client.getSession().getId()));
532        }
533        synchronized (handshakeCompletedListenerCalled) {
534            while (!handshakeCompletedListenerCalled[0]) {
535                handshakeCompletedListenerCalled.wait();
536            }
537        }
538        client.close();
539        server.close();
540        c.close();
541    }
542
543    private static final class TestUncaughtExceptionHandler implements UncaughtExceptionHandler {
544        Throwable actualException;
545        @Override public void uncaughtException(Thread thread, Throwable ex) {
546            assertNull(actualException);
547            actualException = ex;
548        }
549    }
550
551    public void test_SSLSocket_HandshakeCompletedListener_RuntimeException() throws Exception {
552        final Thread self = Thread.currentThread();
553        final UncaughtExceptionHandler original = self.getUncaughtExceptionHandler();
554
555        final RuntimeException expectedException = new RuntimeException("expected");
556        final TestUncaughtExceptionHandler test = new TestUncaughtExceptionHandler();
557        self.setUncaughtExceptionHandler(test);
558
559        final TestSSLContext c = TestSSLContext.create();
560        final SSLSocket client = (SSLSocket)
561                c.clientContext.getSocketFactory().createSocket(c.host, c.port);
562        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
563        ExecutorService executor = Executors.newSingleThreadExecutor();
564        Future<Void> future = executor.submit(new Callable<Void>() {
565            @Override public Void call() throws Exception {
566                server.startHandshake();
567                return null;
568            }
569        });
570        executor.shutdown();
571        client.addHandshakeCompletedListener(new HandshakeCompletedListener() {
572            public void handshakeCompleted(HandshakeCompletedEvent event) {
573                throw expectedException;
574            }
575        });
576        client.startHandshake();
577        future.get();
578        client.close();
579        server.close();
580        c.close();
581
582        assertSame(expectedException, test.actualException);
583        self.setUncaughtExceptionHandler(original);
584    }
585
586    public void test_SSLSocket_getUseClientMode() throws Exception {
587        TestSSLContext c = TestSSLContext.create();
588        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
589                                                                                       c.port);
590        SSLSocket server = (SSLSocket) c.serverSocket.accept();
591        assertTrue(client.getUseClientMode());
592        assertFalse(server.getUseClientMode());
593        client.close();
594        server.close();
595        c.close();
596    }
597
598    public void test_SSLSocket_setUseClientMode() throws Exception {
599        // client is client, server is server
600        test_SSLSocket_setUseClientMode(true, false);
601        // client is server, server is client
602        test_SSLSocket_setUseClientMode(true, false);
603        // both are client
604        try {
605            test_SSLSocket_setUseClientMode(true, true);
606            fail();
607        } catch (SSLProtocolException expected) {
608            assertTrue(StandardNames.IS_RI);
609        } catch (SSLHandshakeException expected) {
610            assertFalse(StandardNames.IS_RI);
611        }
612
613        // both are server
614        try {
615            test_SSLSocket_setUseClientMode(false, false);
616            fail();
617        } catch (SocketTimeoutException expected) {
618        }
619    }
620
621    private void test_SSLSocket_setUseClientMode(final boolean clientClientMode,
622                                                 final boolean serverClientMode)
623            throws Exception {
624        TestSSLContext c = TestSSLContext.create();
625        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
626                                                                                       c.port);
627        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
628
629        ExecutorService executor = Executors.newSingleThreadExecutor();
630        Future<IOException> future = executor.submit(new Callable<IOException>() {
631            @Override public IOException call() throws Exception {
632                try {
633                    if (!serverClientMode) {
634                        server.setSoTimeout(1 * 1000);
635                    }
636                    server.setUseClientMode(serverClientMode);
637                    server.startHandshake();
638                    return null;
639                } catch (SSLHandshakeException e) {
640                    return e;
641                } catch (SocketTimeoutException e) {
642                    return e;
643                }
644            }
645        });
646        executor.shutdown();
647        if (!clientClientMode) {
648            client.setSoTimeout(1 * 1000);
649        }
650        client.setUseClientMode(clientClientMode);
651        client.startHandshake();
652        IOException ioe = future.get();
653        if (ioe != null) {
654            throw ioe;
655        }
656        client.close();
657        server.close();
658        c.close();
659    }
660
661    public void test_SSLSocket_setUseClientMode_afterHandshake() throws Exception {
662
663        // can't set after handshake
664        TestSSLSocketPair pair = TestSSLSocketPair.create();
665        try {
666            pair.server.setUseClientMode(false);
667            fail();
668        } catch (IllegalArgumentException expected) {
669        }
670        try {
671            pair.client.setUseClientMode(false);
672            fail();
673        } catch (IllegalArgumentException expected) {
674        }
675    }
676
677    public void test_SSLSocket_untrustedServer() throws Exception {
678        TestSSLContext c = TestSSLContext.create(TestKeyStore.getClientCA2(),
679                                                 TestKeyStore.getServer());
680        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
681                                                                                       c.port);
682        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
683        ExecutorService executor = Executors.newSingleThreadExecutor();
684        Future<Void> future = executor.submit(new Callable<Void>() {
685            @Override public Void call() throws Exception {
686                try {
687                    server.startHandshake();
688                    fail();
689                } catch (SSLHandshakeException expected) {
690                }
691                return null;
692            }
693        });
694        executor.shutdown();
695        try {
696            client.startHandshake();
697            fail();
698        } catch (SSLHandshakeException expected) {
699            assertTrue(expected.getCause() instanceof CertificateException);
700        }
701        future.get();
702        client.close();
703        server.close();
704        c.close();
705    }
706
707    public void test_SSLSocket_clientAuth() throws Exception {
708        TestSSLContext c = TestSSLContext.create(TestKeyStore.getClientCertificate(),
709                                                 TestKeyStore.getServer());
710        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
711                                                                                       c.port);
712        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
713        ExecutorService executor = Executors.newSingleThreadExecutor();
714        Future<Void> future = executor.submit(new Callable<Void>() {
715            @Override public Void call() throws Exception {
716                assertFalse(server.getWantClientAuth());
717                assertFalse(server.getNeedClientAuth());
718
719                // confirm turning one on by itself
720                server.setWantClientAuth(true);
721                assertTrue(server.getWantClientAuth());
722                assertFalse(server.getNeedClientAuth());
723
724                // confirm turning setting on toggles the other
725                server.setNeedClientAuth(true);
726                assertFalse(server.getWantClientAuth());
727                assertTrue(server.getNeedClientAuth());
728
729                // confirm toggling back
730                server.setWantClientAuth(true);
731                assertTrue(server.getWantClientAuth());
732                assertFalse(server.getNeedClientAuth());
733
734                server.startHandshake();
735                return null;
736            }
737        });
738        executor.shutdown();
739        client.startHandshake();
740        assertNotNull(client.getSession().getLocalCertificates());
741        TestKeyStore.assertChainLength(client.getSession().getLocalCertificates());
742        TestSSLContext.assertClientCertificateChain(c.clientTrustManager,
743                                                    client.getSession().getLocalCertificates());
744        future.get();
745        client.close();
746        server.close();
747        c.close();
748    }
749
750    public void test_SSLSocket_clientAuth_bogusAlias() throws Exception {
751        TestSSLContext c = TestSSLContext.create();
752        SSLContext clientContext = SSLContext.getInstance("TLS");
753        X509KeyManager keyManager = new X509KeyManager() {
754            @Override public String chooseClientAlias(String[] keyType,
755                                                      Principal[] issuers,
756                                                      Socket socket) {
757                return "bogus";
758            }
759            @Override public String chooseServerAlias(String keyType,
760                                                      Principal[] issuers,
761                                                      Socket socket) {
762                throw new AssertionError();
763            }
764            @Override public X509Certificate[] getCertificateChain(String alias) {
765                // return null for "bogus" alias
766                return null;
767            }
768            @Override public String[] getClientAliases(String keyType, Principal[] issuers) {
769                throw new AssertionError();
770            }
771            @Override public String[] getServerAliases(String keyType, Principal[] issuers) {
772                throw new AssertionError();
773            }
774            @Override public PrivateKey getPrivateKey(String alias) {
775                // return null for "bogus" alias
776                return null;
777            }
778        };
779        clientContext.init(new KeyManager[] { keyManager },
780                           new TrustManager[] { c.clientTrustManager },
781                           null);
782        SSLSocket client = (SSLSocket) clientContext.getSocketFactory().createSocket(c.host,
783                                                                                     c.port);
784        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
785        ExecutorService executor = Executors.newSingleThreadExecutor();
786        Future<Void> future = executor.submit(new Callable<Void>() {
787            @Override public Void call() throws Exception {
788                try {
789                    server.setNeedClientAuth(true);
790                    server.startHandshake();
791                    fail();
792                } catch (SSLHandshakeException expected) {
793                }
794                return null;
795            }
796        });
797
798        executor.shutdown();
799        try {
800            client.startHandshake();
801            fail();
802        } catch (SSLHandshakeException expected) {
803            // before we would get a NullPointerException from passing
804            // due to the null PrivateKey return by the X509KeyManager.
805        }
806        future.get();
807        client.close();
808        server.close();
809        c.close();
810    }
811
812    public void test_SSLSocket_TrustManagerRuntimeException() throws Exception {
813        TestSSLContext c = TestSSLContext.create();
814        SSLContext clientContext = SSLContext.getInstance("TLS");
815        X509TrustManager trustManager = new X509TrustManager() {
816            @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
817                    throws CertificateException {
818                throw new AssertionError();
819            }
820            @Override public void checkServerTrusted(X509Certificate[] chain, String authType)
821                    throws CertificateException {
822                throw new RuntimeException();  // throw a RuntimeException from custom TrustManager
823            }
824            @Override public X509Certificate[] getAcceptedIssuers() {
825                throw new AssertionError();
826            }
827        };
828        clientContext.init(null, new TrustManager[] { trustManager }, null);
829        SSLSocket client = (SSLSocket) clientContext.getSocketFactory().createSocket(c.host,
830                                                                                     c.port);
831        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
832        ExecutorService executor = Executors.newSingleThreadExecutor();
833        Future<Void> future = executor.submit(new Callable<Void>() {
834            @Override public Void call() throws Exception {
835                try {
836                    server.startHandshake();
837                    fail();
838                } catch (SSLHandshakeException expected) {
839                }
840                return null;
841            }
842        });
843
844        executor.shutdown();
845        try {
846            client.startHandshake();
847            fail();
848        } catch (SSLHandshakeException expected) {
849            // before we would get a RuntimeException from checkServerTrusted.
850        }
851        future.get();
852        client.close();
853        server.close();
854        c.close();
855    }
856
857    public void test_SSLSocket_getEnableSessionCreation() throws Exception {
858        TestSSLContext c = TestSSLContext.create();
859        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
860                                                                                       c.port);
861        SSLSocket server = (SSLSocket) c.serverSocket.accept();
862        assertTrue(client.getEnableSessionCreation());
863        assertTrue(server.getEnableSessionCreation());
864        client.close();
865        server.close();
866        c.close();
867    }
868
869    public void test_SSLSocket_setEnableSessionCreation_server() throws Exception {
870        TestSSLContext c = TestSSLContext.create();
871        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
872                                                                                       c.port);
873        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
874        ExecutorService executor = Executors.newSingleThreadExecutor();
875        Future<Void> future = executor.submit(new Callable<Void>() {
876            @Override public Void call() throws Exception {
877                server.setEnableSessionCreation(false);
878                try {
879                    server.startHandshake();
880                    fail();
881                } catch (SSLException expected) {
882                }
883                return null;
884            }
885        });
886        executor.shutdown();
887        try {
888            client.startHandshake();
889            fail();
890        } catch (SSLException expected) {
891        }
892        future.get();
893        client.close();
894        server.close();
895        c.close();
896    }
897
898    public void test_SSLSocket_setEnableSessionCreation_client() throws Exception {
899        TestSSLContext c = TestSSLContext.create();
900        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host,
901                                                                                       c.port);
902        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
903        ExecutorService executor = Executors.newSingleThreadExecutor();
904        Future<Void> future = executor.submit(new Callable<Void>() {
905            @Override public Void call() throws Exception {
906                try {
907                    server.startHandshake();
908                    fail();
909                } catch (SSLException expected) {
910                }
911                return null;
912            }
913        });
914        executor.shutdown();
915        client.setEnableSessionCreation(false);
916        try {
917            client.startHandshake();
918            fail();
919        } catch (SSLException expected) {
920        }
921        future.get();
922        client.close();
923        server.close();
924        c.close();
925    }
926
927    public void test_SSLSocket_getSSLParameters() throws Exception {
928        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
929        SSLSocket ssl = (SSLSocket) sf.createSocket();
930
931        SSLParameters p = ssl.getSSLParameters();
932        assertNotNull(p);
933
934        String[] cipherSuites = p.getCipherSuites();
935        assertNotSame(cipherSuites, ssl.getEnabledCipherSuites());
936        assertEquals(Arrays.asList(cipherSuites), Arrays.asList(ssl.getEnabledCipherSuites()));
937
938        String[] protocols = p.getProtocols();
939        assertNotSame(protocols, ssl.getEnabledProtocols());
940        assertEquals(Arrays.asList(protocols), Arrays.asList(ssl.getEnabledProtocols()));
941
942        assertEquals(p.getWantClientAuth(), ssl.getWantClientAuth());
943        assertEquals(p.getNeedClientAuth(), ssl.getNeedClientAuth());
944
945        assertNull(p.getEndpointIdentificationAlgorithm());
946        p.setEndpointIdentificationAlgorithm(null);
947        assertNull(p.getEndpointIdentificationAlgorithm());
948        p.setEndpointIdentificationAlgorithm("HTTPS");
949        assertEquals("HTTPS", p.getEndpointIdentificationAlgorithm());
950        p.setEndpointIdentificationAlgorithm("FOO");
951        assertEquals("FOO", p.getEndpointIdentificationAlgorithm());
952    }
953
954    public void test_SSLSocket_setSSLParameters() throws Exception {
955        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
956        SSLSocket ssl = (SSLSocket) sf.createSocket();
957        String[] defaultCipherSuites = ssl.getEnabledCipherSuites();
958        String[] defaultProtocols = ssl.getEnabledProtocols();
959        String[] supportedCipherSuites = ssl.getSupportedCipherSuites();
960        String[] supportedProtocols = ssl.getSupportedProtocols();
961
962        {
963            SSLParameters p = new SSLParameters();
964            ssl.setSSLParameters(p);
965            assertEquals(Arrays.asList(defaultCipherSuites),
966                         Arrays.asList(ssl.getEnabledCipherSuites()));
967            assertEquals(Arrays.asList(defaultProtocols),
968                         Arrays.asList(ssl.getEnabledProtocols()));
969        }
970
971        {
972            SSLParameters p = new SSLParameters(supportedCipherSuites,
973                                                supportedProtocols);
974            ssl.setSSLParameters(p);
975            assertEquals(Arrays.asList(supportedCipherSuites),
976                         Arrays.asList(ssl.getEnabledCipherSuites()));
977            assertEquals(Arrays.asList(supportedProtocols),
978                         Arrays.asList(ssl.getEnabledProtocols()));
979        }
980        {
981            SSLParameters p = new SSLParameters();
982
983            p.setNeedClientAuth(true);
984            assertFalse(ssl.getNeedClientAuth());
985            assertFalse(ssl.getWantClientAuth());
986            ssl.setSSLParameters(p);
987            assertTrue(ssl.getNeedClientAuth());
988            assertFalse(ssl.getWantClientAuth());
989
990            p.setWantClientAuth(true);
991            assertTrue(ssl.getNeedClientAuth());
992            assertFalse(ssl.getWantClientAuth());
993            ssl.setSSLParameters(p);
994            assertFalse(ssl.getNeedClientAuth());
995            assertTrue(ssl.getWantClientAuth());
996
997            p.setWantClientAuth(false);
998            assertFalse(ssl.getNeedClientAuth());
999            assertTrue(ssl.getWantClientAuth());
1000            ssl.setSSLParameters(p);
1001            assertFalse(ssl.getNeedClientAuth());
1002            assertFalse(ssl.getWantClientAuth());
1003        }
1004    }
1005
1006    public void test_SSLSocket_close() throws Exception {
1007        TestSSLSocketPair pair = TestSSLSocketPair.create();
1008        SSLSocket server = pair.server;
1009        SSLSocket client = pair.client;
1010        assertFalse(server.isClosed());
1011        assertFalse(client.isClosed());
1012        InputStream input = client.getInputStream();
1013        OutputStream output = client.getOutputStream();
1014        server.close();
1015        client.close();
1016        assertTrue(server.isClosed());
1017        assertTrue(client.isClosed());
1018
1019        // close after close is okay...
1020        server.close();
1021        client.close();
1022
1023        // ...so are a lot of other operations...
1024        HandshakeCompletedListener l = new HandshakeCompletedListener () {
1025            public void handshakeCompleted(HandshakeCompletedEvent e) {}
1026        };
1027        client.addHandshakeCompletedListener(l);
1028        assertNotNull(client.getEnabledCipherSuites());
1029        assertNotNull(client.getEnabledProtocols());
1030        client.getEnableSessionCreation();
1031        client.getNeedClientAuth();
1032        assertNotNull(client.getSession());
1033        assertNotNull(client.getSSLParameters());
1034        assertNotNull(client.getSupportedProtocols());
1035        client.getUseClientMode();
1036        client.getWantClientAuth();
1037        client.removeHandshakeCompletedListener(l);
1038        client.setEnabledCipherSuites(new String[0]);
1039        client.setEnabledProtocols(new String[0]);
1040        client.setEnableSessionCreation(false);
1041        client.setNeedClientAuth(false);
1042        client.setSSLParameters(client.getSSLParameters());
1043        client.setWantClientAuth(false);
1044
1045        // ...but some operations are expected to give SocketException...
1046        try {
1047            client.startHandshake();
1048            fail();
1049        } catch (SocketException expected) {
1050        }
1051        try {
1052            client.getInputStream();
1053            fail();
1054        } catch (SocketException expected) {
1055        }
1056        try {
1057            client.getOutputStream();
1058            fail();
1059        } catch (SocketException expected) {
1060        }
1061        try {
1062            input.read();
1063            fail();
1064        } catch (SocketException expected) {
1065        }
1066        try {
1067            input.read(null, -1, -1);
1068            fail();
1069        } catch (NullPointerException expected) {
1070            assertTrue(StandardNames.IS_RI);
1071        } catch (SocketException expected) {
1072            assertFalse(StandardNames.IS_RI);
1073        }
1074        try {
1075            output.write(-1);
1076            fail();
1077        } catch (SocketException expected) {
1078        }
1079        try {
1080            output.write(null, -1, -1);
1081            fail();
1082        } catch (NullPointerException expected) {
1083            assertTrue(StandardNames.IS_RI);
1084        } catch (SocketException expected) {
1085            assertFalse(StandardNames.IS_RI);
1086        }
1087
1088        // ... and one gives IllegalArgumentException
1089        try {
1090            client.setUseClientMode(false);
1091            fail();
1092        } catch (IllegalArgumentException expected) {
1093        }
1094
1095        pair.close();
1096    }
1097
1098    /**
1099     * b/3350645 Test to confirm that an SSLSocket.close() performing
1100     * an SSL_shutdown does not throw an IOException if the peer
1101     * socket has been closed.
1102     */
1103    public void test_SSLSocket_shutdownCloseOnClosedPeer() throws Exception {
1104        TestSSLContext c = TestSSLContext.create();
1105        final Socket underlying = new Socket(c.host, c.port);
1106        final SSLSocket wrapping = (SSLSocket)
1107                c.clientContext.getSocketFactory().createSocket(underlying,
1108                                                                c.host.getHostName(),
1109                                                                c.port,
1110                                                                false);
1111        ExecutorService executor = Executors.newSingleThreadExecutor();
1112        Future<Void> clientFuture = executor.submit(new Callable<Void>() {
1113            @Override public Void call() throws Exception {
1114                wrapping.startHandshake();
1115                wrapping.getOutputStream().write(42);
1116                // close the underlying socket,
1117                // so that no SSL shutdown is sent
1118                underlying.close();
1119                wrapping.close();
1120                return null;
1121            }
1122        });
1123        executor.shutdown();
1124
1125        SSLSocket server = (SSLSocket) c.serverSocket.accept();
1126        server.startHandshake();
1127        server.getInputStream().read();
1128        // wait for thread to finish so we know client is closed.
1129        clientFuture.get();
1130        // close should cause an SSL_shutdown which will fail
1131        // because the peer has closed, but it shouldn't throw.
1132        server.close();
1133    }
1134
1135    public void test_SSLSocket_endpointIdentification_Success() throws Exception {
1136        final TestSSLContext c = TestSSLContext.create();
1137        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket();
1138        SSLParameters p = client.getSSLParameters();
1139        p.setEndpointIdentificationAlgorithm("HTTPS");
1140        client.connect(new InetSocketAddress(c.host, c.port));
1141        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
1142        ExecutorService executor = Executors.newSingleThreadExecutor();
1143        Future<Void> future = executor.submit(new Callable<Void>() {
1144            @Override public Void call() throws Exception {
1145                server.startHandshake();
1146                assertNotNull(server.getSession());
1147                try {
1148                    server.getSession().getPeerCertificates();
1149                    fail();
1150                } catch (SSLPeerUnverifiedException expected) {
1151                }
1152                Certificate[] localCertificates = server.getSession().getLocalCertificates();
1153                assertNotNull(localCertificates);
1154                TestKeyStore.assertChainLength(localCertificates);
1155                assertNotNull(localCertificates[0]);
1156                TestSSLContext.assertCertificateInKeyStore(localCertificates[0],
1157                                                           c.serverKeyStore);
1158                return null;
1159            }
1160        });
1161        executor.shutdown();
1162        client.startHandshake();
1163        assertNotNull(client.getSession());
1164        assertNull(client.getSession().getLocalCertificates());
1165        Certificate[] peerCertificates = client.getSession().getPeerCertificates();
1166        assertNotNull(peerCertificates);
1167        TestKeyStore.assertChainLength(peerCertificates);
1168        assertNotNull(peerCertificates[0]);
1169        TestSSLContext.assertCertificateInKeyStore(peerCertificates[0], c.serverKeyStore);
1170        future.get();
1171        client.close();
1172        server.close();
1173        c.close();
1174    }
1175
1176    public void test_SSLSocket_endpointIdentification_Failure() throws Exception {
1177
1178        final TestSSLContext c = TestSSLContext.create();
1179        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(
1180                InetAddress.getByName("127.0.0.2"), c.port);
1181        SSLParameters p = client.getSSLParameters();
1182        p.setEndpointIdentificationAlgorithm("HTTPS");
1183        client.setSSLParameters(p);
1184        // client.connect(new InetSocketAddress(c.host, c.port));
1185        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
1186        ExecutorService executor = Executors.newSingleThreadExecutor();
1187        Future<Void> future = executor.submit(new Callable<Void>() {
1188            @Override public Void call() throws Exception {
1189                try {
1190                    server.startHandshake();
1191                    fail("Should receive SSLHandshakeException as server");
1192                } catch (SSLHandshakeException expected) {
1193                }
1194                return null;
1195            }
1196        });
1197        executor.shutdown();
1198        try {
1199            client.startHandshake();
1200            fail("Should throw when hostname does not match expected");
1201        } catch (SSLHandshakeException expected) {
1202        } finally {
1203            try {
1204                future.get();
1205            } finally {
1206                client.close();
1207                server.close();
1208                c.close();
1209            }
1210        }
1211    }
1212
1213    public void test_SSLSocket_setSoTimeout_basic() throws Exception {
1214        ServerSocket listening = new ServerSocket(0);
1215
1216        Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
1217        assertEquals(0, underlying.getSoTimeout());
1218
1219        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
1220        Socket wrapping = sf.createSocket(underlying, null, -1, false);
1221        assertEquals(0, wrapping.getSoTimeout());
1222
1223        // setting wrapper sets underlying and ...
1224        int expectedTimeoutMillis = 1000;  // 10 was too small because it was affected by rounding
1225        wrapping.setSoTimeout(expectedTimeoutMillis);
1226        assertEquals(expectedTimeoutMillis, wrapping.getSoTimeout());
1227        assertEquals(expectedTimeoutMillis, underlying.getSoTimeout());
1228
1229        // ... getting wrapper inspects underlying
1230        underlying.setSoTimeout(0);
1231        assertEquals(0, wrapping.getSoTimeout());
1232        assertEquals(0, underlying.getSoTimeout());
1233    }
1234
1235    public void test_SSLSocket_setSoTimeout_wrapper() throws Exception {
1236        if (StandardNames.IS_RI) {
1237            // RI cannot handle this case
1238            return;
1239        }
1240        ServerSocket listening = new ServerSocket(0);
1241
1242        // setSoTimeout applies to read, not connect, so connect first
1243        Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
1244        Socket server = listening.accept();
1245
1246        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
1247        Socket clientWrapping = sf.createSocket(underlying, null, -1, false);
1248
1249        underlying.setSoTimeout(1);
1250        try {
1251            clientWrapping.getInputStream().read();
1252            fail();
1253        } catch (SocketTimeoutException expected) {
1254        }
1255
1256        clientWrapping.close();
1257        server.close();
1258        underlying.close();
1259        listening.close();
1260    }
1261
1262    public void test_SSLSocket_setSoWriteTimeout() throws Exception {
1263        if (StandardNames.IS_RI) {
1264            // RI does not support write timeout on sockets
1265            return;
1266        }
1267
1268        final TestSSLContext c = TestSSLContext.create();
1269        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket();
1270
1271        // Try to make the client SO_SNDBUF size as small as possible
1272        // (it can default to 512k or even megabytes).  Note that
1273        // socket(7) says that the kernel will double the request to
1274        // leave room for its own book keeping and that the minimal
1275        // value will be 2048. Also note that tcp(7) says the value
1276        // needs to be set before connect(2).
1277        int sendBufferSize = 1024;
1278        client.setSendBufferSize(sendBufferSize);
1279        sendBufferSize = client.getSendBufferSize();
1280
1281        // In jb-mr2 it was found that we need to also set SO_RCVBUF
1282        // to a minimal size or the write would not block. While
1283        // tcp(2) says the value has to be set before listen(2), it
1284        // seems fine to set it before accept(2).
1285        final int recvBufferSize = 128;
1286        c.serverSocket.setReceiveBufferSize(recvBufferSize);
1287
1288        client.connect(new InetSocketAddress(c.host, c.port));
1289
1290        final SSLSocket server = (SSLSocket) c.serverSocket.accept();
1291        ExecutorService executor = Executors.newSingleThreadExecutor();
1292        Future<Void> future = executor.submit(new Callable<Void>() {
1293            @Override public Void call() throws Exception {
1294                server.startHandshake();
1295                return null;
1296            }
1297        });
1298        executor.shutdown();
1299        client.startHandshake();
1300
1301        // Reflection is used so this can compile on the RI
1302        String expectedClassName = "com.android.org.conscrypt.OpenSSLSocketImpl";
1303        Class actualClass = client.getClass();
1304        assertEquals(expectedClassName, actualClass.getName());
1305        Method setSoWriteTimeout = actualClass.getMethod("setSoWriteTimeout",
1306                                                         new Class[] { Integer.TYPE });
1307        setSoWriteTimeout.invoke(client, 1);
1308
1309
1310        try {
1311            // Add extra space to the write to exceed the send buffer
1312            // size and cause the write to block.
1313            final int extra = 1;
1314            client.getOutputStream().write(new byte[sendBufferSize + extra]);
1315            fail();
1316        } catch (SocketTimeoutException expected) {
1317        }
1318
1319        future.get();
1320        client.close();
1321        server.close();
1322        c.close();
1323    }
1324
1325    public void test_SSLSocket_reusedNpnSocket() throws Exception {
1326        if (StandardNames.IS_RI) {
1327            // RI does not support NPN/ALPN
1328            return;
1329        }
1330
1331        byte[] npnProtocols = new byte[] {
1332                8, 'h', 't', 't', 'p', '/', '1', '.', '1'
1333        };
1334
1335        final TestSSLContext c = TestSSLContext.create();
1336        SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket();
1337
1338        // Reflection is used so this can compile on the RI
1339        String expectedClassName = "com.android.org.conscrypt.OpenSSLSocketImpl";
1340        Class<?> actualClass = client.getClass();
1341        assertEquals(expectedClassName, actualClass.getName());
1342        Method setNpnProtocols = actualClass.getMethod("setNpnProtocols", byte[].class);
1343
1344        ExecutorService executor = Executors.newSingleThreadExecutor();
1345
1346        // First connection with NPN set on client and server
1347        {
1348            setNpnProtocols.invoke(client, npnProtocols);
1349            client.connect(new InetSocketAddress(c.host, c.port));
1350
1351            final SSLSocket server = (SSLSocket) c.serverSocket.accept();
1352            assertEquals(expectedClassName, server.getClass().getName());
1353            setNpnProtocols.invoke(server, npnProtocols);
1354
1355            Future<Void> future = executor.submit(new Callable<Void>() {
1356                @Override
1357                public Void call() throws Exception {
1358                    server.startHandshake();
1359                    return null;
1360                }
1361            });
1362            client.startHandshake();
1363
1364            future.get();
1365            client.close();
1366            server.close();
1367        }
1368
1369        // Second connection with client NPN already set on the SSL context, but
1370        // without server NPN set.
1371        {
1372            SSLServerSocket serverSocket = (SSLServerSocket) c.serverContext
1373                    .getServerSocketFactory().createServerSocket(0);
1374            InetAddress host = InetAddress.getLocalHost();
1375            int port = serverSocket.getLocalPort();
1376
1377            client = (SSLSocket) c.clientContext.getSocketFactory().createSocket();
1378            client.connect(new InetSocketAddress(host, port));
1379
1380            final SSLSocket server = (SSLSocket) serverSocket.accept();
1381
1382            Future<Void> future = executor.submit(new Callable<Void>() {
1383                @Override
1384                public Void call() throws Exception {
1385                    server.startHandshake();
1386                    return null;
1387                }
1388            });
1389            client.startHandshake();
1390
1391            future.get();
1392            client.close();
1393            server.close();
1394            serverSocket.close();
1395        }
1396
1397        c.close();
1398    }
1399
1400    public void test_SSLSocket_interrupt() throws Exception {
1401        test_SSLSocket_interrupt_case(true, true);
1402        test_SSLSocket_interrupt_case(true, false);
1403        test_SSLSocket_interrupt_case(false, true);
1404        test_SSLSocket_interrupt_case(false, false);
1405    }
1406
1407    private void test_SSLSocket_interrupt_case(boolean readUnderlying, boolean closeUnderlying)
1408            throws Exception {
1409
1410        ServerSocket listening = new ServerSocket(0);
1411
1412        Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
1413        Socket server = listening.accept();
1414
1415        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
1416        Socket clientWrapping = sf.createSocket(underlying, null, -1, true);
1417
1418        final Socket toRead = (readUnderlying) ? underlying : clientWrapping;
1419        final Socket toClose = (closeUnderlying) ? underlying : clientWrapping;
1420
1421        ExecutorService executor = Executors.newSingleThreadExecutor();
1422        Future<Void> future = executor.submit(new Callable<Void>() {
1423            @Override public Void call() throws Exception {
1424                Thread.sleep(1 * 1000);
1425                toClose.close();
1426                return null;
1427            }
1428        });
1429        executor.shutdown();
1430        try {
1431            toRead.setSoTimeout(5 * 1000);
1432            toRead.getInputStream().read();
1433            fail();
1434        } catch (SocketTimeoutException e) {
1435            throw e;
1436        } catch (SocketException expected) {
1437        }
1438        future.get();
1439
1440        server.close();
1441        underlying.close();
1442        listening.close();
1443    }
1444
1445    /**
1446     * b/7014266 Test to confirm that an SSLSocket.close() on one
1447     * thread will interrupt another thread blocked reading on the same
1448     * socket.
1449     */
1450    public void test_SSLSocket_interrupt_read() throws Exception {
1451        TestSSLContext c = TestSSLContext.create();
1452        final Socket underlying = new Socket(c.host, c.port);
1453        final SSLSocket wrapping = (SSLSocket)
1454                c.clientContext.getSocketFactory().createSocket(underlying,
1455                                                                c.host.getHostName(),
1456                                                                c.port,
1457                                                                false);
1458
1459        // Create our own thread group so we can inspect the stack state later.
1460        final ThreadGroup clientGroup = new ThreadGroup("client");
1461        ExecutorService executor = Executors.newSingleThreadExecutor(new ThreadFactory() {
1462            @Override
1463            public Thread newThread(Runnable r) {
1464                return new Thread(clientGroup, r);
1465            }
1466        });
1467
1468        Future<Void> clientFuture = executor.submit(new Callable<Void>() {
1469            @Override public Void call() throws Exception {
1470                try {
1471                    wrapping.startHandshake();
1472                    assertFalse(StandardNames.IS_RI);
1473                    wrapping.setSoTimeout(5 * 1000);
1474                    assertEquals(-1, wrapping.getInputStream().read());
1475                } catch (Exception e) {
1476                    assertTrue(StandardNames.IS_RI);
1477                }
1478                return null;
1479            }
1480        });
1481        executor.shutdown();
1482
1483        SSLSocket server = (SSLSocket) c.serverSocket.accept();
1484        server.startHandshake();
1485
1486        /*
1487         * Wait for the client to at least be in the "read" method before
1488         * calling close()
1489         */
1490        Thread[] threads = new Thread[1];
1491        clientGroup.enumerate(threads);
1492        if (threads[0] != null) {
1493            boolean clientInRead = false;
1494            while (!clientInRead) {
1495                StackTraceElement[] elements = threads[0].getStackTrace();
1496                for (StackTraceElement element : elements) {
1497                    if ("read".equals(element.getMethodName())) {
1498                        clientInRead = true;
1499                        break;
1500                    }
1501                }
1502            }
1503        }
1504
1505        wrapping.close();
1506        clientFuture.get();
1507        server.close();
1508    }
1509
1510    public void test_TestSSLSocketPair_create() {
1511        TestSSLSocketPair test = TestSSLSocketPair.create();
1512        assertNotNull(test.c);
1513        assertNotNull(test.server);
1514        assertNotNull(test.client);
1515        assertTrue(test.server.isConnected());
1516        assertTrue(test.client.isConnected());
1517        assertFalse(test.server.isClosed());
1518        assertFalse(test.client.isClosed());
1519        assertNotNull(test.server.getSession());
1520        assertNotNull(test.client.getSession());
1521        assertTrue(test.server.getSession().isValid());
1522        assertTrue(test.client.getSession().isValid());
1523        test.close();
1524    }
1525
1526    public void test_SSLSocket_ClientHello_size() throws Exception {
1527        // This test checks the size of ClientHello of the default SSLSocket. TLS/SSL handshakes
1528        // with older/unpatched F5/BIG-IP appliances are known to stall and time out when
1529        // the fragment containing ClientHello is between 256 and 511 (inclusive) bytes long.
1530        //
1531        // Since there's no straightforward way to obtain a ClientHello from SSLSocket, this test
1532        // does the following:
1533        // 1. Creates a listening server socket (a plain one rather than a TLS/SSL one).
1534        // 2. Creates a client SSLSocket, which connects to the server socket and initiates the
1535        //    TLS/SSL handshake.
1536        // 3. Makes the server socket accept an incoming connection on the server socket, and reads
1537        //    the first chunk of data received. This chunk is assumed to be the ClientHello.
1538        // NOTE: Steps 2 and 3 run concurrently.
1539        ServerSocket listeningSocket = null;
1540        ExecutorService executorService = Executors.newFixedThreadPool(2);
1541
1542        // Some Socket operations are not interruptible via Thread.interrupt for some reason. To
1543        // work around, we unblock these sockets using Socket.close.
1544        final Socket[] sockets = new Socket[2];
1545        try {
1546            // 1. Create the listening server socket.
1547            listeningSocket = ServerSocketFactory.getDefault().createServerSocket(0);
1548            final ServerSocket finalListeningSocket = listeningSocket;
1549            // 2. (in background) Wait for an incoming connection and read its first chunk.
1550            final Future<byte[]> readFirstReceivedChunkFuture =
1551                    executorService.submit(new Callable<byte[]>() {
1552                        @Override
1553                        public byte[] call() throws Exception {
1554                            Socket socket = finalListeningSocket.accept();
1555                            sockets[1] = socket;
1556                            try {
1557                                byte[] buffer = new byte[64 * 1024];
1558                                int bytesRead = socket.getInputStream().read(buffer);
1559                                if (bytesRead == -1) {
1560                                    throw new EOFException("Failed to read anything");
1561                                }
1562                                return Arrays.copyOf(buffer, bytesRead);
1563                            } finally {
1564                                IoUtils.closeQuietly(socket);
1565                            }
1566                        }
1567                    });
1568
1569            // 3. Create a client socket, connect it to the server socket, and start the TLS/SSL
1570            //    handshake.
1571            executorService.submit(new Callable<Void>() {
1572                @Override
1573                public Void call() throws Exception {
1574                    SSLContext sslContext = SSLContext.getInstance("TLS");
1575                    sslContext.init(null, null, null);
1576                    SSLSocket client = (SSLSocket) sslContext.getSocketFactory().createSocket();
1577                    sockets[0] = client;
1578                    try {
1579                        // Enable SNI extension on the socket (this is typically enabled by default)
1580                        // to increase the size of ClientHello.
1581                        try {
1582                            Method setHostname =
1583                                    client.getClass().getMethod("setHostname", String.class);
1584                            setHostname.invoke(client, "sslsockettest.androidcts.google.com");
1585                        } catch (NoSuchMethodException ignored) {}
1586
1587                        // Enable Session Tickets extension on the socket (this is typically enabled
1588                        // by default) to increase the size of ClientHello.
1589                        try {
1590                            Method setUseSessionTickets =
1591                                    client.getClass().getMethod(
1592                                            "setUseSessionTickets", boolean.class);
1593                            setUseSessionTickets.invoke(client, true);
1594                        } catch (NoSuchMethodException ignored) {}
1595
1596                        client.connect(finalListeningSocket.getLocalSocketAddress());
1597                        // Initiate the TLS/SSL handshake which is expected to fail as soon as the
1598                        // server socket receives a ClientHello.
1599                        try {
1600                            client.startHandshake();
1601                            fail();
1602                            return null;
1603                        } catch (IOException expected) {}
1604                        return null;
1605                    } finally {
1606                        IoUtils.closeQuietly(client);
1607
1608                        // Cancel the reading task. If this task succeeded, then the reading task
1609                        // is done and this will have no effect. If this task failed prematurely,
1610                        // then the reading task might get unblocked (we're interrupting the thread
1611                        // it's running on), will fail early, and we'll thus save some time in this
1612                        // test.
1613                        readFirstReceivedChunkFuture.cancel(true);
1614                    }
1615                }
1616            });
1617
1618            // Wait for the ClientHello to arrive
1619            byte[] clientHello = readFirstReceivedChunkFuture.get(10, TimeUnit.SECONDS);
1620
1621            // Check for ClientHello length that may cause handshake to fail/time out with older
1622            // F5/BIG-IP appliances.
1623            assertEquals("TLS record type: handshake", 22, clientHello[0]);
1624            int fragmentLength = ((clientHello[3] & 0xff) << 8) | (clientHello[4] & 0xff);
1625            if ((fragmentLength >= 256) && (fragmentLength <= 511)) {
1626                fail("Fragment containing ClientHello is of dangerous length: "
1627                        + fragmentLength + " bytes");
1628            }
1629        } finally {
1630            executorService.shutdownNow();
1631            IoUtils.closeQuietly(listeningSocket);
1632            IoUtils.closeQuietly(sockets[0]);
1633            IoUtils.closeQuietly(sockets[1]);
1634            if (!executorService.awaitTermination(5, TimeUnit.SECONDS)) {
1635                fail("Timed out while waiting for the test to shut down");
1636            }
1637        }
1638    }
1639
1640    /**
1641     * Not run by default by JUnit, but can be run by Vogar by
1642     * specifying it explicitly (or with main method below)
1643     */
1644    public void stress_test_TestSSLSocketPair_create() {
1645        final boolean verbose = true;
1646        while (true) {
1647            TestSSLSocketPair test = TestSSLSocketPair.create();
1648            if (verbose) {
1649                System.out.println("client=" + test.client.getLocalPort()
1650                                   + " server=" + test.server.getLocalPort());
1651            } else {
1652                System.out.print("X");
1653            }
1654
1655            /*
1656              We don't close on purpose in this stress test to add
1657              races in file descriptors reuse when the garbage
1658              collector runs concurrently and finalizes sockets
1659            */
1660            // test.close();
1661
1662        }
1663    }
1664
1665    public static void main (String[] args) {
1666        new SSLSocketTest().stress_test_TestSSLSocketPair_create();
1667    }
1668}
1669