1/*
2 * Copyright (C) 2007 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import org.conscrypt.util.Arrays;
20import dalvik.system.BlockGuard;
21import dalvik.system.CloseGuard;
22import java.io.FileDescriptor;
23import java.io.IOException;
24import java.io.InputStream;
25import java.io.OutputStream;
26import java.net.InetAddress;
27import java.net.Socket;
28import java.net.SocketException;
29import java.security.InvalidKeyException;
30import java.security.PrivateKey;
31import java.security.SecureRandom;
32import java.security.cert.CertificateEncodingException;
33import java.security.cert.CertificateException;
34import java.util.ArrayList;
35import javax.crypto.SecretKey;
36import javax.net.ssl.HandshakeCompletedEvent;
37import javax.net.ssl.HandshakeCompletedListener;
38import javax.net.ssl.SSLException;
39import javax.net.ssl.SSLHandshakeException;
40import javax.net.ssl.SSLParameters;
41import javax.net.ssl.SSLProtocolException;
42import javax.net.ssl.SSLSession;
43import javax.net.ssl.X509KeyManager;
44import javax.net.ssl.X509TrustManager;
45import javax.security.auth.x500.X500Principal;
46
47/**
48 * Implementation of the class OpenSSLSocketImpl based on OpenSSL.
49 * <p>
50 * Extensions to SSLSocket include:
51 * <ul>
52 * <li>handshake timeout
53 * <li>session tickets
54 * <li>Server Name Indication
55 * </ul>
56 */
57public class OpenSSLSocketImpl
58        extends javax.net.ssl.SSLSocket
59        implements NativeCrypto.SSLHandshakeCallbacks, SSLParametersImpl.AliasChooser,
60        SSLParametersImpl.PSKCallbacks {
61
62    private static final boolean DBG_STATE = false;
63
64    /**
65     * Protects handshakeStarted and handshakeCompleted.
66     */
67    private final Object stateLock = new Object();
68
69    /**
70     * The {@link OpenSSLSocketImpl} object is constructed, but {@link #startHandshake()}
71     * has not yet been called.
72     */
73    private static final int STATE_NEW = 0;
74
75    /**
76     * {@link #startHandshake()} has been called at least once.
77     */
78    private static final int STATE_HANDSHAKE_STARTED = 1;
79
80    /**
81     * {@link #handshakeCompleted()} has been called, but {@link #startHandshake()} hasn't
82     * returned yet.
83     */
84    private static final int STATE_HANDSHAKE_COMPLETED = 2;
85
86    /**
87     * {@link #startHandshake()} has completed but {@link #handshakeCompleted()} hasn't
88     * been called. This is expected behaviour in cut-through mode, where SSL_do_handshake
89     * returns before the handshake is complete. We can now start writing data to the socket.
90     */
91    private static final int STATE_READY_HANDSHAKE_CUT_THROUGH = 3;
92
93    /**
94     * {@link #startHandshake()} has completed and {@link #handshakeCompleted()} has been
95     * called.
96     */
97    private static final int STATE_READY = 4;
98
99    /**
100     * {@link #close()} has been called at least once.
101     */
102    private static final int STATE_CLOSED = 5;
103
104    // @GuardedBy("stateLock");
105    private int state = STATE_NEW;
106
107    /**
108     * Protected by synchronizing on stateLock. Starts as 0, set by
109     * startHandshake, reset to 0 on close.
110     */
111    // @GuardedBy("stateLock");
112    private long sslNativePointer;
113
114    /**
115     * Protected by synchronizing on stateLock. Starts as null, set by
116     * getInputStream.
117     */
118    // @GuardedBy("stateLock");
119    private SSLInputStream is;
120
121    /**
122     * Protected by synchronizing on stateLock. Starts as null, set by
123     * getInputStream.
124     */
125    // @GuardedBy("stateLock");
126    private SSLOutputStream os;
127
128    private final Socket socket;
129    private final boolean autoClose;
130
131    /**
132     * The peer's DNS hostname if it was supplied during creation.
133     */
134    private String peerHostname;
135
136    /**
137     * The DNS hostname from reverse lookup on the socket. Should never be used
138     * for Server Name Indication (SNI).
139     */
140    private String resolvedHostname;
141
142    /**
143     * The peer's port if it was supplied during creation. Should only be set if
144     * {@link #peerHostname} is also set.
145     */
146    private final int peerPort;
147
148    private final SSLParametersImpl sslParameters;
149    private final CloseGuard guard = CloseGuard.get();
150
151    private ArrayList<HandshakeCompletedListener> listeners;
152
153    /**
154     * Private key for the TLS Channel ID extension. This field is client-side
155     * only. Set during startHandshake.
156     */
157    OpenSSLKey channelIdPrivateKey;
158
159    /** Set during startHandshake. */
160    private OpenSSLSessionImpl sslSession;
161
162    /** Used during handshake callbacks. */
163    private OpenSSLSessionImpl handshakeSession;
164
165    /**
166     * Local cache of timeout to avoid getsockopt on every read and
167     * write for non-wrapped sockets. Note that
168     * OpenSSLSocketImplWrapper overrides setSoTimeout and
169     * getSoTimeout to delegate to the wrapped socket.
170     */
171    private int readTimeoutMilliseconds = 0;
172    private int writeTimeoutMilliseconds = 0;
173
174    private int handshakeTimeoutMilliseconds = -1;  // -1 = same as timeout; 0 = infinite
175
176    protected OpenSSLSocketImpl(SSLParametersImpl sslParameters) throws IOException {
177        this.socket = this;
178        this.peerHostname = null;
179        this.peerPort = -1;
180        this.autoClose = false;
181        this.sslParameters = sslParameters;
182    }
183
184    protected OpenSSLSocketImpl(String hostname, int port, SSLParametersImpl sslParameters)
185            throws IOException {
186        super(hostname, port);
187        this.socket = this;
188        this.peerHostname = hostname;
189        this.peerPort = port;
190        this.autoClose = false;
191        this.sslParameters = sslParameters;
192    }
193
194    protected OpenSSLSocketImpl(InetAddress address, int port, SSLParametersImpl sslParameters)
195            throws IOException {
196        super(address, port);
197        this.socket = this;
198        this.peerHostname = null;
199        this.peerPort = -1;
200        this.autoClose = false;
201        this.sslParameters = sslParameters;
202    }
203
204
205    protected OpenSSLSocketImpl(String hostname, int port,
206                                InetAddress clientAddress, int clientPort,
207                                SSLParametersImpl sslParameters) throws IOException {
208        super(hostname, port, clientAddress, clientPort);
209        this.socket = this;
210        this.peerHostname = hostname;
211        this.peerPort = port;
212        this.autoClose = false;
213        this.sslParameters = sslParameters;
214    }
215
216    protected OpenSSLSocketImpl(InetAddress address, int port,
217                                InetAddress clientAddress, int clientPort,
218                                SSLParametersImpl sslParameters) throws IOException {
219        super(address, port, clientAddress, clientPort);
220        this.socket = this;
221        this.peerHostname = null;
222        this.peerPort = -1;
223        this.autoClose = false;
224        this.sslParameters = sslParameters;
225    }
226
227    /**
228     * Create an SSL socket that wraps another socket. Invoked by
229     * OpenSSLSocketImplWrapper constructor.
230     */
231    protected OpenSSLSocketImpl(Socket socket, String hostname, int port,
232            boolean autoClose, SSLParametersImpl sslParameters) throws IOException {
233        this.socket = socket;
234        this.peerHostname = hostname;
235        this.peerPort = port;
236        this.autoClose = autoClose;
237        this.sslParameters = sslParameters;
238
239        // this.timeout is not set intentionally.
240        // OpenSSLSocketImplWrapper.getSoTimeout will delegate timeout
241        // to wrapped socket
242    }
243
244    private void checkOpen() throws SocketException {
245        if (isClosed()) {
246            throw new SocketException("Socket is closed");
247        }
248    }
249
250    /**
251     * Starts a TLS/SSL handshake on this connection using some native methods
252     * from the OpenSSL library. It can negotiate new encryption keys, change
253     * cipher suites, or initiate a new session. The certificate chain is
254     * verified if the correspondent property in java.Security is set. All
255     * listeners are notified at the end of the TLS/SSL handshake.
256     */
257    @Override
258    public void startHandshake() throws IOException {
259        checkOpen();
260        synchronized (stateLock) {
261            if (state == STATE_NEW) {
262                state = STATE_HANDSHAKE_STARTED;
263            } else {
264                // We've either started the handshake already or have been closed.
265                // Do nothing in both cases.
266                return;
267            }
268        }
269
270        // note that this modifies the global seed, not something specific to the connection
271        final int seedLengthInBytes = NativeCrypto.RAND_SEED_LENGTH_IN_BYTES;
272        final SecureRandom secureRandom = sslParameters.getSecureRandomMember();
273        if (secureRandom == null) {
274            NativeCrypto.RAND_load_file("/dev/urandom", seedLengthInBytes);
275        } else {
276            NativeCrypto.RAND_seed(secureRandom.generateSeed(seedLengthInBytes));
277        }
278
279        final boolean client = sslParameters.getUseClientMode();
280
281        sslNativePointer = 0;
282        boolean releaseResources = true;
283        try {
284            final AbstractSessionContext sessionContext = sslParameters.getSessionContext();
285            final long sslCtxNativePointer = sessionContext.sslCtxNativePointer;
286            sslNativePointer = NativeCrypto.SSL_new(sslCtxNativePointer);
287            guard.open("close");
288
289            boolean enableSessionCreation = getEnableSessionCreation();
290            if (!enableSessionCreation) {
291                NativeCrypto.SSL_set_session_creation_enabled(sslNativePointer,
292                        enableSessionCreation);
293            }
294
295            final OpenSSLSessionImpl sessionToReuse = sslParameters.getSessionToReuse(
296                    sslNativePointer, getHostname(), getPort());
297            sslParameters.setSSLParameters(sslCtxNativePointer, sslNativePointer, this, this,
298                    peerHostname);
299            sslParameters.setCertificateValidation(sslNativePointer);
300            sslParameters.setTlsChannelId(sslNativePointer, channelIdPrivateKey);
301
302            // Temporarily use a different timeout for the handshake process
303            int savedReadTimeoutMilliseconds = getSoTimeout();
304            int savedWriteTimeoutMilliseconds = getSoWriteTimeout();
305            if (handshakeTimeoutMilliseconds >= 0) {
306                setSoTimeout(handshakeTimeoutMilliseconds);
307                setSoWriteTimeout(handshakeTimeoutMilliseconds);
308            }
309
310            synchronized (stateLock) {
311                if (state == STATE_CLOSED) {
312                    return;
313                }
314            }
315
316            long sslSessionNativePointer;
317            try {
318                sslSessionNativePointer = NativeCrypto.SSL_do_handshake(sslNativePointer,
319                        Platform.getFileDescriptor(socket), this, getSoTimeout(), client,
320                        sslParameters.npnProtocols, client ? null : sslParameters.alpnProtocols);
321            } catch (CertificateException e) {
322                SSLHandshakeException wrapper = new SSLHandshakeException(e.getMessage());
323                wrapper.initCause(e);
324                throw wrapper;
325            } catch (SSLException e) {
326                // Swallow this exception if it's thrown as the result of an interruption.
327                //
328                // TODO: SSL_read and SSL_write return -1 when interrupted, but SSL_do_handshake
329                // will throw the last sslError that it saw before sslSelect, usually SSL_WANT_READ
330                // (or WANT_WRITE). Catching that exception here doesn't seem much worse than
331                // changing the native code to return a "special" native pointer value when that
332                // happens.
333                synchronized (stateLock) {
334                    if (state == STATE_CLOSED) {
335                        return;
336                    }
337                }
338
339                // Write CCS errors to EventLog
340                String message = e.getMessage();
341                // Must match error string of SSL_R_UNEXPECTED_CCS
342                if (message.contains("unexpected CCS")) {
343                    String logMessage = String.format("ssl_unexpected_ccs: host=%s",
344                            getHostname());
345                    Platform.logEvent(logMessage);
346                }
347
348                throw e;
349            }
350
351            boolean handshakeCompleted = false;
352            synchronized (stateLock) {
353                if (state == STATE_HANDSHAKE_COMPLETED) {
354                    handshakeCompleted = true;
355                } else if (state == STATE_CLOSED) {
356                    return;
357                }
358            }
359
360            sslSession = sslParameters.setupSession(sslSessionNativePointer, sslNativePointer,
361                    sessionToReuse, getHostname(), getPort(), handshakeCompleted);
362
363            // Restore the original timeout now that the handshake is complete
364            if (handshakeTimeoutMilliseconds >= 0) {
365                setSoTimeout(savedReadTimeoutMilliseconds);
366                setSoWriteTimeout(savedWriteTimeoutMilliseconds);
367            }
368
369            // if not, notifyHandshakeCompletedListeners later in handshakeCompleted() callback
370            if (handshakeCompleted) {
371                notifyHandshakeCompletedListeners();
372            }
373
374            synchronized (stateLock) {
375                releaseResources = (state == STATE_CLOSED);
376
377                if (state == STATE_HANDSHAKE_STARTED) {
378                    state = STATE_READY_HANDSHAKE_CUT_THROUGH;
379                } else if (state == STATE_HANDSHAKE_COMPLETED) {
380                    state = STATE_READY;
381                }
382
383                if (!releaseResources) {
384                    // Unblock threads that are waiting for our state to transition
385                    // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH.
386                    stateLock.notifyAll();
387                }
388            }
389        } catch (SSLProtocolException e) {
390            throw (SSLHandshakeException) new SSLHandshakeException("Handshake failed")
391                    .initCause(e);
392        } finally {
393            // on exceptional exit, treat the socket as closed
394            if (releaseResources) {
395                synchronized (stateLock) {
396                    // Mark the socket as closed since we might have reached this as
397                    // a result on an exception thrown by the handshake process.
398                    //
399                    // The state will already be set to closed if we reach this as a result of
400                    // an early return or an interruption due to a concurrent call to close().
401                    state = STATE_CLOSED;
402                    stateLock.notifyAll();
403                }
404
405                try {
406                    shutdownAndFreeSslNative();
407                } catch (IOException ignored) {
408
409                }
410            }
411        }
412    }
413
414    /**
415     * Returns the hostname that was supplied during socket creation or tries to
416     * look up the hostname via the supplied socket address if possible. This
417     * may result in the return of a IP address and should not be used for
418     * Server Name Indication (SNI).
419     */
420    private String getHostname() {
421        if (peerHostname != null) {
422            return peerHostname;
423        }
424        if (resolvedHostname == null) {
425            InetAddress inetAddress = super.getInetAddress();
426            if (inetAddress != null) {
427                resolvedHostname = inetAddress.getHostName();
428            }
429        }
430        return resolvedHostname;
431    }
432
433    @Override
434    public int getPort() {
435        return peerPort == -1 ? super.getPort() : peerPort;
436    }
437
438    @Override
439    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / client_cert_cb
440    public void clientCertificateRequested(byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals)
441            throws CertificateEncodingException, SSLException {
442        sslParameters.chooseClientCertificate(keyTypeBytes, asn1DerEncodedPrincipals,
443                sslNativePointer, this);
444    }
445
446    @Override
447    @SuppressWarnings("unused") // used by native psk_client_callback
448    public int clientPSKKeyRequested(String identityHint, byte[] identity, byte[] key) {
449        return sslParameters.clientPSKKeyRequested(identityHint, identity, key, this);
450    }
451
452    @Override
453    @SuppressWarnings("unused") // used by native psk_server_callback
454    public int serverPSKKeyRequested(String identityHint, String identity, byte[] key) {
455        return sslParameters.serverPSKKeyRequested(identityHint, identity, key, this);
456    }
457
458    @Override
459    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / info_callback
460    public void onSSLStateChange(long sslSessionNativePtr, int type, int val) {
461        if (type != NativeCrypto.SSL_CB_HANDSHAKE_DONE) {
462            return;
463        }
464
465        synchronized (stateLock) {
466            if (state == STATE_HANDSHAKE_STARTED) {
467                // If sslSession is null, the handshake was completed during
468                // the call to NativeCrypto.SSL_do_handshake and not during a
469                // later read operation. That means we do not need to fix up
470                // the SSLSession and session cache or notify
471                // HandshakeCompletedListeners, it will be done in
472                // startHandshake.
473
474                state = STATE_HANDSHAKE_COMPLETED;
475                return;
476            } else if (state == STATE_READY_HANDSHAKE_CUT_THROUGH) {
477                // We've returned from startHandshake, which means we've set a sslSession etc.
478                // we need to fix them up, which we'll do outside this lock.
479            } else if (state == STATE_CLOSED) {
480                // Someone called "close" but the handshake hasn't been interrupted yet.
481                return;
482            }
483        }
484
485        // reset session id from the native pointer and update the
486        // appropriate cache.
487        sslSession.resetId();
488        AbstractSessionContext sessionContext =
489            (sslParameters.getUseClientMode())
490            ? sslParameters.getClientSessionContext()
491                : sslParameters.getServerSessionContext();
492        sessionContext.putSession(sslSession);
493
494        // let listeners know we are finally done
495        notifyHandshakeCompletedListeners();
496
497        synchronized (stateLock) {
498            // Now that we've fixed up our state, we can tell waiting threads that
499            // we're ready.
500            state = STATE_READY;
501            // Notify all threads waiting for the handshake to complete.
502            stateLock.notifyAll();
503        }
504    }
505
506    private void notifyHandshakeCompletedListeners() {
507        if (listeners != null && !listeners.isEmpty()) {
508            // notify the listeners
509            HandshakeCompletedEvent event =
510                new HandshakeCompletedEvent(this, sslSession);
511            for (HandshakeCompletedListener listener : listeners) {
512                try {
513                    listener.handshakeCompleted(event);
514                } catch (RuntimeException e) {
515                    // The RI runs the handlers in a separate thread,
516                    // which we do not. But we try to preserve their
517                    // behavior of logging a problem and not killing
518                    // the handshaking thread just because a listener
519                    // has a problem.
520                    Thread thread = Thread.currentThread();
521                    thread.getUncaughtExceptionHandler().uncaughtException(thread, e);
522                }
523            }
524        }
525    }
526
527    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks
528    @Override
529    public void verifyCertificateChain(long sslSessionNativePtr, long[] certRefs, String authMethod)
530            throws CertificateException {
531        try {
532            X509TrustManager x509tm = sslParameters.getX509TrustManager();
533            if (x509tm == null) {
534                throw new CertificateException("No X.509 TrustManager");
535            }
536            if (certRefs == null || certRefs.length == 0) {
537                throw new SSLException("Peer sent no certificate");
538            }
539            OpenSSLX509Certificate[] peerCertChain = new OpenSSLX509Certificate[certRefs.length];
540            for (int i = 0; i < certRefs.length; i++) {
541                peerCertChain[i] = new OpenSSLX509Certificate(certRefs[i]);
542            }
543
544            // Used for verifyCertificateChain callback
545            handshakeSession = new OpenSSLSessionImpl(sslSessionNativePtr, null, peerCertChain,
546                    getHostname(), getPort(), null);
547
548            boolean client = sslParameters.getUseClientMode();
549            if (client) {
550                Platform.checkServerTrusted(x509tm, peerCertChain, authMethod, getHostname());
551            } else {
552                String authType = peerCertChain[0].getPublicKey().getAlgorithm();
553                x509tm.checkClientTrusted(peerCertChain, authType);
554            }
555        } catch (CertificateException e) {
556            throw e;
557        } catch (Exception e) {
558            throw new CertificateException(e);
559        } finally {
560            // Clear this before notifying handshake completed listeners
561            handshakeSession = null;
562        }
563    }
564
565    @Override
566    public InputStream getInputStream() throws IOException {
567        checkOpen();
568
569        InputStream returnVal;
570        synchronized (stateLock) {
571            if (state == STATE_CLOSED) {
572                throw new SocketException("Socket is closed.");
573            }
574
575            if (is == null) {
576                is = new SSLInputStream();
577            }
578
579            returnVal = is;
580        }
581
582        // Block waiting for a handshake without a lock held. It's possible that the socket
583        // is closed at this point. If that happens, we'll still return the input stream but
584        // all reads on it will throw.
585        waitForHandshake();
586        return returnVal;
587    }
588
589    @Override
590    public OutputStream getOutputStream() throws IOException {
591        checkOpen();
592
593        OutputStream returnVal;
594        synchronized (stateLock) {
595            if (state == STATE_CLOSED) {
596                throw new SocketException("Socket is closed.");
597            }
598
599            if (os == null) {
600                os = new SSLOutputStream();
601            }
602
603            returnVal = os;
604        }
605
606        // Block waiting for a handshake without a lock held. It's possible that the socket
607        // is closed at this point. If that happens, we'll still return the output stream but
608        // all writes on it will throw.
609        waitForHandshake();
610        return returnVal;
611    }
612
613    private void assertReadableOrWriteableState() {
614        if (state == STATE_READY || state == STATE_READY_HANDSHAKE_CUT_THROUGH) {
615            return;
616        }
617
618        throw new AssertionError("Invalid state: " + state);
619    }
620
621
622    private void waitForHandshake() throws IOException {
623        startHandshake();
624
625        synchronized (stateLock) {
626            while (state != STATE_READY &&
627                    state != STATE_READY_HANDSHAKE_CUT_THROUGH &&
628                    state != STATE_CLOSED) {
629                try {
630                    stateLock.wait();
631                } catch (InterruptedException e) {
632                    Thread.currentThread().interrupt();
633                    IOException ioe = new IOException("Interrupted waiting for handshake");
634                    ioe.initCause(e);
635
636                    throw ioe;
637                }
638            }
639
640            if (state == STATE_CLOSED) {
641                throw new SocketException("Socket is closed");
642            }
643        }
644    }
645
646    /**
647     * This inner class provides input data stream functionality
648     * for the OpenSSL native implementation. It is used to
649     * read data received via SSL protocol.
650     */
651    private class SSLInputStream extends InputStream {
652        /**
653         * OpenSSL only lets one thread read at a time, so this is used to
654         * make sure we serialize callers of SSL_read. Thread is already
655         * expected to have completed handshaking.
656         */
657        private final Object readLock = new Object();
658
659        SSLInputStream() {
660        }
661
662        /**
663         * Reads one byte. If there is no data in the underlying buffer,
664         * this operation can block until the data will be
665         * available.
666         * @return read value.
667         * @throws IOException
668         */
669        @Override
670        public int read() throws IOException {
671            byte[] buffer = new byte[1];
672            int result = read(buffer, 0, 1);
673            return (result != -1) ? buffer[0] & 0xff : -1;
674        }
675
676        /**
677         * Method acts as described in spec for superclass.
678         * @see java.io.InputStream#read(byte[],int,int)
679         */
680        @Override
681        public int read(byte[] buf, int offset, int byteCount) throws IOException {
682            BlockGuard.getThreadPolicy().onNetwork();
683
684            checkOpen();
685            Arrays.checkOffsetAndCount(buf.length, offset, byteCount);
686            if (byteCount == 0) {
687                return 0;
688            }
689
690            synchronized (readLock) {
691                synchronized (stateLock) {
692                    if (state == STATE_CLOSED) {
693                        throw new SocketException("socket is closed");
694                    }
695
696                    if (DBG_STATE) assertReadableOrWriteableState();
697                }
698
699                return NativeCrypto.SSL_read(sslNativePointer, Platform.getFileDescriptor(socket),
700                        OpenSSLSocketImpl.this, buf, offset, byteCount, getSoTimeout());
701            }
702        }
703
704        public void awaitPendingOps() {
705            if (DBG_STATE) {
706                synchronized (stateLock) {
707                    if (state != STATE_CLOSED) throw new AssertionError("State is: " + state);
708                }
709            }
710
711            synchronized (readLock) { }
712        }
713    }
714
715    /**
716     * This inner class provides output data stream functionality
717     * for the OpenSSL native implementation. It is used to
718     * write data according to the encryption parameters given in SSL context.
719     */
720    private class SSLOutputStream extends OutputStream {
721
722        /**
723         * OpenSSL only lets one thread write at a time, so this is used
724         * to make sure we serialize callers of SSL_write. Thread is
725         * already expected to have completed handshaking.
726         */
727        private final Object writeLock = new Object();
728
729        SSLOutputStream() {
730        }
731
732        /**
733         * Method acts as described in spec for superclass.
734         * @see java.io.OutputStream#write(int)
735         */
736        @Override
737        public void write(int oneByte) throws IOException {
738            byte[] buffer = new byte[1];
739            buffer[0] = (byte) (oneByte & 0xff);
740            write(buffer);
741        }
742
743        /**
744         * Method acts as described in spec for superclass.
745         * @see java.io.OutputStream#write(byte[],int,int)
746         */
747        @Override
748        public void write(byte[] buf, int offset, int byteCount) throws IOException {
749            BlockGuard.getThreadPolicy().onNetwork();
750            checkOpen();
751            Arrays.checkOffsetAndCount(buf.length, offset, byteCount);
752            if (byteCount == 0) {
753                return;
754            }
755
756            synchronized (writeLock) {
757                synchronized (stateLock) {
758                    if (state == STATE_CLOSED) {
759                        throw new SocketException("socket is closed");
760                    }
761
762                    if (DBG_STATE) assertReadableOrWriteableState();
763                }
764
765                NativeCrypto.SSL_write(sslNativePointer, Platform.getFileDescriptor(socket),
766                        OpenSSLSocketImpl.this, buf, offset, byteCount, writeTimeoutMilliseconds);
767            }
768        }
769
770
771        public void awaitPendingOps() {
772            if (DBG_STATE) {
773                synchronized (stateLock) {
774                    if (state != STATE_CLOSED) throw new AssertionError("State is: " + state);
775                }
776            }
777
778            synchronized (writeLock) { }
779        }
780    }
781
782
783    @Override
784    public SSLSession getSession() {
785        if (sslSession == null) {
786            try {
787                waitForHandshake();
788            } catch (IOException e) {
789                // return an invalid session with
790                // invalid cipher suite of "SSL_NULL_WITH_NULL_NULL"
791                return SSLNullSession.getNullSession();
792            }
793        }
794        return sslSession;
795    }
796
797    @Override
798    public void addHandshakeCompletedListener(
799            HandshakeCompletedListener listener) {
800        if (listener == null) {
801            throw new IllegalArgumentException("Provided listener is null");
802        }
803        if (listeners == null) {
804            listeners = new ArrayList<HandshakeCompletedListener>();
805        }
806        listeners.add(listener);
807    }
808
809    @Override
810    public void removeHandshakeCompletedListener(
811            HandshakeCompletedListener listener) {
812        if (listener == null) {
813            throw new IllegalArgumentException("Provided listener is null");
814        }
815        if (listeners == null) {
816            throw new IllegalArgumentException(
817                    "Provided listener is not registered");
818        }
819        if (!listeners.remove(listener)) {
820            throw new IllegalArgumentException(
821                    "Provided listener is not registered");
822        }
823    }
824
825    @Override
826    public boolean getEnableSessionCreation() {
827        return sslParameters.getEnableSessionCreation();
828    }
829
830    @Override
831    public void setEnableSessionCreation(boolean flag) {
832        sslParameters.setEnableSessionCreation(flag);
833    }
834
835    @Override
836    public String[] getSupportedCipherSuites() {
837        return NativeCrypto.getSupportedCipherSuites();
838    }
839
840    @Override
841    public String[] getEnabledCipherSuites() {
842        return sslParameters.getEnabledCipherSuites();
843    }
844
845    @Override
846    public void setEnabledCipherSuites(String[] suites) {
847        sslParameters.setEnabledCipherSuites(suites);
848    }
849
850    @Override
851    public String[] getSupportedProtocols() {
852        return NativeCrypto.getSupportedProtocols();
853    }
854
855    @Override
856    public String[] getEnabledProtocols() {
857        return sslParameters.getEnabledProtocols();
858    }
859
860    @Override
861    public void setEnabledProtocols(String[] protocols) {
862        sslParameters.setEnabledProtocols(protocols);
863    }
864
865    /**
866     * This method enables session ticket support.
867     *
868     * @param useSessionTickets True to enable session tickets
869     */
870    public void setUseSessionTickets(boolean useSessionTickets) {
871        sslParameters.useSessionTickets = useSessionTickets;
872    }
873
874    /**
875     * This method enables Server Name Indication
876     *
877     * @param hostname the desired SNI hostname, or null to disable
878     */
879    public void setHostname(String hostname) {
880        sslParameters.setUseSni(hostname != null);
881        peerHostname = hostname;
882    }
883
884    /**
885     * Enables/disables TLS Channel ID for this server socket.
886     *
887     * <p>This method needs to be invoked before the handshake starts.
888     *
889     * @throws IllegalStateException if this is a client socket or if the handshake has already
890     *         started.
891
892     */
893    public void setChannelIdEnabled(boolean enabled) {
894        if (getUseClientMode()) {
895            throw new IllegalStateException("Client mode");
896        }
897
898        synchronized (stateLock) {
899            if (state != STATE_NEW) {
900                throw new IllegalStateException(
901                        "Could not enable/disable Channel ID after the initial handshake has"
902                                + " begun.");
903            }
904        }
905        sslParameters.channelIdEnabled = enabled;
906    }
907
908    /**
909     * Gets the TLS Channel ID for this server socket. Channel ID is only available once the
910     * handshake completes.
911     *
912     * @return channel ID or {@code null} if not available.
913     *
914     * @throws IllegalStateException if this is a client socket or if the handshake has not yet
915     *         completed.
916     * @throws SSLException if channel ID is available but could not be obtained.
917     */
918    public byte[] getChannelId() throws SSLException {
919        if (getUseClientMode()) {
920            throw new IllegalStateException("Client mode");
921        }
922
923        synchronized (stateLock) {
924            if (state != STATE_READY) {
925                throw new IllegalStateException(
926                        "Channel ID is only available after handshake completes");
927            }
928        }
929        return NativeCrypto.SSL_get_tls_channel_id(sslNativePointer);
930    }
931
932    /**
933     * Sets the {@link PrivateKey} to be used for TLS Channel ID by this client socket.
934     *
935     * <p>This method needs to be invoked before the handshake starts.
936     *
937     * @param privateKey private key (enables TLS Channel ID) or {@code null} for no key (disables
938     *        TLS Channel ID). The private key must be an Elliptic Curve (EC) key based on the NIST
939     *        P-256 curve (aka SECG secp256r1 or ANSI X9.62 prime256v1).
940     *
941     * @throws IllegalStateException if this is a server socket or if the handshake has already
942     *         started.
943     */
944    public void setChannelIdPrivateKey(PrivateKey privateKey) {
945        if (!getUseClientMode()) {
946            throw new IllegalStateException("Server mode");
947        }
948
949        synchronized (stateLock) {
950            if (state != STATE_NEW) {
951                throw new IllegalStateException(
952                        "Could not change Channel ID private key after the initial handshake has"
953                                + " begun.");
954            }
955        }
956
957        if (privateKey == null) {
958            sslParameters.channelIdEnabled = false;
959            channelIdPrivateKey = null;
960        } else {
961            sslParameters.channelIdEnabled = true;
962            try {
963                channelIdPrivateKey = OpenSSLKey.fromPrivateKey(privateKey);
964            } catch (InvalidKeyException e) {
965                // Will have error in startHandshake
966            }
967        }
968    }
969
970    @Override
971    public boolean getUseClientMode() {
972        return sslParameters.getUseClientMode();
973    }
974
975    @Override
976    public void setUseClientMode(boolean mode) {
977        synchronized (stateLock) {
978            if (state != STATE_NEW) {
979                throw new IllegalArgumentException(
980                        "Could not change the mode after the initial handshake has begun.");
981            }
982        }
983        sslParameters.setUseClientMode(mode);
984    }
985
986    @Override
987    public boolean getWantClientAuth() {
988        return sslParameters.getWantClientAuth();
989    }
990
991    @Override
992    public boolean getNeedClientAuth() {
993        return sslParameters.getNeedClientAuth();
994    }
995
996    @Override
997    public void setNeedClientAuth(boolean need) {
998        sslParameters.setNeedClientAuth(need);
999    }
1000
1001    @Override
1002    public void setWantClientAuth(boolean want) {
1003        sslParameters.setWantClientAuth(want);
1004    }
1005
1006    @Override
1007    public void sendUrgentData(int data) throws IOException {
1008        throw new SocketException("Method sendUrgentData() is not supported.");
1009    }
1010
1011    @Override
1012    public void setOOBInline(boolean on) throws SocketException {
1013        throw new SocketException("Methods sendUrgentData, setOOBInline are not supported.");
1014    }
1015
1016    @Override
1017    public void setSoTimeout(int readTimeoutMilliseconds) throws SocketException {
1018        super.setSoTimeout(readTimeoutMilliseconds);
1019        this.readTimeoutMilliseconds = readTimeoutMilliseconds;
1020    }
1021
1022    @Override
1023    public int getSoTimeout() throws SocketException {
1024        return readTimeoutMilliseconds;
1025    }
1026
1027    /**
1028     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
1029     */
1030    public void setSoWriteTimeout(int writeTimeoutMilliseconds) throws SocketException {
1031        this.writeTimeoutMilliseconds = writeTimeoutMilliseconds;
1032
1033        Platform.setSocketWriteTimeout(this, writeTimeoutMilliseconds);
1034    }
1035
1036    /**
1037     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
1038     */
1039    public int getSoWriteTimeout() throws SocketException {
1040        return writeTimeoutMilliseconds;
1041    }
1042
1043    /**
1044     * Set the handshake timeout on this socket.  This timeout is specified in
1045     * milliseconds and will be used only during the handshake process.
1046     */
1047    public void setHandshakeTimeout(int handshakeTimeoutMilliseconds) throws SocketException {
1048        this.handshakeTimeoutMilliseconds = handshakeTimeoutMilliseconds;
1049    }
1050
1051    @Override
1052    public void close() throws IOException {
1053        // TODO: Close SSL sockets using a background thread so they close gracefully.
1054
1055        SSLInputStream sslInputStream = null;
1056        SSLOutputStream sslOutputStream = null;
1057
1058        synchronized (stateLock) {
1059            if (state == STATE_CLOSED) {
1060                // close() has already been called, so do nothing and return.
1061                return;
1062            }
1063
1064            int oldState = state;
1065            state = STATE_CLOSED;
1066
1067            if (oldState == STATE_NEW) {
1068                // The handshake hasn't been started yet, so there's no OpenSSL related
1069                // state to clean up. We still need to close the underlying socket if
1070                // we're wrapping it and were asked to autoClose.
1071                closeUnderlyingSocket();
1072
1073                stateLock.notifyAll();
1074                return;
1075            }
1076
1077            if (oldState != STATE_READY && oldState != STATE_READY_HANDSHAKE_CUT_THROUGH) {
1078                // If we're in these states, we still haven't returned from startHandshake.
1079                // We call SSL_interrupt so that we can interrupt SSL_do_handshake and then
1080                // set the state to STATE_CLOSED. startHandshake will handle all cleanup
1081                // after SSL_do_handshake returns, so we don't have anything to do here.
1082                NativeCrypto.SSL_interrupt(sslNativePointer);
1083
1084                stateLock.notifyAll();
1085                return;
1086            }
1087
1088            stateLock.notifyAll();
1089            // We've already returned from startHandshake, so we potentially have
1090            // input and output streams to clean up.
1091            sslInputStream = is;
1092            sslOutputStream = os;
1093        }
1094
1095        // Don't bother interrupting unless we have something to interrupt.
1096        if (sslInputStream != null || sslOutputStream != null) {
1097            NativeCrypto.SSL_interrupt(sslNativePointer);
1098        }
1099
1100        // Wait for the input and output streams to finish any reads they have in
1101        // progress. If there are no reads in progress at this point, future reads will
1102        // throw because state == STATE_CLOSED
1103        if (sslInputStream != null) {
1104            sslInputStream.awaitPendingOps();
1105        }
1106        if (sslOutputStream != null) {
1107            sslOutputStream.awaitPendingOps();
1108        }
1109
1110        shutdownAndFreeSslNative();
1111    }
1112
1113    private void shutdownAndFreeSslNative() throws IOException {
1114        try {
1115            BlockGuard.getThreadPolicy().onNetwork();
1116            NativeCrypto.SSL_shutdown(sslNativePointer, Platform.getFileDescriptor(socket),
1117                    this);
1118        } catch (IOException ignored) {
1119            /*
1120            * Note that although close() can throw
1121            * IOException, the RI does not throw if there
1122            * is problem sending a "close notify" which
1123            * can happen if the underlying socket is closed.
1124            */
1125        } finally {
1126            free();
1127            closeUnderlyingSocket();
1128        }
1129    }
1130
1131    private void closeUnderlyingSocket() throws IOException {
1132        if (socket != this) {
1133            if (autoClose && !socket.isClosed()) {
1134                socket.close();
1135            }
1136        } else {
1137            if (!super.isClosed()) {
1138                super.close();
1139            }
1140        }
1141    }
1142
1143    private void free() {
1144        if (sslNativePointer == 0) {
1145            return;
1146        }
1147        NativeCrypto.SSL_free(sslNativePointer);
1148        sslNativePointer = 0;
1149        guard.close();
1150    }
1151
1152    @Override
1153    protected void finalize() throws Throwable {
1154        try {
1155            /*
1156             * Just worry about our own state. Notably we do not try and
1157             * close anything. The SocketImpl, either our own
1158             * PlainSocketImpl, or the Socket we are wrapping, will do
1159             * that. This might mean we do not properly SSL_shutdown, but
1160             * if you want to do that, properly close the socket yourself.
1161             *
1162             * The reason why we don't try to SSL_shutdown, is that there
1163             * can be a race between finalizers where the PlainSocketImpl
1164             * finalizer runs first and closes the socket. However, in the
1165             * meanwhile, the underlying file descriptor could be reused
1166             * for another purpose. If we call SSL_shutdown, the
1167             * underlying socket BIOs still have the old file descriptor
1168             * and will write the close notify to some unsuspecting
1169             * reader.
1170             */
1171            if (guard != null) {
1172                guard.warnIfOpen();
1173            }
1174            free();
1175        } finally {
1176            super.finalize();
1177        }
1178    }
1179
1180    /* @Override */
1181    public FileDescriptor getFileDescriptor$() {
1182        if (socket == this) {
1183            return Platform.getFileDescriptorFromSSLSocket(this);
1184        } else {
1185            return Platform.getFileDescriptor(socket);
1186        }
1187    }
1188
1189    /**
1190     * Returns the protocol agreed upon by client and server, or null if no
1191     * protocol was agreed upon.
1192     */
1193    public byte[] getNpnSelectedProtocol() {
1194        return NativeCrypto.SSL_get_npn_negotiated_protocol(sslNativePointer);
1195    }
1196
1197    /**
1198     * Returns the protocol agreed upon by client and server, or {@code null} if
1199     * no protocol was agreed upon.
1200     */
1201    public byte[] getAlpnSelectedProtocol() {
1202        return NativeCrypto.SSL_get0_alpn_selected(sslNativePointer);
1203    }
1204
1205    /**
1206     * Sets the list of protocols this peer is interested in. If null no
1207     * protocols will be used.
1208     *
1209     * @param npnProtocols a non-empty array of protocol names. From
1210     *     SSL_select_next_proto, "vector of 8-bit, length prefixed byte
1211     *     strings. The length byte itself is not included in the length. A byte
1212     *     string of length 0 is invalid. No byte string may be truncated.".
1213     */
1214    public void setNpnProtocols(byte[] npnProtocols) {
1215        if (npnProtocols != null && npnProtocols.length == 0) {
1216            throw new IllegalArgumentException("npnProtocols.length == 0");
1217        }
1218        sslParameters.npnProtocols = npnProtocols;
1219    }
1220
1221    /**
1222     * Sets the list of protocols this peer is interested in. If the list is
1223     * {@code null}, no protocols will be used.
1224     *
1225     * @param alpnProtocols a non-empty array of protocol names. From
1226     *            SSL_select_next_proto, "vector of 8-bit, length prefixed byte
1227     *            strings. The length byte itself is not included in the length.
1228     *            A byte string of length 0 is invalid. No byte string may be
1229     *            truncated.".
1230     */
1231    public void setAlpnProtocols(byte[] alpnProtocols) {
1232        if (alpnProtocols != null && alpnProtocols.length == 0) {
1233            throw new IllegalArgumentException("alpnProtocols.length == 0");
1234        }
1235        sslParameters.alpnProtocols = alpnProtocols;
1236    }
1237
1238    @Override
1239    public String chooseServerAlias(X509KeyManager keyManager, String keyType) {
1240        return keyManager.chooseServerAlias(keyType, null, this);
1241    }
1242
1243    @Override
1244    public String chooseClientAlias(X509KeyManager keyManager, X500Principal[] issuers,
1245            String[] keyTypes) {
1246        return keyManager.chooseClientAlias(keyTypes, null, this);
1247    }
1248
1249    @Override
1250    public String chooseServerPSKIdentityHint(PSKKeyManager keyManager) {
1251        return keyManager.chooseServerKeyIdentityHint(this);
1252    }
1253
1254    @Override
1255    public String chooseClientPSKIdentity(PSKKeyManager keyManager, String identityHint) {
1256        return keyManager.chooseClientKeyIdentity(identityHint, this);
1257    }
1258
1259    @Override
1260    public SecretKey getPSKKey(PSKKeyManager keyManager, String identityHint, String identity) {
1261        return keyManager.getKey(identityHint, identity, this);
1262    }
1263}
1264