1/*
2 * Copyright (C) 2007 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import org.conscrypt.util.ArrayUtils;
20import org.conscrypt.ct.CTVerifier;
21import org.conscrypt.ct.CTVerificationResult;
22import dalvik.system.BlockGuard;
23import dalvik.system.CloseGuard;
24import java.io.FileDescriptor;
25import java.io.IOException;
26import java.io.InputStream;
27import java.io.OutputStream;
28import java.net.InetAddress;
29import java.net.InetSocketAddress;
30import java.net.Socket;
31import java.net.SocketAddress;
32import java.net.SocketException;
33import java.security.InvalidKeyException;
34import java.security.PrivateKey;
35import java.security.SecureRandom;
36import java.security.cert.CertificateEncodingException;
37import java.security.cert.CertificateException;
38import java.security.interfaces.ECKey;
39import java.security.spec.ECParameterSpec;
40import java.util.ArrayList;
41import javax.crypto.SecretKey;
42import javax.net.ssl.HandshakeCompletedEvent;
43import javax.net.ssl.HandshakeCompletedListener;
44import javax.net.ssl.SSLException;
45import javax.net.ssl.SSLHandshakeException;
46import javax.net.ssl.SSLParameters;
47import javax.net.ssl.SSLProtocolException;
48import javax.net.ssl.SSLSession;
49import javax.net.ssl.X509KeyManager;
50import javax.net.ssl.X509TrustManager;
51import javax.security.auth.x500.X500Principal;
52
53/**
54 * Implementation of the class OpenSSLSocketImpl based on OpenSSL.
55 * <p>
56 * Extensions to SSLSocket include:
57 * <ul>
58 * <li>handshake timeout
59 * <li>session tickets
60 * <li>Server Name Indication
61 * </ul>
62 */
63public class OpenSSLSocketImpl
64        extends javax.net.ssl.SSLSocket
65        implements NativeCrypto.SSLHandshakeCallbacks, SSLParametersImpl.AliasChooser,
66        SSLParametersImpl.PSKCallbacks {
67
68    private static final boolean DBG_STATE = false;
69
70    /**
71     * Protects handshakeStarted and handshakeCompleted.
72     */
73    private final Object stateLock = new Object();
74
75    /**
76     * The {@link OpenSSLSocketImpl} object is constructed, but {@link #startHandshake()}
77     * has not yet been called.
78     */
79    private static final int STATE_NEW = 0;
80
81    /**
82     * {@link #startHandshake()} has been called at least once.
83     */
84    private static final int STATE_HANDSHAKE_STARTED = 1;
85
86    /**
87     * {@link #handshakeCompleted()} has been called, but {@link #startHandshake()} hasn't
88     * returned yet.
89     */
90    private static final int STATE_HANDSHAKE_COMPLETED = 2;
91
92    /**
93     * {@link #startHandshake()} has completed but {@link #handshakeCompleted()} hasn't
94     * been called. This is expected behaviour in cut-through mode, where SSL_do_handshake
95     * returns before the handshake is complete. We can now start writing data to the socket.
96     */
97    private static final int STATE_READY_HANDSHAKE_CUT_THROUGH = 3;
98
99    /**
100     * {@link #startHandshake()} has completed and {@link #handshakeCompleted()} has been
101     * called.
102     */
103    private static final int STATE_READY = 4;
104
105    /**
106     * {@link #close()} has been called at least once.
107     */
108    private static final int STATE_CLOSED = 5;
109
110    // @GuardedBy("stateLock");
111    private int state = STATE_NEW;
112
113    /**
114     * Protected by synchronizing on stateLock. Starts as 0, set by
115     * startHandshake, reset to 0 on close.
116     */
117    // @GuardedBy("stateLock");
118    private long sslNativePointer;
119
120    /**
121     * Protected by synchronizing on stateLock. Starts as null, set by
122     * getInputStream.
123     */
124    // @GuardedBy("stateLock");
125    private SSLInputStream is;
126
127    /**
128     * Protected by synchronizing on stateLock. Starts as null, set by
129     * getInputStream.
130     */
131    // @GuardedBy("stateLock");
132    private SSLOutputStream os;
133
134    private final Socket socket;
135    private final boolean autoClose;
136
137    /**
138     * The peer's DNS hostname if it was supplied during creation. Note that
139     * this may be a raw IP address, so it should be checked before use with
140     * extensions that don't use it like Server Name Indication (SNI).
141     */
142    private String peerHostname;
143
144    /**
145     * The peer's port if it was supplied during creation. Should only be set if
146     * {@link #peerHostname} is also set.
147     */
148    private final int peerPort;
149
150    private final SSLParametersImpl sslParameters;
151
152    /*
153     * A CloseGuard object on Android. On other platforms, this is nothing.
154     */
155    private final Object guard = Platform.closeGuardGet();
156
157    private ArrayList<HandshakeCompletedListener> listeners;
158
159    /**
160     * Private key for the TLS Channel ID extension. This field is client-side
161     * only. Set during startHandshake.
162     */
163    OpenSSLKey channelIdPrivateKey;
164
165    /** Set during startHandshake. */
166    private OpenSSLSessionImpl sslSession;
167
168    /** Used during handshake callbacks. */
169    private OpenSSLSessionImpl handshakeSession;
170
171    /**
172     * Local cache of timeout to avoid getsockopt on every read and
173     * write for non-wrapped sockets. Note that
174     * OpenSSLSocketImplWrapper overrides setSoTimeout and
175     * getSoTimeout to delegate to the wrapped socket.
176     */
177    private int readTimeoutMilliseconds = 0;
178    private int writeTimeoutMilliseconds = 0;
179
180    private int handshakeTimeoutMilliseconds = -1;  // -1 = same as timeout; 0 = infinite
181
182    protected OpenSSLSocketImpl(SSLParametersImpl sslParameters) throws IOException {
183        this.socket = this;
184        this.peerHostname = null;
185        this.peerPort = -1;
186        this.autoClose = false;
187        this.sslParameters = sslParameters;
188    }
189
190    protected OpenSSLSocketImpl(String hostname, int port, SSLParametersImpl sslParameters)
191            throws IOException {
192        super(hostname, port);
193        this.socket = this;
194        this.peerHostname = hostname;
195        this.peerPort = port;
196        this.autoClose = false;
197        this.sslParameters = sslParameters;
198    }
199
200    protected OpenSSLSocketImpl(InetAddress address, int port, SSLParametersImpl sslParameters)
201            throws IOException {
202        super(address, port);
203        this.socket = this;
204        this.peerHostname = null;
205        this.peerPort = -1;
206        this.autoClose = false;
207        this.sslParameters = sslParameters;
208    }
209
210
211    protected OpenSSLSocketImpl(String hostname, int port,
212                                InetAddress clientAddress, int clientPort,
213                                SSLParametersImpl sslParameters) throws IOException {
214        super(hostname, port, clientAddress, clientPort);
215        this.socket = this;
216        this.peerHostname = hostname;
217        this.peerPort = port;
218        this.autoClose = false;
219        this.sslParameters = sslParameters;
220    }
221
222    protected OpenSSLSocketImpl(InetAddress address, int port,
223                                InetAddress clientAddress, int clientPort,
224                                SSLParametersImpl sslParameters) throws IOException {
225        super(address, port, clientAddress, clientPort);
226        this.socket = this;
227        this.peerHostname = null;
228        this.peerPort = -1;
229        this.autoClose = false;
230        this.sslParameters = sslParameters;
231    }
232
233    /**
234     * Create an SSL socket that wraps another socket. Invoked by
235     * OpenSSLSocketImplWrapper constructor.
236     */
237    protected OpenSSLSocketImpl(Socket socket, String hostname, int port,
238            boolean autoClose, SSLParametersImpl sslParameters) throws IOException {
239        this.socket = socket;
240        this.peerHostname = hostname;
241        this.peerPort = port;
242        this.autoClose = autoClose;
243        this.sslParameters = sslParameters;
244
245        // this.timeout is not set intentionally.
246        // OpenSSLSocketImplWrapper.getSoTimeout will delegate timeout
247        // to wrapped socket
248    }
249
250    @Override
251    public void connect(SocketAddress endpoint) throws IOException {
252        connect(endpoint, 0);
253    }
254
255    /**
256     * Try to extract the peer's hostname if it's available from the endpoint address.
257     */
258    @Override
259    public void connect(SocketAddress endpoint, int timeout) throws IOException {
260        if (peerHostname == null && endpoint instanceof InetSocketAddress) {
261            peerHostname = Platform.getHostStringFromInetSocketAddress(
262                    (InetSocketAddress) endpoint);
263        }
264
265        super.connect(endpoint, timeout);
266    }
267
268    private void checkOpen() throws SocketException {
269        if (isClosed()) {
270            throw new SocketException("Socket is closed");
271        }
272    }
273
274    /**
275     * Starts a TLS/SSL handshake on this connection using some native methods
276     * from the OpenSSL library. It can negotiate new encryption keys, change
277     * cipher suites, or initiate a new session. The certificate chain is
278     * verified if the correspondent property in java.Security is set. All
279     * listeners are notified at the end of the TLS/SSL handshake.
280     */
281    @Override
282    public void startHandshake() throws IOException {
283        checkOpen();
284        synchronized (stateLock) {
285            if (state == STATE_NEW) {
286                state = STATE_HANDSHAKE_STARTED;
287            } else {
288                // We've either started the handshake already or have been closed.
289                // Do nothing in both cases.
290                return;
291            }
292        }
293
294        // For BoringSSL, RAND_seed and RAND_load_file are no-ops since the RNG
295        // reads directly from the random device node.
296        if (!NativeCrypto.isBoringSSL) {
297            // note that this modifies the global seed, not something specific
298            // to the connection
299            final int seedLengthInBytes = NativeCrypto.RAND_SEED_LENGTH_IN_BYTES;
300            final SecureRandom secureRandom = sslParameters.getSecureRandomMember();
301            if (secureRandom == null) {
302                NativeCrypto.RAND_load_file("/dev/urandom", seedLengthInBytes);
303            } else {
304                NativeCrypto.RAND_seed(secureRandom.generateSeed(seedLengthInBytes));
305            }
306        }
307
308        final boolean client = sslParameters.getUseClientMode();
309
310        sslNativePointer = 0;
311        boolean releaseResources = true;
312        try {
313            final AbstractSessionContext sessionContext = sslParameters.getSessionContext();
314            final long sslCtxNativePointer = sessionContext.sslCtxNativePointer;
315            sslNativePointer = NativeCrypto.SSL_new(sslCtxNativePointer);
316            Platform.closeGuardOpen(guard, "close");
317
318            boolean enableSessionCreation = getEnableSessionCreation();
319            if (!enableSessionCreation) {
320                NativeCrypto.SSL_set_session_creation_enabled(sslNativePointer,
321                        enableSessionCreation);
322            }
323
324            // Allow servers to trigger renegotiation. Some inadvisable server
325            // configurations cause them to attempt to renegotiate during
326            // certain protocols.
327            NativeCrypto.SSL_set_reject_peer_renegotiations(sslNativePointer, false);
328
329            if (client && sslParameters.isCTVerificationEnabled(getHostname())) {
330                NativeCrypto.SSL_enable_signed_cert_timestamps(sslNativePointer);
331                NativeCrypto.SSL_enable_ocsp_stapling(sslNativePointer);
332            }
333
334            final OpenSSLSessionImpl sessionToReuse = sslParameters.getSessionToReuse(
335                    sslNativePointer, getHostnameOrIP(), getPort());
336            sslParameters.setSSLParameters(sslCtxNativePointer, sslNativePointer, this, this,
337                    getHostname());
338            sslParameters.setCertificateValidation(sslNativePointer);
339            sslParameters.setTlsChannelId(sslNativePointer, channelIdPrivateKey);
340
341            // Temporarily use a different timeout for the handshake process
342            int savedReadTimeoutMilliseconds = getSoTimeout();
343            int savedWriteTimeoutMilliseconds = getSoWriteTimeout();
344            if (handshakeTimeoutMilliseconds >= 0) {
345                setSoTimeout(handshakeTimeoutMilliseconds);
346                setSoWriteTimeout(handshakeTimeoutMilliseconds);
347            }
348
349            synchronized (stateLock) {
350                if (state == STATE_CLOSED) {
351                    return;
352                }
353            }
354
355            long sslSessionNativePointer;
356            try {
357                sslSessionNativePointer = NativeCrypto.SSL_do_handshake(sslNativePointer,
358                        Platform.getFileDescriptor(socket), this, getSoTimeout(), client,
359                        sslParameters.npnProtocols, client ? null : sslParameters.alpnProtocols);
360            } catch (CertificateException e) {
361                SSLHandshakeException wrapper = new SSLHandshakeException(e.getMessage());
362                wrapper.initCause(e);
363                throw wrapper;
364            } catch (SSLException e) {
365                // Swallow this exception if it's thrown as the result of an interruption.
366                //
367                // TODO: SSL_read and SSL_write return -1 when interrupted, but SSL_do_handshake
368                // will throw the last sslError that it saw before sslSelect, usually SSL_WANT_READ
369                // (or WANT_WRITE). Catching that exception here doesn't seem much worse than
370                // changing the native code to return a "special" native pointer value when that
371                // happens.
372                synchronized (stateLock) {
373                    if (state == STATE_CLOSED) {
374                        return;
375                    }
376                }
377
378                // Write CCS errors to EventLog
379                String message = e.getMessage();
380                // Must match error string of SSL_R_UNEXPECTED_CCS
381                if (message.contains("unexpected CCS")) {
382                    String logMessage = String.format("ssl_unexpected_ccs: host=%s",
383                            getHostnameOrIP());
384                    Platform.logEvent(logMessage);
385                }
386
387                throw e;
388            }
389
390            boolean handshakeCompleted = false;
391            synchronized (stateLock) {
392                if (state == STATE_HANDSHAKE_COMPLETED) {
393                    handshakeCompleted = true;
394                } else if (state == STATE_CLOSED) {
395                    return;
396                }
397            }
398
399            sslSession = sslParameters.setupSession(sslSessionNativePointer, sslNativePointer,
400                    sessionToReuse, getHostnameOrIP(), getPort(), handshakeCompleted);
401
402            // Restore the original timeout now that the handshake is complete
403            if (handshakeTimeoutMilliseconds >= 0) {
404                setSoTimeout(savedReadTimeoutMilliseconds);
405                setSoWriteTimeout(savedWriteTimeoutMilliseconds);
406            }
407
408            // if not, notifyHandshakeCompletedListeners later in handshakeCompleted() callback
409            if (handshakeCompleted) {
410                notifyHandshakeCompletedListeners();
411            }
412
413            synchronized (stateLock) {
414                releaseResources = (state == STATE_CLOSED);
415
416                if (state == STATE_HANDSHAKE_STARTED) {
417                    state = STATE_READY_HANDSHAKE_CUT_THROUGH;
418                } else if (state == STATE_HANDSHAKE_COMPLETED) {
419                    state = STATE_READY;
420                }
421
422                if (!releaseResources) {
423                    // Unblock threads that are waiting for our state to transition
424                    // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH.
425                    stateLock.notifyAll();
426                }
427            }
428        } catch (SSLProtocolException e) {
429            throw (SSLHandshakeException) new SSLHandshakeException("Handshake failed")
430                    .initCause(e);
431        } finally {
432            // on exceptional exit, treat the socket as closed
433            if (releaseResources) {
434                synchronized (stateLock) {
435                    // Mark the socket as closed since we might have reached this as
436                    // a result on an exception thrown by the handshake process.
437                    //
438                    // The state will already be set to closed if we reach this as a result of
439                    // an early return or an interruption due to a concurrent call to close().
440                    state = STATE_CLOSED;
441                    stateLock.notifyAll();
442                }
443
444                try {
445                    shutdownAndFreeSslNative();
446                } catch (IOException ignored) {
447
448                }
449            }
450        }
451    }
452
453    /**
454     * Returns the hostname that was supplied during socket creation. No DNS resolution is
455     * attempted before returning the hostname.
456     */
457    public String getHostname() {
458        return peerHostname;
459    }
460
461    /**
462     * For the purposes of an SSLSession, we want a way to represent the supplied hostname
463     * or the IP address in a textual representation. We do not want to perform reverse DNS
464     * lookups on this address.
465     */
466    public String getHostnameOrIP() {
467        if (peerHostname != null) {
468            return peerHostname;
469        }
470
471        InetAddress peerAddress = getInetAddress();
472        if (peerAddress != null) {
473            return peerAddress.getHostAddress();
474        }
475
476        return null;
477    }
478
479    @Override
480    public int getPort() {
481        return peerPort == -1 ? super.getPort() : peerPort;
482    }
483
484    @Override
485    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / client_cert_cb
486    public void clientCertificateRequested(byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals)
487            throws CertificateEncodingException, SSLException {
488        sslParameters.chooseClientCertificate(keyTypeBytes, asn1DerEncodedPrincipals,
489                sslNativePointer, this);
490    }
491
492    @Override
493    @SuppressWarnings("unused") // used by native psk_client_callback
494    public int clientPSKKeyRequested(String identityHint, byte[] identity, byte[] key) {
495        return sslParameters.clientPSKKeyRequested(identityHint, identity, key, this);
496    }
497
498    @Override
499    @SuppressWarnings("unused") // used by native psk_server_callback
500    public int serverPSKKeyRequested(String identityHint, String identity, byte[] key) {
501        return sslParameters.serverPSKKeyRequested(identityHint, identity, key, this);
502    }
503
504    @Override
505    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / info_callback
506    public void onSSLStateChange(long sslSessionNativePtr, int type, int val) {
507        if (type != NativeConstants.SSL_CB_HANDSHAKE_DONE) {
508            return;
509        }
510
511        synchronized (stateLock) {
512            if (state == STATE_HANDSHAKE_STARTED) {
513                // If sslSession is null, the handshake was completed during
514                // the call to NativeCrypto.SSL_do_handshake and not during a
515                // later read operation. That means we do not need to fix up
516                // the SSLSession and session cache or notify
517                // HandshakeCompletedListeners, it will be done in
518                // startHandshake.
519
520                state = STATE_HANDSHAKE_COMPLETED;
521                return;
522            } else if (state == STATE_READY_HANDSHAKE_CUT_THROUGH) {
523                // We've returned from startHandshake, which means we've set a sslSession etc.
524                // we need to fix them up, which we'll do outside this lock.
525            } else if (state == STATE_CLOSED) {
526                // Someone called "close" but the handshake hasn't been interrupted yet.
527                return;
528            }
529        }
530
531        // reset session id from the native pointer and update the
532        // appropriate cache.
533        sslSession.resetId();
534        AbstractSessionContext sessionContext =
535            (sslParameters.getUseClientMode())
536            ? sslParameters.getClientSessionContext()
537                : sslParameters.getServerSessionContext();
538        sessionContext.putSession(sslSession);
539
540        // let listeners know we are finally done
541        notifyHandshakeCompletedListeners();
542
543        synchronized (stateLock) {
544            // Now that we've fixed up our state, we can tell waiting threads that
545            // we're ready.
546            state = STATE_READY;
547            // Notify all threads waiting for the handshake to complete.
548            stateLock.notifyAll();
549        }
550    }
551
552    private void notifyHandshakeCompletedListeners() {
553        if (listeners != null && !listeners.isEmpty()) {
554            // notify the listeners
555            HandshakeCompletedEvent event =
556                new HandshakeCompletedEvent(this, sslSession);
557            for (HandshakeCompletedListener listener : listeners) {
558                try {
559                    listener.handshakeCompleted(event);
560                } catch (RuntimeException e) {
561                    // The RI runs the handlers in a separate thread,
562                    // which we do not. But we try to preserve their
563                    // behavior of logging a problem and not killing
564                    // the handshaking thread just because a listener
565                    // has a problem.
566                    Thread thread = Thread.currentThread();
567                    thread.getUncaughtExceptionHandler().uncaughtException(thread, e);
568                }
569            }
570        }
571    }
572
573    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks
574    @Override
575    public void verifyCertificateChain(long sslSessionNativePtr, long[] certRefs, String authMethod)
576            throws CertificateException {
577        try {
578            X509TrustManager x509tm = sslParameters.getX509TrustManager();
579            if (x509tm == null) {
580                throw new CertificateException("No X.509 TrustManager");
581            }
582            if (certRefs == null || certRefs.length == 0) {
583                throw new SSLException("Peer sent no certificate");
584            }
585            OpenSSLX509Certificate[] peerCertChain = new OpenSSLX509Certificate[certRefs.length];
586            for (int i = 0; i < certRefs.length; i++) {
587                peerCertChain[i] = new OpenSSLX509Certificate(certRefs[i]);
588            }
589
590            // Used for verifyCertificateChain callback
591            handshakeSession = new OpenSSLSessionImpl(sslSessionNativePtr, null, peerCertChain,
592                    getHostnameOrIP(), getPort(), null);
593
594            boolean client = sslParameters.getUseClientMode();
595            if (client) {
596                Platform.checkServerTrusted(x509tm, peerCertChain, authMethod, this);
597                if (sslParameters.isCTVerificationEnabled(getHostname())) {
598                    byte[] tlsData = NativeCrypto.SSL_get_signed_cert_timestamp_list(
599                                        sslNativePointer);
600                    byte[] ocspData = NativeCrypto.SSL_get_ocsp_response(sslNativePointer);
601
602                    CTVerifier ctVerifier = sslParameters.getCTVerifier();
603                    CTVerificationResult result =
604                        ctVerifier.verifySignedCertificateTimestamps(peerCertChain, tlsData, ocspData);
605
606                    if (result.getValidSCTs().size() == 0) {
607                        throw new CertificateException("No valid SCT found");
608                    }
609                }
610            } else {
611                String authType = peerCertChain[0].getPublicKey().getAlgorithm();
612                Platform.checkClientTrusted(x509tm, peerCertChain, authType, this);
613            }
614        } catch (CertificateException e) {
615            throw e;
616        } catch (Exception e) {
617            throw new CertificateException(e);
618        } finally {
619            // Clear this before notifying handshake completed listeners
620            handshakeSession = null;
621        }
622    }
623
624    @Override
625    public InputStream getInputStream() throws IOException {
626        checkOpen();
627
628        InputStream returnVal;
629        synchronized (stateLock) {
630            if (state == STATE_CLOSED) {
631                throw new SocketException("Socket is closed.");
632            }
633
634            if (is == null) {
635                is = new SSLInputStream();
636            }
637
638            returnVal = is;
639        }
640
641        // Block waiting for a handshake without a lock held. It's possible that the socket
642        // is closed at this point. If that happens, we'll still return the input stream but
643        // all reads on it will throw.
644        waitForHandshake();
645        return returnVal;
646    }
647
648    @Override
649    public OutputStream getOutputStream() throws IOException {
650        checkOpen();
651
652        OutputStream returnVal;
653        synchronized (stateLock) {
654            if (state == STATE_CLOSED) {
655                throw new SocketException("Socket is closed.");
656            }
657
658            if (os == null) {
659                os = new SSLOutputStream();
660            }
661
662            returnVal = os;
663        }
664
665        // Block waiting for a handshake without a lock held. It's possible that the socket
666        // is closed at this point. If that happens, we'll still return the output stream but
667        // all writes on it will throw.
668        waitForHandshake();
669        return returnVal;
670    }
671
672    private void assertReadableOrWriteableState() {
673        if (state == STATE_READY || state == STATE_READY_HANDSHAKE_CUT_THROUGH) {
674            return;
675        }
676
677        throw new AssertionError("Invalid state: " + state);
678    }
679
680
681    private void waitForHandshake() throws IOException {
682        startHandshake();
683
684        synchronized (stateLock) {
685            while (state != STATE_READY &&
686                    state != STATE_READY_HANDSHAKE_CUT_THROUGH &&
687                    state != STATE_CLOSED) {
688                try {
689                    stateLock.wait();
690                } catch (InterruptedException e) {
691                    Thread.currentThread().interrupt();
692                    IOException ioe = new IOException("Interrupted waiting for handshake");
693                    ioe.initCause(e);
694
695                    throw ioe;
696                }
697            }
698
699            if (state == STATE_CLOSED) {
700                throw new SocketException("Socket is closed");
701            }
702        }
703    }
704
705    /**
706     * This inner class provides input data stream functionality
707     * for the OpenSSL native implementation. It is used to
708     * read data received via SSL protocol.
709     */
710    private class SSLInputStream extends InputStream {
711        /**
712         * OpenSSL only lets one thread read at a time, so this is used to
713         * make sure we serialize callers of SSL_read. Thread is already
714         * expected to have completed handshaking.
715         */
716        private final Object readLock = new Object();
717
718        SSLInputStream() {
719        }
720
721        /**
722         * Reads one byte. If there is no data in the underlying buffer,
723         * this operation can block until the data will be
724         * available.
725         * @return read value.
726         * @throws IOException
727         */
728        @Override
729        public int read() throws IOException {
730            byte[] buffer = new byte[1];
731            int result = read(buffer, 0, 1);
732            return (result != -1) ? buffer[0] & 0xff : -1;
733        }
734
735        /**
736         * Method acts as described in spec for superclass.
737         * @see java.io.InputStream#read(byte[],int,int)
738         */
739        @Override
740        public int read(byte[] buf, int offset, int byteCount) throws IOException {
741            Platform.blockGuardOnNetwork();
742
743            checkOpen();
744            ArrayUtils.checkOffsetAndCount(buf.length, offset, byteCount);
745            if (byteCount == 0) {
746                return 0;
747            }
748
749            synchronized (readLock) {
750                synchronized (stateLock) {
751                    if (state == STATE_CLOSED) {
752                        throw new SocketException("socket is closed");
753                    }
754
755                    if (DBG_STATE) assertReadableOrWriteableState();
756                }
757
758                return NativeCrypto.SSL_read(sslNativePointer, Platform.getFileDescriptor(socket),
759                        OpenSSLSocketImpl.this, buf, offset, byteCount, getSoTimeout());
760            }
761        }
762
763        public void awaitPendingOps() {
764            if (DBG_STATE) {
765                synchronized (stateLock) {
766                    if (state != STATE_CLOSED) throw new AssertionError("State is: " + state);
767                }
768            }
769
770            synchronized (readLock) { }
771        }
772    }
773
774    /**
775     * This inner class provides output data stream functionality
776     * for the OpenSSL native implementation. It is used to
777     * write data according to the encryption parameters given in SSL context.
778     */
779    private class SSLOutputStream extends OutputStream {
780
781        /**
782         * OpenSSL only lets one thread write at a time, so this is used
783         * to make sure we serialize callers of SSL_write. Thread is
784         * already expected to have completed handshaking.
785         */
786        private final Object writeLock = new Object();
787
788        SSLOutputStream() {
789        }
790
791        /**
792         * Method acts as described in spec for superclass.
793         * @see java.io.OutputStream#write(int)
794         */
795        @Override
796        public void write(int oneByte) throws IOException {
797            byte[] buffer = new byte[1];
798            buffer[0] = (byte) (oneByte & 0xff);
799            write(buffer);
800        }
801
802        /**
803         * Method acts as described in spec for superclass.
804         * @see java.io.OutputStream#write(byte[],int,int)
805         */
806        @Override
807        public void write(byte[] buf, int offset, int byteCount) throws IOException {
808            Platform.blockGuardOnNetwork();
809            checkOpen();
810            ArrayUtils.checkOffsetAndCount(buf.length, offset, byteCount);
811            if (byteCount == 0) {
812                return;
813            }
814
815            synchronized (writeLock) {
816                synchronized (stateLock) {
817                    if (state == STATE_CLOSED) {
818                        throw new SocketException("socket is closed");
819                    }
820
821                    if (DBG_STATE) assertReadableOrWriteableState();
822                }
823
824                NativeCrypto.SSL_write(sslNativePointer, Platform.getFileDescriptor(socket),
825                        OpenSSLSocketImpl.this, buf, offset, byteCount, writeTimeoutMilliseconds);
826            }
827        }
828
829
830        public void awaitPendingOps() {
831            if (DBG_STATE) {
832                synchronized (stateLock) {
833                    if (state != STATE_CLOSED) throw new AssertionError("State is: " + state);
834                }
835            }
836
837            synchronized (writeLock) { }
838        }
839    }
840
841
842    @Override
843    public SSLSession getSession() {
844        if (sslSession == null) {
845            try {
846                waitForHandshake();
847            } catch (IOException e) {
848                // return an invalid session with
849                // invalid cipher suite of "SSL_NULL_WITH_NULL_NULL"
850                return SSLNullSession.getNullSession();
851            }
852        }
853        return Platform.wrapSSLSession(sslSession);
854    }
855
856    // Comment annotation to compile Conscrypt unbundled with Java 6.
857    /* @Override */
858    public SSLSession getHandshakeSession() {
859        return handshakeSession;
860    }
861
862    @Override
863    public void addHandshakeCompletedListener(
864            HandshakeCompletedListener listener) {
865        if (listener == null) {
866            throw new IllegalArgumentException("Provided listener is null");
867        }
868        if (listeners == null) {
869            listeners = new ArrayList<HandshakeCompletedListener>();
870        }
871        listeners.add(listener);
872    }
873
874    @Override
875    public void removeHandshakeCompletedListener(
876            HandshakeCompletedListener listener) {
877        if (listener == null) {
878            throw new IllegalArgumentException("Provided listener is null");
879        }
880        if (listeners == null) {
881            throw new IllegalArgumentException(
882                    "Provided listener is not registered");
883        }
884        if (!listeners.remove(listener)) {
885            throw new IllegalArgumentException(
886                    "Provided listener is not registered");
887        }
888    }
889
890    @Override
891    public boolean getEnableSessionCreation() {
892        return sslParameters.getEnableSessionCreation();
893    }
894
895    @Override
896    public void setEnableSessionCreation(boolean flag) {
897        sslParameters.setEnableSessionCreation(flag);
898    }
899
900    @Override
901    public String[] getSupportedCipherSuites() {
902        return NativeCrypto.getSupportedCipherSuites();
903    }
904
905    @Override
906    public String[] getEnabledCipherSuites() {
907        return sslParameters.getEnabledCipherSuites();
908    }
909
910    @Override
911    public void setEnabledCipherSuites(String[] suites) {
912        sslParameters.setEnabledCipherSuites(suites);
913    }
914
915    @Override
916    public String[] getSupportedProtocols() {
917        return NativeCrypto.getSupportedProtocols();
918    }
919
920    @Override
921    public String[] getEnabledProtocols() {
922        return sslParameters.getEnabledProtocols();
923    }
924
925    @Override
926    public void setEnabledProtocols(String[] protocols) {
927        sslParameters.setEnabledProtocols(protocols);
928    }
929
930    /**
931     * This method enables session ticket support.
932     *
933     * @param useSessionTickets True to enable session tickets
934     */
935    public void setUseSessionTickets(boolean useSessionTickets) {
936        sslParameters.useSessionTickets = useSessionTickets;
937    }
938
939    /**
940     * This method enables Server Name Indication
941     *
942     * @param hostname the desired SNI hostname, or null to disable
943     */
944    public void setHostname(String hostname) {
945        sslParameters.setUseSni(hostname != null);
946        peerHostname = hostname;
947    }
948
949    /**
950     * Enables/disables TLS Channel ID for this server socket.
951     *
952     * <p>This method needs to be invoked before the handshake starts.
953     *
954     * @throws IllegalStateException if this is a client socket or if the handshake has already
955     *         started.
956     */
957    public void setChannelIdEnabled(boolean enabled) {
958        if (getUseClientMode()) {
959            throw new IllegalStateException("Client mode");
960        }
961
962        synchronized (stateLock) {
963            if (state != STATE_NEW) {
964                throw new IllegalStateException(
965                        "Could not enable/disable Channel ID after the initial handshake has"
966                                + " begun.");
967            }
968        }
969        sslParameters.channelIdEnabled = enabled;
970    }
971
972    /**
973     * Gets the TLS Channel ID for this server socket. Channel ID is only available once the
974     * handshake completes.
975     *
976     * @return channel ID or {@code null} if not available.
977     *
978     * @throws IllegalStateException if this is a client socket or if the handshake has not yet
979     *         completed.
980     * @throws SSLException if channel ID is available but could not be obtained.
981     */
982    public byte[] getChannelId() throws SSLException {
983        if (getUseClientMode()) {
984            throw new IllegalStateException("Client mode");
985        }
986
987        synchronized (stateLock) {
988            if (state != STATE_READY) {
989                throw new IllegalStateException(
990                        "Channel ID is only available after handshake completes");
991            }
992        }
993        return NativeCrypto.SSL_get_tls_channel_id(sslNativePointer);
994    }
995
996    /**
997     * Sets the {@link PrivateKey} to be used for TLS Channel ID by this client socket.
998     *
999     * <p>This method needs to be invoked before the handshake starts.
1000     *
1001     * @param privateKey private key (enables TLS Channel ID) or {@code null} for no key (disables
1002     *        TLS Channel ID). The private key must be an Elliptic Curve (EC) key based on the NIST
1003     *        P-256 curve (aka SECG secp256r1 or ANSI X9.62 prime256v1).
1004     *
1005     * @throws IllegalStateException if this is a server socket or if the handshake has already
1006     *         started.
1007     */
1008    public void setChannelIdPrivateKey(PrivateKey privateKey) {
1009        if (!getUseClientMode()) {
1010            throw new IllegalStateException("Server mode");
1011        }
1012
1013        synchronized (stateLock) {
1014            if (state != STATE_NEW) {
1015                throw new IllegalStateException(
1016                        "Could not change Channel ID private key after the initial handshake has"
1017                                + " begun.");
1018            }
1019        }
1020
1021        if (privateKey == null) {
1022            sslParameters.channelIdEnabled = false;
1023            channelIdPrivateKey = null;
1024        } else {
1025            sslParameters.channelIdEnabled = true;
1026            try {
1027                ECParameterSpec ecParams = null;
1028                if (privateKey instanceof ECKey) {
1029                    ecParams = ((ECKey) privateKey).getParams();
1030                }
1031                if (ecParams == null) {
1032                    // Assume this is a P-256 key, as specified in the contract of this method.
1033                    ecParams =
1034                            OpenSSLECGroupContext.getCurveByName("prime256v1").getECParameterSpec();
1035                }
1036                channelIdPrivateKey =
1037                        OpenSSLKey.fromECPrivateKeyForTLSStackOnly(privateKey, ecParams);
1038            } catch (InvalidKeyException e) {
1039                // Will have error in startHandshake
1040            }
1041        }
1042    }
1043
1044    @Override
1045    public boolean getUseClientMode() {
1046        return sslParameters.getUseClientMode();
1047    }
1048
1049    @Override
1050    public void setUseClientMode(boolean mode) {
1051        synchronized (stateLock) {
1052            if (state != STATE_NEW) {
1053                throw new IllegalArgumentException(
1054                        "Could not change the mode after the initial handshake has begun.");
1055            }
1056        }
1057        sslParameters.setUseClientMode(mode);
1058    }
1059
1060    @Override
1061    public boolean getWantClientAuth() {
1062        return sslParameters.getWantClientAuth();
1063    }
1064
1065    @Override
1066    public boolean getNeedClientAuth() {
1067        return sslParameters.getNeedClientAuth();
1068    }
1069
1070    @Override
1071    public void setNeedClientAuth(boolean need) {
1072        sslParameters.setNeedClientAuth(need);
1073    }
1074
1075    @Override
1076    public void setWantClientAuth(boolean want) {
1077        sslParameters.setWantClientAuth(want);
1078    }
1079
1080    @Override
1081    public void sendUrgentData(int data) throws IOException {
1082        throw new SocketException("Method sendUrgentData() is not supported.");
1083    }
1084
1085    @Override
1086    public void setOOBInline(boolean on) throws SocketException {
1087        throw new SocketException("Methods sendUrgentData, setOOBInline are not supported.");
1088    }
1089
1090    @Override
1091    public void setSoTimeout(int readTimeoutMilliseconds) throws SocketException {
1092        if (socket != this) {
1093            socket.setSoTimeout(readTimeoutMilliseconds);
1094        } else {
1095            super.setSoTimeout(readTimeoutMilliseconds);
1096        }
1097
1098        this.readTimeoutMilliseconds = readTimeoutMilliseconds;
1099    }
1100
1101    @Override
1102    public int getSoTimeout() throws SocketException {
1103        return readTimeoutMilliseconds;
1104    }
1105
1106    /**
1107     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
1108     */
1109    public void setSoWriteTimeout(int writeTimeoutMilliseconds) throws SocketException {
1110        this.writeTimeoutMilliseconds = writeTimeoutMilliseconds;
1111
1112        Platform.setSocketWriteTimeout(this, writeTimeoutMilliseconds);
1113    }
1114
1115    /**
1116     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
1117     */
1118    public int getSoWriteTimeout() throws SocketException {
1119        return writeTimeoutMilliseconds;
1120    }
1121
1122    /**
1123     * Set the handshake timeout on this socket.  This timeout is specified in
1124     * milliseconds and will be used only during the handshake process.
1125     */
1126    public void setHandshakeTimeout(int handshakeTimeoutMilliseconds) throws SocketException {
1127        this.handshakeTimeoutMilliseconds = handshakeTimeoutMilliseconds;
1128    }
1129
1130    @Override
1131    public void close() throws IOException {
1132        // TODO: Close SSL sockets using a background thread so they close gracefully.
1133
1134        SSLInputStream sslInputStream = null;
1135        SSLOutputStream sslOutputStream = null;
1136
1137        synchronized (stateLock) {
1138            if (state == STATE_CLOSED) {
1139                // close() has already been called, so do nothing and return.
1140                return;
1141            }
1142
1143            int oldState = state;
1144            state = STATE_CLOSED;
1145
1146            if (oldState == STATE_NEW) {
1147                // The handshake hasn't been started yet, so there's no OpenSSL related
1148                // state to clean up. We still need to close the underlying socket if
1149                // we're wrapping it and were asked to autoClose.
1150                closeUnderlyingSocket();
1151
1152                stateLock.notifyAll();
1153                return;
1154            }
1155
1156            if (oldState != STATE_READY && oldState != STATE_READY_HANDSHAKE_CUT_THROUGH) {
1157                // If we're in these states, we still haven't returned from startHandshake.
1158                // We call SSL_interrupt so that we can interrupt SSL_do_handshake and then
1159                // set the state to STATE_CLOSED. startHandshake will handle all cleanup
1160                // after SSL_do_handshake returns, so we don't have anything to do here.
1161                NativeCrypto.SSL_interrupt(sslNativePointer);
1162
1163                stateLock.notifyAll();
1164                return;
1165            }
1166
1167            stateLock.notifyAll();
1168            // We've already returned from startHandshake, so we potentially have
1169            // input and output streams to clean up.
1170            sslInputStream = is;
1171            sslOutputStream = os;
1172        }
1173
1174        // Don't bother interrupting unless we have something to interrupt.
1175        if (sslInputStream != null || sslOutputStream != null) {
1176            NativeCrypto.SSL_interrupt(sslNativePointer);
1177        }
1178
1179        // Wait for the input and output streams to finish any reads they have in
1180        // progress. If there are no reads in progress at this point, future reads will
1181        // throw because state == STATE_CLOSED
1182        if (sslInputStream != null) {
1183            sslInputStream.awaitPendingOps();
1184        }
1185        if (sslOutputStream != null) {
1186            sslOutputStream.awaitPendingOps();
1187        }
1188
1189        shutdownAndFreeSslNative();
1190    }
1191
1192    private void shutdownAndFreeSslNative() throws IOException {
1193        try {
1194            Platform.blockGuardOnNetwork();
1195            NativeCrypto.SSL_shutdown(sslNativePointer, Platform.getFileDescriptor(socket),
1196                    this);
1197        } catch (IOException ignored) {
1198            /*
1199            * Note that although close() can throw
1200            * IOException, the RI does not throw if there
1201            * is problem sending a "close notify" which
1202            * can happen if the underlying socket is closed.
1203            */
1204        } finally {
1205            free();
1206            closeUnderlyingSocket();
1207        }
1208    }
1209
1210    private void closeUnderlyingSocket() throws IOException {
1211        if (socket != this) {
1212            if (autoClose && !socket.isClosed()) {
1213                socket.close();
1214            }
1215        } else {
1216            if (!super.isClosed()) {
1217                super.close();
1218            }
1219        }
1220    }
1221
1222    private void free() {
1223        if (sslNativePointer == 0) {
1224            return;
1225        }
1226        NativeCrypto.SSL_free(sslNativePointer);
1227        sslNativePointer = 0;
1228        Platform.closeGuardClose(guard);
1229    }
1230
1231    @Override
1232    protected void finalize() throws Throwable {
1233        try {
1234            /*
1235             * Just worry about our own state. Notably we do not try and
1236             * close anything. The SocketImpl, either our own
1237             * PlainSocketImpl, or the Socket we are wrapping, will do
1238             * that. This might mean we do not properly SSL_shutdown, but
1239             * if you want to do that, properly close the socket yourself.
1240             *
1241             * The reason why we don't try to SSL_shutdown, is that there
1242             * can be a race between finalizers where the PlainSocketImpl
1243             * finalizer runs first and closes the socket. However, in the
1244             * meanwhile, the underlying file descriptor could be reused
1245             * for another purpose. If we call SSL_shutdown, the
1246             * underlying socket BIOs still have the old file descriptor
1247             * and will write the close notify to some unsuspecting
1248             * reader.
1249             */
1250            if (guard != null) {
1251                Platform.closeGuardWarnIfOpen(guard);
1252            }
1253            free();
1254        } finally {
1255            super.finalize();
1256        }
1257    }
1258
1259    /* @Override */
1260    public FileDescriptor getFileDescriptor$() {
1261        if (socket == this) {
1262            return Platform.getFileDescriptorFromSSLSocket(this);
1263        } else {
1264            return Platform.getFileDescriptor(socket);
1265        }
1266    }
1267
1268    /**
1269     * Returns the protocol agreed upon by client and server, or null if no
1270     * protocol was agreed upon.
1271     */
1272    public byte[] getNpnSelectedProtocol() {
1273        return NativeCrypto.SSL_get_npn_negotiated_protocol(sslNativePointer);
1274    }
1275
1276    /**
1277     * Returns the protocol agreed upon by client and server, or {@code null} if
1278     * no protocol was agreed upon.
1279     */
1280    public byte[] getAlpnSelectedProtocol() {
1281        return NativeCrypto.SSL_get0_alpn_selected(sslNativePointer);
1282    }
1283
1284    /**
1285     * Sets the list of protocols this peer is interested in. If null no
1286     * protocols will be used.
1287     *
1288     * @param npnProtocols a non-empty array of protocol names. From
1289     *     SSL_select_next_proto, "vector of 8-bit, length prefixed byte
1290     *     strings. The length byte itself is not included in the length. A byte
1291     *     string of length 0 is invalid. No byte string may be truncated.".
1292     */
1293    public void setNpnProtocols(byte[] npnProtocols) {
1294        if (npnProtocols != null && npnProtocols.length == 0) {
1295            throw new IllegalArgumentException("npnProtocols.length == 0");
1296        }
1297        sslParameters.npnProtocols = npnProtocols;
1298    }
1299
1300    /**
1301     * Sets the list of protocols this peer is interested in. If the list is
1302     * {@code null}, no protocols will be used.
1303     *
1304     * @param alpnProtocols a non-empty array of protocol names. From
1305     *            SSL_select_next_proto, "vector of 8-bit, length prefixed byte
1306     *            strings. The length byte itself is not included in the length.
1307     *            A byte string of length 0 is invalid. No byte string may be
1308     *            truncated.".
1309     */
1310    public void setAlpnProtocols(byte[] alpnProtocols) {
1311        if (alpnProtocols != null && alpnProtocols.length == 0) {
1312            throw new IllegalArgumentException("alpnProtocols.length == 0");
1313        }
1314        sslParameters.alpnProtocols = alpnProtocols;
1315    }
1316
1317    @Override
1318    public SSLParameters getSSLParameters() {
1319        SSLParameters params = super.getSSLParameters();
1320        Platform.getSSLParameters(params, sslParameters, this);
1321        return params;
1322    }
1323
1324    @Override
1325    public void setSSLParameters(SSLParameters p) {
1326        super.setSSLParameters(p);
1327        Platform.setSSLParameters(p, sslParameters, this);
1328    }
1329
1330    @Override
1331    public String chooseServerAlias(X509KeyManager keyManager, String keyType) {
1332        return keyManager.chooseServerAlias(keyType, null, this);
1333    }
1334
1335    @Override
1336    public String chooseClientAlias(X509KeyManager keyManager, X500Principal[] issuers,
1337            String[] keyTypes) {
1338        return keyManager.chooseClientAlias(keyTypes, null, this);
1339    }
1340
1341    @Override
1342    public String chooseServerPSKIdentityHint(PSKKeyManager keyManager) {
1343        return keyManager.chooseServerKeyIdentityHint(this);
1344    }
1345
1346    @Override
1347    public String chooseClientPSKIdentity(PSKKeyManager keyManager, String identityHint) {
1348        return keyManager.chooseClientKeyIdentity(identityHint, this);
1349    }
1350
1351    @Override
1352    public SecretKey getPSKKey(PSKKeyManager keyManager, String identityHint, String identity) {
1353        return keyManager.getKey(identityHint, identity, this);
1354    }
1355}
1356