1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import static org.conscrypt.SSLUtils.EngineStates.STATE_CLOSED;
20import static org.conscrypt.SSLUtils.EngineStates.STATE_HANDSHAKE_STARTED;
21import static org.conscrypt.SSLUtils.EngineStates.STATE_NEW;
22import static org.conscrypt.SSLUtils.EngineStates.STATE_READY;
23import static org.conscrypt.SSLUtils.EngineStates.STATE_READY_HANDSHAKE_CUT_THROUGH;
24
25import java.io.IOException;
26import java.io.InputStream;
27import java.io.OutputStream;
28import java.net.InetAddress;
29import java.net.Socket;
30import java.net.SocketException;
31import java.security.InvalidKeyException;
32import java.security.PrivateKey;
33import java.security.cert.CertificateEncodingException;
34import java.security.cert.CertificateException;
35import java.security.interfaces.ECKey;
36import java.security.spec.ECParameterSpec;
37import javax.crypto.SecretKey;
38import javax.net.ssl.SSLException;
39import javax.net.ssl.SSLHandshakeException;
40import javax.net.ssl.SSLParameters;
41import javax.net.ssl.SSLProtocolException;
42import javax.net.ssl.SSLSession;
43import javax.net.ssl.X509KeyManager;
44import javax.net.ssl.X509TrustManager;
45import javax.security.auth.x500.X500Principal;
46import org.conscrypt.NativeRef.SSL_SESSION;
47
48/**
49 * Implementation of the class OpenSSLSocketImpl based on OpenSSL.
50 * <p>
51 * Extensions to SSLSocket include:
52 * <ul>
53 * <li>handshake timeout
54 * <li>session tickets
55 * <li>Server Name Indication
56 * </ul>
57 */
58final class ConscryptFileDescriptorSocket extends OpenSSLSocketImpl
59        implements NativeCrypto.SSLHandshakeCallbacks, SSLParametersImpl.AliasChooser,
60                   SSLParametersImpl.PSKCallbacks {
61    private static final boolean DBG_STATE = false;
62
63    /**
64     * Protects handshakeStarted and handshakeCompleted.
65     */
66    private final Object stateLock = new Object();
67
68    // @GuardedBy("stateLock");
69    private int state = STATE_NEW;
70
71    /**
72     * Protected by synchronizing on stateLock. Starts as 0, set by
73     * startHandshake, reset to 0 on close.
74     */
75    // @GuardedBy("stateLock");
76    private final SslWrapper ssl;
77
78    /**
79     * Protected by synchronizing on stateLock. Starts as null, set by
80     * getInputStream.
81     */
82    // @GuardedBy("stateLock");
83    private SSLInputStream is;
84
85    /**
86     * Protected by synchronizing on stateLock. Starts as null, set by
87     * getInputStream.
88     */
89    // @GuardedBy("stateLock");
90    private SSLOutputStream os;
91
92    private final SSLParametersImpl sslParameters;
93
94    /*
95     * A CloseGuard object on Android. On other platforms, this is nothing.
96     */
97    private final Object guard = Platform.closeGuardGet();
98
99    /**
100     * Private key for the TLS Channel ID extension. This field is client-side
101     * only. Set during startHandshake.
102     */
103    private OpenSSLKey channelIdPrivateKey;
104
105    private final ActiveSession sslSession;
106
107    private int writeTimeoutMilliseconds = 0;
108    private int handshakeTimeoutMilliseconds = -1; // -1 = same as timeout; 0 = infinite
109
110    ConscryptFileDescriptorSocket(SSLParametersImpl sslParameters) throws IOException {
111        this.sslParameters = sslParameters;
112        this.ssl = newSsl(sslParameters, this);
113        sslSession = new ActiveSession(ssl, sslParameters.getSessionContext());
114    }
115
116    ConscryptFileDescriptorSocket(String hostname, int port, SSLParametersImpl sslParameters)
117            throws IOException {
118        super(hostname, port);
119        this.sslParameters = sslParameters;
120        this.ssl = newSsl(sslParameters, this);
121        sslSession = new ActiveSession(ssl, sslParameters.getSessionContext());
122    }
123
124    ConscryptFileDescriptorSocket(InetAddress address, int port, SSLParametersImpl sslParameters)
125            throws IOException {
126        super(address, port);
127        this.sslParameters = sslParameters;
128        this.ssl = newSsl(sslParameters, this);
129        sslSession = new ActiveSession(ssl, sslParameters.getSessionContext());
130    }
131
132    ConscryptFileDescriptorSocket(String hostname, int port, InetAddress clientAddress,
133            int clientPort, SSLParametersImpl sslParameters) throws IOException {
134        super(hostname, port, clientAddress, clientPort);
135        this.sslParameters = sslParameters;
136        this.ssl = newSsl(sslParameters, this);
137        sslSession = new ActiveSession(ssl, sslParameters.getSessionContext());
138    }
139
140    ConscryptFileDescriptorSocket(InetAddress address, int port, InetAddress clientAddress,
141            int clientPort, SSLParametersImpl sslParameters) throws IOException {
142        super(address, port, clientAddress, clientPort);
143        this.sslParameters = sslParameters;
144        this.ssl = newSsl(sslParameters, this);
145        sslSession = new ActiveSession(ssl, sslParameters.getSessionContext());
146    }
147
148    ConscryptFileDescriptorSocket(Socket socket, String hostname, int port, boolean autoClose,
149            SSLParametersImpl sslParameters) throws IOException {
150        super(socket, hostname, port, autoClose);
151        this.sslParameters = sslParameters;
152        this.ssl = newSsl(sslParameters, this);
153        sslSession = new ActiveSession(ssl, sslParameters.getSessionContext());
154    }
155
156    private static SslWrapper newSsl(SSLParametersImpl sslParameters,
157            ConscryptFileDescriptorSocket engine) {
158        try {
159            return SslWrapper.newInstance(sslParameters, engine, engine, engine);
160        } catch (SSLException e) {
161            throw new RuntimeException(e);
162        }
163    }
164
165    /**
166     * Starts a TLS/SSL handshake on this connection using some native methods
167     * from the OpenSSL library. It can negotiate new encryption keys, change
168     * cipher suites, or initiate a new session. The certificate chain is
169     * verified if the correspondent property in java.Security is set. All
170     * listeners are notified at the end of the TLS/SSL handshake.
171     */
172    @Override
173    public void startHandshake() throws IOException {
174        checkOpen();
175        synchronized (stateLock) {
176            if (state == STATE_NEW) {
177                state = STATE_HANDSHAKE_STARTED;
178            } else {
179                // We've either started the handshake already or have been closed.
180                // Do nothing in both cases.
181                return;
182            }
183        }
184
185        boolean releaseResources = true;
186        try {
187            Platform.closeGuardOpen(guard, "close");
188
189            // Prepare the SSL object for the handshake.
190            ssl.initialize(getHostname(), channelIdPrivateKey);
191
192            // For clients, offer to resume a previously cached session to avoid the
193            // full TLS handshake.
194            if (getUseClientMode()) {
195                SslSessionWrapper cachedSession = clientSessionContext().getCachedSession(
196                        getHostnameOrIP(), getPort(), sslParameters);
197                if (cachedSession != null) {
198                    cachedSession.offerToResume(ssl);
199                }
200            }
201
202            // Temporarily use a different timeout for the handshake process
203            int savedReadTimeoutMilliseconds = getSoTimeout();
204            int savedWriteTimeoutMilliseconds = getSoWriteTimeout();
205            if (handshakeTimeoutMilliseconds >= 0) {
206                setSoTimeout(handshakeTimeoutMilliseconds);
207                setSoWriteTimeout(handshakeTimeoutMilliseconds);
208            }
209
210            synchronized (stateLock) {
211                if (state == STATE_CLOSED) {
212                    return;
213                }
214            }
215
216            try {
217                ssl.doHandshake(Platform.getFileDescriptor(socket), getSoTimeout());
218            } catch (CertificateException e) {
219                SSLHandshakeException wrapper = new SSLHandshakeException(e.getMessage());
220                wrapper.initCause(e);
221                throw wrapper;
222            } catch (SSLException e) {
223                // Swallow this exception if it's thrown as the result of an interruption.
224                //
225                // TODO: SSL_read and SSL_write return -1 when interrupted, but SSL_do_handshake
226                // will throw the last sslError that it saw before sslSelect, usually SSL_WANT_READ
227                // (or WANT_WRITE). Catching that exception here doesn't seem much worse than
228                // changing the native code to return a "special" native pointer value when that
229                // happens.
230                synchronized (stateLock) {
231                    if (state == STATE_CLOSED) {
232                        return;
233                    }
234                }
235
236                // Write CCS errors to EventLog
237                String message = e.getMessage();
238                // Must match error string of SSL_R_UNEXPECTED_CCS
239                if (message.contains("unexpected CCS")) {
240                    String logMessage =
241                            String.format("ssl_unexpected_ccs: host=%s", getHostnameOrIP());
242                    Platform.logEvent(logMessage);
243                }
244
245                throw e;
246            }
247
248            synchronized (stateLock) {
249                if (state == STATE_CLOSED) {
250                    return;
251                }
252            }
253
254            // Restore the original timeout now that the handshake is complete
255            if (handshakeTimeoutMilliseconds >= 0) {
256                setSoTimeout(savedReadTimeoutMilliseconds);
257                setSoWriteTimeout(savedWriteTimeoutMilliseconds);
258            }
259
260            synchronized (stateLock) {
261                releaseResources = (state == STATE_CLOSED);
262
263                if (state == STATE_HANDSHAKE_STARTED) {
264                    state = STATE_READY_HANDSHAKE_CUT_THROUGH;
265                } else {
266                    state = STATE_READY;
267                }
268
269                if (!releaseResources) {
270                    // Unblock threads that are waiting for our state to transition
271                    // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH.
272                    stateLock.notifyAll();
273                }
274            }
275        } catch (SSLProtocolException e) {
276            throw(SSLHandshakeException) new SSLHandshakeException("Handshake failed").initCause(e);
277        } finally {
278            // on exceptional exit, treat the socket as closed
279            if (releaseResources) {
280                synchronized (stateLock) {
281                    // Mark the socket as closed since we might have reached this as
282                    // a result on an exception thrown by the handshake process.
283                    //
284                    // The state will already be set to closed if we reach this as a result of
285                    // an early return or an interruption due to a concurrent call to close().
286                    state = STATE_CLOSED;
287                    stateLock.notifyAll();
288                }
289
290                try {
291                    shutdownAndFreeSslNative();
292                } catch (IOException ignored) {
293                    // Ignored.
294                }
295            }
296        }
297    }
298
299    @Override
300    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / client_cert_cb
301    public void clientCertificateRequested(byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals)
302            throws CertificateEncodingException, SSLException {
303        ssl.chooseClientCertificate(keyTypeBytes, asn1DerEncodedPrincipals);
304    }
305
306    @Override
307    @SuppressWarnings("unused") // used by native psk_client_callback
308    public int clientPSKKeyRequested(String identityHint, byte[] identity, byte[] key) {
309        return ssl.clientPSKKeyRequested(identityHint, identity, key);
310    }
311
312    @Override
313    @SuppressWarnings("unused") // used by native psk_server_callback
314    public int serverPSKKeyRequested(String identityHint, String identity, byte[] key) {
315        return ssl.serverPSKKeyRequested(identityHint, identity, key);
316    }
317
318    @Override
319    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / info_callback
320    public void onSSLStateChange(int type, int val) {
321        if (type != NativeConstants.SSL_CB_HANDSHAKE_DONE) {
322            // We only care about successful completion.
323            return;
324        }
325
326        // The handshake has completed successfully ...
327
328        // Update the session from the current state of the SSL object.
329        sslSession.onSessionEstablished(getHostnameOrIP(), getPort());
330
331        // First, update the state.
332        synchronized (stateLock) {
333            if (state == STATE_CLOSED) {
334                // Someone called "close" but the handshake hasn't been interrupted yet.
335                return;
336            }
337
338            // Now that we've fixed up our state, we can tell waiting threads that
339            // we're ready.
340            state = STATE_READY;
341        }
342
343        // Let listeners know we are finally done
344        notifyHandshakeCompletedListeners();
345
346        synchronized (stateLock) {
347            // Notify all threads waiting for the handshake to complete.
348            stateLock.notifyAll();
349        }
350    }
351
352    @Override
353    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / new_session_callback
354    public void onNewSessionEstablished(long sslSessionNativePtr) {
355        try {
356            // Increment the reference count to "take ownership" of the session resource.
357            NativeCrypto.SSL_SESSION_up_ref(sslSessionNativePtr);
358
359            // Create a native reference which will release the SSL_SESSION in its finalizer.
360            // This constructor will only throw if the native pointer passed in is NULL, which
361            // BoringSSL guarantees will not happen.
362            NativeRef.SSL_SESSION ref = new SSL_SESSION(sslSessionNativePtr);
363
364            SslSessionWrapper sessionWrapper = SslSessionWrapper.newInstance(ref, sslSession);
365
366            // Cache the newly established session.
367            AbstractSessionContext ctx = sessionContext();
368            ctx.cacheSession(sessionWrapper);
369        } catch (Exception ignored) {
370            // Ignore.
371        }
372    }
373
374    @Override
375    public long serverSessionRequested(byte[] id) {
376        // TODO(nathanmittler): Implement server-side caching for TLS < 1.3
377        return 0;
378    }
379
380    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks
381    @Override
382    public void verifyCertificateChain(long[] certRefs, String authMethod)
383            throws CertificateException {
384        try {
385            X509TrustManager x509tm = sslParameters.getX509TrustManager();
386            if (x509tm == null) {
387                throw new CertificateException("No X.509 TrustManager");
388            }
389            if (certRefs == null || certRefs.length == 0) {
390                throw new SSLException("Peer sent no certificate");
391            }
392            OpenSSLX509Certificate[] peerCertChain =
393                    OpenSSLX509Certificate.createCertChain(certRefs);
394
395            // Update the peer information on the session.
396            sslSession.onPeerCertificatesReceived(getHostnameOrIP(), getPort(), peerCertChain);
397
398            if (getUseClientMode()) {
399                Platform.checkServerTrusted(x509tm, peerCertChain, authMethod, this);
400            } else {
401                String authType = peerCertChain[0].getPublicKey().getAlgorithm();
402                Platform.checkClientTrusted(x509tm, peerCertChain, authType, this);
403            }
404        } catch (CertificateException e) {
405            throw e;
406        } catch (Exception e) {
407            throw new CertificateException(e);
408        }
409    }
410
411    @Override
412    public InputStream getInputStream() throws IOException {
413        checkOpen();
414
415        InputStream returnVal;
416        synchronized (stateLock) {
417            if (state == STATE_CLOSED) {
418                throw new SocketException("Socket is closed.");
419            }
420
421            if (is == null) {
422                is = new SSLInputStream();
423            }
424
425            returnVal = is;
426        }
427
428        // Block waiting for a handshake without a lock held. It's possible that the socket
429        // is closed at this point. If that happens, we'll still return the input stream but
430        // all reads on it will throw.
431        waitForHandshake();
432        return returnVal;
433    }
434
435    @Override
436    public OutputStream getOutputStream() throws IOException {
437        checkOpen();
438
439        OutputStream returnVal;
440        synchronized (stateLock) {
441            if (state == STATE_CLOSED) {
442                throw new SocketException("Socket is closed.");
443            }
444
445            if (os == null) {
446                os = new SSLOutputStream();
447            }
448
449            returnVal = os;
450        }
451
452        // Block waiting for a handshake without a lock held. It's possible that the socket
453        // is closed at this point. If that happens, we'll still return the output stream but
454        // all writes on it will throw.
455        waitForHandshake();
456        return returnVal;
457    }
458
459    private void assertReadableOrWriteableState() {
460        if (state == STATE_READY || state == STATE_READY_HANDSHAKE_CUT_THROUGH) {
461            return;
462        }
463
464        throw new AssertionError("Invalid state: " + state);
465    }
466
467    private void waitForHandshake() throws IOException {
468        startHandshake();
469
470        synchronized (stateLock) {
471            while (state != STATE_READY &&
472                    state != STATE_READY_HANDSHAKE_CUT_THROUGH &&
473                    state != STATE_CLOSED) {
474                try {
475                    stateLock.wait();
476                } catch (InterruptedException e) {
477                    Thread.currentThread().interrupt();
478                    throw new IOException("Interrupted waiting for handshake", e);
479                }
480            }
481
482            if (state == STATE_CLOSED) {
483                throw new SocketException("Socket is closed");
484            }
485        }
486    }
487
488    /**
489     * This inner class provides input data stream functionality
490     * for the OpenSSL native implementation. It is used to
491     * read data received via SSL protocol.
492     */
493    private class SSLInputStream extends InputStream {
494        /**
495         * OpenSSL only lets one thread read at a time, so this is used to
496         * make sure we serialize callers of SSL_read. Thread is already
497         * expected to have completed handshaking.
498         */
499        private final Object readLock = new Object();
500
501        SSLInputStream() {
502        }
503
504        /**
505         * Reads one byte. If there is no data in the underlying buffer,
506         * this operation can block until the data will be
507         * available.
508         */
509        @Override
510        public int read() throws IOException {
511            byte[] buffer = new byte[1];
512            int result = read(buffer, 0, 1);
513            return (result != -1) ? buffer[0] & 0xff : -1;
514        }
515
516        /**
517         * Method acts as described in spec for superclass.
518         * @see java.io.InputStream#read(byte[],int,int)
519         */
520        @Override
521        public int read(byte[] buf, int offset, int byteCount) throws IOException {
522            Platform.blockGuardOnNetwork();
523
524            checkOpen();
525            ArrayUtils.checkOffsetAndCount(buf.length, offset, byteCount);
526            if (byteCount == 0) {
527                return 0;
528            }
529
530            synchronized (readLock) {
531                synchronized (stateLock) {
532                    if (state == STATE_CLOSED) {
533                        throw new SocketException("socket is closed");
534                    }
535
536                    if (DBG_STATE) {
537                        assertReadableOrWriteableState();
538                    }
539                }
540
541                int ret =  ssl.read(
542                        Platform.getFileDescriptor(socket), buf, offset, byteCount, getSoTimeout());
543                if (ret == -1) {
544                    synchronized (stateLock) {
545                        if (state == STATE_CLOSED) {
546                            throw new SocketException("socket is closed");
547                        }
548                    }
549                }
550                return ret;
551            }
552        }
553
554        void awaitPendingOps() {
555            if (DBG_STATE) {
556                synchronized (stateLock) {
557                    if (state != STATE_CLOSED) {
558                        throw new AssertionError("State is: " + state);
559                    }
560                }
561            }
562
563            synchronized (readLock) {}
564        }
565    }
566
567    /**
568     * This inner class provides output data stream functionality
569     * for the OpenSSL native implementation. It is used to
570     * write data according to the encryption parameters given in SSL context.
571     */
572    private class SSLOutputStream extends OutputStream {
573        /**
574         * OpenSSL only lets one thread write at a time, so this is used
575         * to make sure we serialize callers of SSL_write. Thread is
576         * already expected to have completed handshaking.
577         */
578        private final Object writeLock = new Object();
579
580        SSLOutputStream() {
581        }
582
583        /**
584         * Method acts as described in spec for superclass.
585         * @see java.io.OutputStream#write(int)
586         */
587        @Override
588        public void write(int oneByte) throws IOException {
589            byte[] buffer = new byte[1];
590            buffer[0] = (byte) (oneByte & 0xff);
591            write(buffer);
592        }
593
594        /**
595         * Method acts as described in spec for superclass.
596         * @see java.io.OutputStream#write(byte[],int,int)
597         */
598        @Override
599        public void write(byte[] buf, int offset, int byteCount) throws IOException {
600            Platform.blockGuardOnNetwork();
601            checkOpen();
602            ArrayUtils.checkOffsetAndCount(buf.length, offset, byteCount);
603            if (byteCount == 0) {
604                return;
605            }
606
607            synchronized (writeLock) {
608                synchronized (stateLock) {
609                    if (state == STATE_CLOSED) {
610                        throw new SocketException("socket is closed");
611                    }
612
613                    if (DBG_STATE) {
614                        assertReadableOrWriteableState();
615                    }
616                }
617
618                ssl.write(Platform.getFileDescriptor(socket), buf, offset, byteCount,
619                        writeTimeoutMilliseconds);
620
621                synchronized (stateLock) {
622                    if (state == STATE_CLOSED) {
623                        throw new SocketException("socket is closed");
624                    }
625                }
626            }
627        }
628
629        void awaitPendingOps() {
630            if (DBG_STATE) {
631                synchronized (stateLock) {
632                    if (state != STATE_CLOSED) {
633                        throw new AssertionError("State is: " + state);
634                    }
635                }
636            }
637
638            synchronized (writeLock) {}
639        }
640    }
641
642    @Override
643    public SSLSession getSession() {
644        boolean handshakeCompleted = false;
645        synchronized (stateLock) {
646            try {
647                handshakeCompleted = state >= STATE_READY;
648                if (!handshakeCompleted && isConnected()) {
649                    waitForHandshake();
650                    handshakeCompleted = true;
651                }
652            } catch (IOException e) {
653                // Fall through.
654            }
655        }
656
657        if (!handshakeCompleted) {
658            // return an invalid session with
659            // invalid cipher suite of "SSL_NULL_WITH_NULL_NULL"
660            return SSLNullSession.getNullSession();
661        }
662
663        return Platform.wrapSSLSession(sslSession);
664    }
665
666    @Override
667    SSLSession getActiveSession() {
668        return sslSession;
669    }
670
671    @Override
672    public SSLSession getHandshakeSession() {
673        synchronized (stateLock) {
674            return state >= STATE_HANDSHAKE_STARTED && state < STATE_READY ? sslSession : null;
675        }
676    }
677
678    @Override
679    public boolean getEnableSessionCreation() {
680        return sslParameters.getEnableSessionCreation();
681    }
682
683    @Override
684    public void setEnableSessionCreation(boolean flag) {
685        sslParameters.setEnableSessionCreation(flag);
686    }
687
688    @Override
689    public String[] getSupportedCipherSuites() {
690        return NativeCrypto.getSupportedCipherSuites();
691    }
692
693    @Override
694    public String[] getEnabledCipherSuites() {
695        return sslParameters.getEnabledCipherSuites();
696    }
697
698    @Override
699    public void setEnabledCipherSuites(String[] suites) {
700        sslParameters.setEnabledCipherSuites(suites);
701    }
702
703    @Override
704    public String[] getSupportedProtocols() {
705        return NativeCrypto.getSupportedProtocols();
706    }
707
708    @Override
709    public String[] getEnabledProtocols() {
710        return sslParameters.getEnabledProtocols();
711    }
712
713    @Override
714    public void setEnabledProtocols(String[] protocols) {
715        sslParameters.setEnabledProtocols(protocols);
716    }
717
718    /**
719     * This method enables session ticket support.
720     *
721     * @param useSessionTickets True to enable session tickets
722     */
723    @Override
724    public void setUseSessionTickets(boolean useSessionTickets) {
725        sslParameters.setUseSessionTickets(useSessionTickets);
726    }
727
728    /**
729     * This method enables Server Name Indication
730     *
731     * @param hostname the desired SNI hostname, or null to disable
732     */
733    @Override
734    public void setHostname(String hostname) {
735        sslParameters.setUseSni(hostname != null);
736        super.setHostname(hostname);
737    }
738
739    /**
740     * Enables/disables TLS Channel ID for this server socket.
741     *
742     * <p>This method needs to be invoked before the handshake starts.
743     *
744     * @throws IllegalStateException if this is a client socket or if the handshake has already
745     *         started.
746     */
747    @Override
748    public void setChannelIdEnabled(boolean enabled) {
749        if (getUseClientMode()) {
750            throw new IllegalStateException("Client mode");
751        }
752
753        synchronized (stateLock) {
754            if (state != STATE_NEW) {
755                throw new IllegalStateException(
756                        "Could not enable/disable Channel ID after the initial handshake has"
757                                + " begun.");
758            }
759        }
760        sslParameters.channelIdEnabled = enabled;
761    }
762
763    /**
764     * Gets the TLS Channel ID for this server socket. Channel ID is only available once the
765     * handshake completes.
766     *
767     * @return channel ID or {@code null} if not available.
768     *
769     * @throws IllegalStateException if this is a client socket or if the handshake has not yet
770     *         completed.
771     * @throws SSLException if channel ID is available but could not be obtained.
772     */
773    @Override
774    public byte[] getChannelId() throws SSLException {
775        if (getUseClientMode()) {
776            throw new IllegalStateException("Client mode");
777        }
778
779        synchronized (stateLock) {
780            if (state != STATE_READY) {
781                throw new IllegalStateException(
782                        "Channel ID is only available after handshake completes");
783            }
784        }
785        return ssl.getTlsChannelId();
786    }
787
788    /**
789     * Sets the {@link PrivateKey} to be used for TLS Channel ID by this client socket.
790     *
791     * <p>This method needs to be invoked before the handshake starts.
792     *
793     * @param privateKey private key (enables TLS Channel ID) or {@code null} for no key (disables
794     *        TLS Channel ID). The private key must be an Elliptic Curve (EC) key based on the NIST
795     *        P-256 curve (aka SECG secp256r1 or ANSI X9.62 prime256v1).
796     *
797     * @throws IllegalStateException if this is a server socket or if the handshake has already
798     *         started.
799     */
800    @Override
801    public void setChannelIdPrivateKey(PrivateKey privateKey) {
802        if (!getUseClientMode()) {
803            throw new IllegalStateException("Server mode");
804        }
805
806        synchronized (stateLock) {
807            if (state != STATE_NEW) {
808                throw new IllegalStateException(
809                        "Could not change Channel ID private key after the initial handshake has"
810                                + " begun.");
811            }
812        }
813
814        if (privateKey == null) {
815            sslParameters.channelIdEnabled = false;
816            channelIdPrivateKey = null;
817        } else {
818            sslParameters.channelIdEnabled = true;
819            try {
820                ECParameterSpec ecParams = null;
821                if (privateKey instanceof ECKey) {
822                    ecParams = ((ECKey) privateKey).getParams();
823                }
824                if (ecParams == null) {
825                    // Assume this is a P-256 key, as specified in the contract of this method.
826                    ecParams =
827                            OpenSSLECGroupContext.getCurveByName("prime256v1").getECParameterSpec();
828                }
829                channelIdPrivateKey =
830                        OpenSSLKey.fromECPrivateKeyForTLSStackOnly(privateKey, ecParams);
831            } catch (InvalidKeyException e) {
832                // Will have error in startHandshake
833            }
834        }
835    }
836
837    @Override
838    public boolean getUseClientMode() {
839        return sslParameters.getUseClientMode();
840    }
841
842    @Override
843    public void setUseClientMode(boolean mode) {
844        synchronized (stateLock) {
845            if (state != STATE_NEW) {
846                throw new IllegalArgumentException(
847                        "Could not change the mode after the initial handshake has begun.");
848            }
849        }
850        sslParameters.setUseClientMode(mode);
851    }
852
853    @Override
854    public boolean getWantClientAuth() {
855        return sslParameters.getWantClientAuth();
856    }
857
858    @Override
859    public boolean getNeedClientAuth() {
860        return sslParameters.getNeedClientAuth();
861    }
862
863    @Override
864    public void setNeedClientAuth(boolean need) {
865        sslParameters.setNeedClientAuth(need);
866    }
867
868    @Override
869    public void setWantClientAuth(boolean want) {
870        sslParameters.setWantClientAuth(want);
871    }
872
873    /**
874     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
875     */
876    @Override
877    public void setSoWriteTimeout(int writeTimeoutMilliseconds) throws SocketException {
878        this.writeTimeoutMilliseconds = writeTimeoutMilliseconds;
879
880        Platform.setSocketWriteTimeout(this, writeTimeoutMilliseconds);
881    }
882
883    /**
884     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
885     */
886    @Override
887    public int getSoWriteTimeout() throws SocketException {
888        return writeTimeoutMilliseconds;
889    }
890
891    /**
892     * Set the handshake timeout on this socket.  This timeout is specified in
893     * milliseconds and will be used only during the handshake process.
894     */
895    @Override
896    public void setHandshakeTimeout(int handshakeTimeoutMilliseconds) throws SocketException {
897        this.handshakeTimeoutMilliseconds = handshakeTimeoutMilliseconds;
898    }
899
900    @Override
901    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
902    public void close() throws IOException {
903        // TODO: Close SSL sockets using a background thread so they close gracefully.
904
905        SSLInputStream sslInputStream;
906        SSLOutputStream sslOutputStream;
907
908        synchronized (stateLock) {
909            if (state == STATE_CLOSED) {
910                // close() has already been called, so do nothing and return.
911                return;
912            }
913
914            int oldState = state;
915            state = STATE_CLOSED;
916
917            if (oldState == STATE_NEW) {
918                // The handshake hasn't been started yet, so there's no OpenSSL related
919                // state to clean up. We still need to close the underlying socket if
920                // we're wrapping it and were asked to autoClose.
921                free();
922                closeUnderlyingSocket();
923
924                stateLock.notifyAll();
925                return;
926            }
927
928            if (oldState != STATE_READY && oldState != STATE_READY_HANDSHAKE_CUT_THROUGH) {
929                // If we're in these states, we still haven't returned from startHandshake.
930                // We call SSL_interrupt so that we can interrupt SSL_do_handshake and then
931                // set the state to STATE_CLOSED. startHandshake will handle all cleanup
932                // after SSL_do_handshake returns, so we don't have anything to do here.
933                ssl.interrupt();
934
935                stateLock.notifyAll();
936                return;
937            }
938
939            stateLock.notifyAll();
940            // We've already returned from startHandshake, so we potentially have
941            // input and output streams to clean up.
942            sslInputStream = is;
943            sslOutputStream = os;
944        }
945
946        // Don't bother interrupting unless we have something to interrupt.
947        if (sslInputStream != null || sslOutputStream != null) {
948            ssl.interrupt();
949        }
950
951        // Wait for the input and output streams to finish any reads they have in
952        // progress. If there are no reads in progress at this point, future reads will
953        // throw because state == STATE_CLOSED
954        if (sslInputStream != null) {
955            sslInputStream.awaitPendingOps();
956        }
957        if (sslOutputStream != null) {
958            sslOutputStream.awaitPendingOps();
959        }
960
961        shutdownAndFreeSslNative();
962    }
963
964    private void shutdownAndFreeSslNative() throws IOException {
965        try {
966            Platform.blockGuardOnNetwork();
967            ssl.shutdown(Platform.getFileDescriptor(socket));
968        } catch (IOException ignored) {
969            /*
970             * Note that although close() can throw
971             * IOException, the RI does not throw if there
972             * is problem sending a "close notify" which
973             * can happen if the underlying socket is closed.
974             */
975        } finally {
976            free();
977            closeUnderlyingSocket();
978        }
979    }
980
981    private void closeUnderlyingSocket() throws IOException {
982        super.close();
983    }
984
985    private void free() {
986        if (!ssl.isClosed()) {
987            ssl.close();
988            Platform.closeGuardClose(guard);
989        }
990    }
991
992    @Override
993    protected void finalize() throws Throwable {
994        try {
995            /*
996             * Just worry about our own state. Notably we do not try and
997             * close anything. The SocketImpl, either our own
998             * PlainSocketImpl, or the Socket we are wrapping, will do
999             * that. This might mean we do not properly SSL_shutdown, but
1000             * if you want to do that, properly close the socket yourself.
1001             *
1002             * The reason why we don't try to SSL_shutdown, is that there
1003             * can be a race between finalizers where the PlainSocketImpl
1004             * finalizer runs first and closes the socket. However, in the
1005             * meanwhile, the underlying file descriptor could be reused
1006             * for another purpose. If we call SSL_shutdown, the
1007             * underlying socket BIOs still have the old file descriptor
1008             * and will write the close notify to some unsuspecting
1009             * reader.
1010             */
1011            if (guard != null) {
1012                Platform.closeGuardWarnIfOpen(guard);
1013            }
1014            free();
1015        } finally {
1016            super.finalize();
1017        }
1018    }
1019
1020    /**
1021     * Returns the protocol agreed upon by client and server, or {@code null} if
1022     * no protocol was agreed upon.
1023     */
1024    @Override
1025    public byte[] getAlpnSelectedProtocol() {
1026        return ssl.getAlpnSelectedProtocol();
1027    }
1028
1029    /**
1030     * Sets the list of ALPN protocols. This method internally converts the protocols to their
1031     * wire-format form.
1032     *
1033     * @param alpnProtocols the list of ALPN protocols
1034     * @see #setAlpnProtocols(byte[])
1035     */
1036    @Override
1037    public void setAlpnProtocols(String[] alpnProtocols) {
1038        sslParameters.setAlpnProtocols(alpnProtocols);
1039    }
1040
1041    /**
1042     * Alternate version of {@link #setAlpnProtocols(String[])} that directly sets the list of
1043     * ALPN in the wire-format form used by BoringSSL (length-prefixed 8-bit strings).
1044     * Requires that all strings be encoded with US-ASCII.
1045     *
1046     * @param alpnProtocols the encoded form of the ALPN protocol list
1047     * @see #setAlpnProtocols(String[])
1048     */
1049    @Override
1050    public void setAlpnProtocols(byte[] alpnProtocols) {
1051        sslParameters.setAlpnProtocols(alpnProtocols);
1052    }
1053
1054    @Override
1055    public SSLParameters getSSLParameters() {
1056        SSLParameters params = super.getSSLParameters();
1057        Platform.getSSLParameters(params, sslParameters, this);
1058        return params;
1059    }
1060
1061    @Override
1062    public void setSSLParameters(SSLParameters p) {
1063        super.setSSLParameters(p);
1064        Platform.setSSLParameters(p, sslParameters, this);
1065    }
1066
1067    @Override
1068    public String chooseServerAlias(X509KeyManager keyManager, String keyType) {
1069        return keyManager.chooseServerAlias(keyType, null, this);
1070    }
1071
1072    @Override
1073    public String chooseClientAlias(X509KeyManager keyManager, X500Principal[] issuers,
1074            String[] keyTypes) {
1075        return keyManager.chooseClientAlias(keyTypes, null, this);
1076    }
1077
1078    @Override
1079    @SuppressWarnings("deprecation") // PSKKeyManager is deprecated, but in our own package
1080    public String chooseServerPSKIdentityHint(PSKKeyManager keyManager) {
1081        return keyManager.chooseServerKeyIdentityHint(this);
1082    }
1083
1084    @Override
1085    @SuppressWarnings("deprecation") // PSKKeyManager is deprecated, but in our own package
1086    public String chooseClientPSKIdentity(PSKKeyManager keyManager, String identityHint) {
1087        return keyManager.chooseClientKeyIdentity(identityHint, this);
1088    }
1089
1090    @Override
1091    @SuppressWarnings("deprecation") // PSKKeyManager is deprecated, but in our own package
1092    public SecretKey getPSKKey(PSKKeyManager keyManager, String identityHint, String identity) {
1093        return keyManager.getKey(identityHint, identity, this);
1094    }
1095
1096    private ClientSessionContext clientSessionContext() {
1097        return sslParameters.getClientSessionContext();
1098    }
1099
1100    private AbstractSessionContext sessionContext() {
1101        return sslParameters.getSessionContext();
1102    }
1103}
1104