1/*
2 * Copyright (C) 2007 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import org.conscrypt.util.ArrayUtils;
20import dalvik.system.BlockGuard;
21import dalvik.system.CloseGuard;
22import java.io.FileDescriptor;
23import java.io.IOException;
24import java.io.InputStream;
25import java.io.OutputStream;
26import java.net.InetAddress;
27import java.net.Socket;
28import java.net.SocketException;
29import java.security.InvalidKeyException;
30import java.security.PrivateKey;
31import java.security.SecureRandom;
32import java.security.cert.CertificateEncodingException;
33import java.security.cert.CertificateException;
34import java.security.interfaces.ECKey;
35import java.security.spec.ECParameterSpec;
36import java.util.ArrayList;
37import javax.crypto.SecretKey;
38import javax.net.ssl.HandshakeCompletedEvent;
39import javax.net.ssl.HandshakeCompletedListener;
40import javax.net.ssl.SSLException;
41import javax.net.ssl.SSLHandshakeException;
42import javax.net.ssl.SSLProtocolException;
43import javax.net.ssl.SSLSession;
44import javax.net.ssl.X509KeyManager;
45import javax.net.ssl.X509TrustManager;
46import javax.security.auth.x500.X500Principal;
47
48/**
49 * Implementation of the class OpenSSLSocketImpl based on OpenSSL.
50 * <p>
51 * Extensions to SSLSocket include:
52 * <ul>
53 * <li>handshake timeout
54 * <li>session tickets
55 * <li>Server Name Indication
56 * </ul>
57 */
58public class OpenSSLSocketImpl
59        extends javax.net.ssl.SSLSocket
60        implements NativeCrypto.SSLHandshakeCallbacks, SSLParametersImpl.AliasChooser,
61        SSLParametersImpl.PSKCallbacks {
62
63    private static final boolean DBG_STATE = false;
64
65    /**
66     * Protects handshakeStarted and handshakeCompleted.
67     */
68    private final Object stateLock = new Object();
69
70    /**
71     * The {@link OpenSSLSocketImpl} object is constructed, but {@link #startHandshake()}
72     * has not yet been called.
73     */
74    private static final int STATE_NEW = 0;
75
76    /**
77     * {@link #startHandshake()} has been called at least once.
78     */
79    private static final int STATE_HANDSHAKE_STARTED = 1;
80
81    /**
82     * {@link #handshakeCompleted()} has been called, but {@link #startHandshake()} hasn't
83     * returned yet.
84     */
85    private static final int STATE_HANDSHAKE_COMPLETED = 2;
86
87    /**
88     * {@link #startHandshake()} has completed but {@link #handshakeCompleted()} hasn't
89     * been called. This is expected behaviour in cut-through mode, where SSL_do_handshake
90     * returns before the handshake is complete. We can now start writing data to the socket.
91     */
92    private static final int STATE_READY_HANDSHAKE_CUT_THROUGH = 3;
93
94    /**
95     * {@link #startHandshake()} has completed and {@link #handshakeCompleted()} has been
96     * called.
97     */
98    private static final int STATE_READY = 4;
99
100    /**
101     * {@link #close()} has been called at least once.
102     */
103    private static final int STATE_CLOSED = 5;
104
105    // @GuardedBy("stateLock");
106    private int state = STATE_NEW;
107
108    /**
109     * Protected by synchronizing on stateLock. Starts as 0, set by
110     * startHandshake, reset to 0 on close.
111     */
112    // @GuardedBy("stateLock");
113    private long sslNativePointer;
114
115    /**
116     * Protected by synchronizing on stateLock. Starts as null, set by
117     * getInputStream.
118     */
119    // @GuardedBy("stateLock");
120    private SSLInputStream is;
121
122    /**
123     * Protected by synchronizing on stateLock. Starts as null, set by
124     * getInputStream.
125     */
126    // @GuardedBy("stateLock");
127    private SSLOutputStream os;
128
129    private final Socket socket;
130    private final boolean autoClose;
131
132    /**
133     * The peer's DNS hostname if it was supplied during creation.
134     */
135    private String peerHostname;
136
137    /**
138     * The DNS hostname from reverse lookup on the socket. Should never be used
139     * for Server Name Indication (SNI).
140     */
141    private String resolvedHostname;
142
143    /**
144     * The peer's port if it was supplied during creation. Should only be set if
145     * {@link #peerHostname} is also set.
146     */
147    private final int peerPort;
148
149    private final SSLParametersImpl sslParameters;
150    private final CloseGuard guard = CloseGuard.get();
151
152    private ArrayList<HandshakeCompletedListener> listeners;
153
154    /**
155     * Private key for the TLS Channel ID extension. This field is client-side
156     * only. Set during startHandshake.
157     */
158    OpenSSLKey channelIdPrivateKey;
159
160    /** Set during startHandshake. */
161    private OpenSSLSessionImpl sslSession;
162
163    /** Used during handshake callbacks. */
164    private OpenSSLSessionImpl handshakeSession;
165
166    /**
167     * Local cache of timeout to avoid getsockopt on every read and
168     * write for non-wrapped sockets. Note that
169     * OpenSSLSocketImplWrapper overrides setSoTimeout and
170     * getSoTimeout to delegate to the wrapped socket.
171     */
172    private int readTimeoutMilliseconds = 0;
173    private int writeTimeoutMilliseconds = 0;
174
175    private int handshakeTimeoutMilliseconds = -1;  // -1 = same as timeout; 0 = infinite
176
177    protected OpenSSLSocketImpl(SSLParametersImpl sslParameters) throws IOException {
178        this.socket = this;
179        this.peerHostname = null;
180        this.peerPort = -1;
181        this.autoClose = false;
182        this.sslParameters = sslParameters;
183    }
184
185    protected OpenSSLSocketImpl(String hostname, int port, SSLParametersImpl sslParameters)
186            throws IOException {
187        super(hostname, port);
188        this.socket = this;
189        this.peerHostname = hostname;
190        this.peerPort = port;
191        this.autoClose = false;
192        this.sslParameters = sslParameters;
193    }
194
195    protected OpenSSLSocketImpl(InetAddress address, int port, SSLParametersImpl sslParameters)
196            throws IOException {
197        super(address, port);
198        this.socket = this;
199        this.peerHostname = null;
200        this.peerPort = -1;
201        this.autoClose = false;
202        this.sslParameters = sslParameters;
203    }
204
205
206    protected OpenSSLSocketImpl(String hostname, int port,
207                                InetAddress clientAddress, int clientPort,
208                                SSLParametersImpl sslParameters) throws IOException {
209        super(hostname, port, clientAddress, clientPort);
210        this.socket = this;
211        this.peerHostname = hostname;
212        this.peerPort = port;
213        this.autoClose = false;
214        this.sslParameters = sslParameters;
215    }
216
217    protected OpenSSLSocketImpl(InetAddress address, int port,
218                                InetAddress clientAddress, int clientPort,
219                                SSLParametersImpl sslParameters) throws IOException {
220        super(address, port, clientAddress, clientPort);
221        this.socket = this;
222        this.peerHostname = null;
223        this.peerPort = -1;
224        this.autoClose = false;
225        this.sslParameters = sslParameters;
226    }
227
228    /**
229     * Create an SSL socket that wraps another socket. Invoked by
230     * OpenSSLSocketImplWrapper constructor.
231     */
232    protected OpenSSLSocketImpl(Socket socket, String hostname, int port,
233            boolean autoClose, SSLParametersImpl sslParameters) throws IOException {
234        this.socket = socket;
235        this.peerHostname = hostname;
236        this.peerPort = port;
237        this.autoClose = autoClose;
238        this.sslParameters = sslParameters;
239
240        // this.timeout is not set intentionally.
241        // OpenSSLSocketImplWrapper.getSoTimeout will delegate timeout
242        // to wrapped socket
243    }
244
245    private void checkOpen() throws SocketException {
246        if (isClosed()) {
247            throw new SocketException("Socket is closed");
248        }
249    }
250
251    /**
252     * Starts a TLS/SSL handshake on this connection using some native methods
253     * from the OpenSSL library. It can negotiate new encryption keys, change
254     * cipher suites, or initiate a new session. The certificate chain is
255     * verified if the correspondent property in java.Security is set. All
256     * listeners are notified at the end of the TLS/SSL handshake.
257     */
258    @Override
259    public void startHandshake() throws IOException {
260        checkOpen();
261        synchronized (stateLock) {
262            if (state == STATE_NEW) {
263                state = STATE_HANDSHAKE_STARTED;
264            } else {
265                // We've either started the handshake already or have been closed.
266                // Do nothing in both cases.
267                return;
268            }
269        }
270
271        // note that this modifies the global seed, not something specific to the connection
272        final int seedLengthInBytes = NativeCrypto.RAND_SEED_LENGTH_IN_BYTES;
273        final SecureRandom secureRandom = sslParameters.getSecureRandomMember();
274        if (secureRandom == null) {
275            NativeCrypto.RAND_load_file("/dev/urandom", seedLengthInBytes);
276        } else {
277            NativeCrypto.RAND_seed(secureRandom.generateSeed(seedLengthInBytes));
278        }
279
280        final boolean client = sslParameters.getUseClientMode();
281
282        sslNativePointer = 0;
283        boolean releaseResources = true;
284        try {
285            final AbstractSessionContext sessionContext = sslParameters.getSessionContext();
286            final long sslCtxNativePointer = sessionContext.sslCtxNativePointer;
287            sslNativePointer = NativeCrypto.SSL_new(sslCtxNativePointer);
288            guard.open("close");
289
290            boolean enableSessionCreation = getEnableSessionCreation();
291            if (!enableSessionCreation) {
292                NativeCrypto.SSL_set_session_creation_enabled(sslNativePointer,
293                        enableSessionCreation);
294            }
295
296            // Allow servers to trigger renegotiation. Some inadvisable server
297            // configurations cause them to attempt to renegotiate during
298            // certain protocols.
299            NativeCrypto.SSL_set_reject_peer_renegotiations(sslNativePointer, false);
300
301            final OpenSSLSessionImpl sessionToReuse = sslParameters.getSessionToReuse(
302                    sslNativePointer, getHostname(), getPort());
303            sslParameters.setSSLParameters(sslCtxNativePointer, sslNativePointer, this, this,
304                    peerHostname);
305            sslParameters.setCertificateValidation(sslNativePointer);
306            sslParameters.setTlsChannelId(sslNativePointer, channelIdPrivateKey);
307
308            // Temporarily use a different timeout for the handshake process
309            int savedReadTimeoutMilliseconds = getSoTimeout();
310            int savedWriteTimeoutMilliseconds = getSoWriteTimeout();
311            if (handshakeTimeoutMilliseconds >= 0) {
312                setSoTimeout(handshakeTimeoutMilliseconds);
313                setSoWriteTimeout(handshakeTimeoutMilliseconds);
314            }
315
316            synchronized (stateLock) {
317                if (state == STATE_CLOSED) {
318                    return;
319                }
320            }
321
322            long sslSessionNativePointer;
323            try {
324                sslSessionNativePointer = NativeCrypto.SSL_do_handshake(sslNativePointer,
325                        Platform.getFileDescriptor(socket), this, getSoTimeout(), client,
326                        sslParameters.npnProtocols, client ? null : sslParameters.alpnProtocols);
327            } catch (CertificateException e) {
328                SSLHandshakeException wrapper = new SSLHandshakeException(e.getMessage());
329                wrapper.initCause(e);
330                throw wrapper;
331            } catch (SSLException e) {
332                // Swallow this exception if it's thrown as the result of an interruption.
333                //
334                // TODO: SSL_read and SSL_write return -1 when interrupted, but SSL_do_handshake
335                // will throw the last sslError that it saw before sslSelect, usually SSL_WANT_READ
336                // (or WANT_WRITE). Catching that exception here doesn't seem much worse than
337                // changing the native code to return a "special" native pointer value when that
338                // happens.
339                synchronized (stateLock) {
340                    if (state == STATE_CLOSED) {
341                        return;
342                    }
343                }
344
345                // Write CCS errors to EventLog
346                String message = e.getMessage();
347                // Must match error string of SSL_R_UNEXPECTED_CCS
348                if (message.contains("unexpected CCS")) {
349                    String logMessage = String.format("ssl_unexpected_ccs: host=%s",
350                            getHostname());
351                    Platform.logEvent(logMessage);
352                }
353
354                throw e;
355            }
356
357            boolean handshakeCompleted = false;
358            synchronized (stateLock) {
359                if (state == STATE_HANDSHAKE_COMPLETED) {
360                    handshakeCompleted = true;
361                } else if (state == STATE_CLOSED) {
362                    return;
363                }
364            }
365
366            sslSession = sslParameters.setupSession(sslSessionNativePointer, sslNativePointer,
367                    sessionToReuse, getHostname(), getPort(), handshakeCompleted);
368
369            // Restore the original timeout now that the handshake is complete
370            if (handshakeTimeoutMilliseconds >= 0) {
371                setSoTimeout(savedReadTimeoutMilliseconds);
372                setSoWriteTimeout(savedWriteTimeoutMilliseconds);
373            }
374
375            // if not, notifyHandshakeCompletedListeners later in handshakeCompleted() callback
376            if (handshakeCompleted) {
377                notifyHandshakeCompletedListeners();
378            }
379
380            synchronized (stateLock) {
381                releaseResources = (state == STATE_CLOSED);
382
383                if (state == STATE_HANDSHAKE_STARTED) {
384                    state = STATE_READY_HANDSHAKE_CUT_THROUGH;
385                } else if (state == STATE_HANDSHAKE_COMPLETED) {
386                    state = STATE_READY;
387                }
388
389                if (!releaseResources) {
390                    // Unblock threads that are waiting for our state to transition
391                    // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH.
392                    stateLock.notifyAll();
393                }
394            }
395        } catch (SSLProtocolException e) {
396            throw (SSLHandshakeException) new SSLHandshakeException("Handshake failed")
397                    .initCause(e);
398        } finally {
399            // on exceptional exit, treat the socket as closed
400            if (releaseResources) {
401                synchronized (stateLock) {
402                    // Mark the socket as closed since we might have reached this as
403                    // a result on an exception thrown by the handshake process.
404                    //
405                    // The state will already be set to closed if we reach this as a result of
406                    // an early return or an interruption due to a concurrent call to close().
407                    state = STATE_CLOSED;
408                    stateLock.notifyAll();
409                }
410
411                try {
412                    shutdownAndFreeSslNative();
413                } catch (IOException ignored) {
414
415                }
416            }
417        }
418    }
419
420    /**
421     * Returns the hostname that was supplied during socket creation or tries to
422     * look up the hostname via the supplied socket address if possible. This
423     * may result in the return of a IP address and should not be used for
424     * Server Name Indication (SNI).
425     */
426    private String getHostname() {
427        if (peerHostname != null) {
428            return peerHostname;
429        }
430        if (resolvedHostname == null) {
431            InetAddress inetAddress = super.getInetAddress();
432            if (inetAddress != null) {
433                resolvedHostname = inetAddress.getHostName();
434            }
435        }
436        return resolvedHostname;
437    }
438
439    @Override
440    public int getPort() {
441        return peerPort == -1 ? super.getPort() : peerPort;
442    }
443
444    @Override
445    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / client_cert_cb
446    public void clientCertificateRequested(byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals)
447            throws CertificateEncodingException, SSLException {
448        sslParameters.chooseClientCertificate(keyTypeBytes, asn1DerEncodedPrincipals,
449                sslNativePointer, this);
450    }
451
452    @Override
453    @SuppressWarnings("unused") // used by native psk_client_callback
454    public int clientPSKKeyRequested(String identityHint, byte[] identity, byte[] key) {
455        return sslParameters.clientPSKKeyRequested(identityHint, identity, key, this);
456    }
457
458    @Override
459    @SuppressWarnings("unused") // used by native psk_server_callback
460    public int serverPSKKeyRequested(String identityHint, String identity, byte[] key) {
461        return sslParameters.serverPSKKeyRequested(identityHint, identity, key, this);
462    }
463
464    @Override
465    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / info_callback
466    public void onSSLStateChange(long sslSessionNativePtr, int type, int val) {
467        if (type != NativeConstants.SSL_CB_HANDSHAKE_DONE) {
468            return;
469        }
470
471        synchronized (stateLock) {
472            if (state == STATE_HANDSHAKE_STARTED) {
473                // If sslSession is null, the handshake was completed during
474                // the call to NativeCrypto.SSL_do_handshake and not during a
475                // later read operation. That means we do not need to fix up
476                // the SSLSession and session cache or notify
477                // HandshakeCompletedListeners, it will be done in
478                // startHandshake.
479
480                state = STATE_HANDSHAKE_COMPLETED;
481                return;
482            } else if (state == STATE_READY_HANDSHAKE_CUT_THROUGH) {
483                // We've returned from startHandshake, which means we've set a sslSession etc.
484                // we need to fix them up, which we'll do outside this lock.
485            } else if (state == STATE_CLOSED) {
486                // Someone called "close" but the handshake hasn't been interrupted yet.
487                return;
488            }
489        }
490
491        // reset session id from the native pointer and update the
492        // appropriate cache.
493        sslSession.resetId();
494        AbstractSessionContext sessionContext =
495            (sslParameters.getUseClientMode())
496            ? sslParameters.getClientSessionContext()
497                : sslParameters.getServerSessionContext();
498        sessionContext.putSession(sslSession);
499
500        // let listeners know we are finally done
501        notifyHandshakeCompletedListeners();
502
503        synchronized (stateLock) {
504            // Now that we've fixed up our state, we can tell waiting threads that
505            // we're ready.
506            state = STATE_READY;
507            // Notify all threads waiting for the handshake to complete.
508            stateLock.notifyAll();
509        }
510    }
511
512    private void notifyHandshakeCompletedListeners() {
513        if (listeners != null && !listeners.isEmpty()) {
514            // notify the listeners
515            HandshakeCompletedEvent event =
516                new HandshakeCompletedEvent(this, sslSession);
517            for (HandshakeCompletedListener listener : listeners) {
518                try {
519                    listener.handshakeCompleted(event);
520                } catch (RuntimeException e) {
521                    // The RI runs the handlers in a separate thread,
522                    // which we do not. But we try to preserve their
523                    // behavior of logging a problem and not killing
524                    // the handshaking thread just because a listener
525                    // has a problem.
526                    Thread thread = Thread.currentThread();
527                    thread.getUncaughtExceptionHandler().uncaughtException(thread, e);
528                }
529            }
530        }
531    }
532
533    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks
534    @Override
535    public void verifyCertificateChain(long sslSessionNativePtr, long[] certRefs, String authMethod)
536            throws CertificateException {
537        try {
538            X509TrustManager x509tm = sslParameters.getX509TrustManager();
539            if (x509tm == null) {
540                throw new CertificateException("No X.509 TrustManager");
541            }
542            if (certRefs == null || certRefs.length == 0) {
543                throw new SSLException("Peer sent no certificate");
544            }
545            OpenSSLX509Certificate[] peerCertChain = new OpenSSLX509Certificate[certRefs.length];
546            for (int i = 0; i < certRefs.length; i++) {
547                peerCertChain[i] = new OpenSSLX509Certificate(certRefs[i]);
548            }
549
550            // Used for verifyCertificateChain callback
551            handshakeSession = new OpenSSLSessionImpl(sslSessionNativePtr, null, peerCertChain,
552                    getHostname(), getPort(), null);
553
554            boolean client = sslParameters.getUseClientMode();
555            if (client) {
556                Platform.checkServerTrusted(x509tm, peerCertChain, authMethod, getHostname());
557            } else {
558                String authType = peerCertChain[0].getPublicKey().getAlgorithm();
559                x509tm.checkClientTrusted(peerCertChain, authType);
560            }
561        } catch (CertificateException e) {
562            throw e;
563        } catch (Exception e) {
564            throw new CertificateException(e);
565        } finally {
566            // Clear this before notifying handshake completed listeners
567            handshakeSession = null;
568        }
569    }
570
571    @Override
572    public InputStream getInputStream() throws IOException {
573        checkOpen();
574
575        InputStream returnVal;
576        synchronized (stateLock) {
577            if (state == STATE_CLOSED) {
578                throw new SocketException("Socket is closed.");
579            }
580
581            if (is == null) {
582                is = new SSLInputStream();
583            }
584
585            returnVal = is;
586        }
587
588        // Block waiting for a handshake without a lock held. It's possible that the socket
589        // is closed at this point. If that happens, we'll still return the input stream but
590        // all reads on it will throw.
591        waitForHandshake();
592        return returnVal;
593    }
594
595    @Override
596    public OutputStream getOutputStream() throws IOException {
597        checkOpen();
598
599        OutputStream returnVal;
600        synchronized (stateLock) {
601            if (state == STATE_CLOSED) {
602                throw new SocketException("Socket is closed.");
603            }
604
605            if (os == null) {
606                os = new SSLOutputStream();
607            }
608
609            returnVal = os;
610        }
611
612        // Block waiting for a handshake without a lock held. It's possible that the socket
613        // is closed at this point. If that happens, we'll still return the output stream but
614        // all writes on it will throw.
615        waitForHandshake();
616        return returnVal;
617    }
618
619    private void assertReadableOrWriteableState() {
620        if (state == STATE_READY || state == STATE_READY_HANDSHAKE_CUT_THROUGH) {
621            return;
622        }
623
624        throw new AssertionError("Invalid state: " + state);
625    }
626
627
628    private void waitForHandshake() throws IOException {
629        startHandshake();
630
631        synchronized (stateLock) {
632            while (state != STATE_READY &&
633                    state != STATE_READY_HANDSHAKE_CUT_THROUGH &&
634                    state != STATE_CLOSED) {
635                try {
636                    stateLock.wait();
637                } catch (InterruptedException e) {
638                    Thread.currentThread().interrupt();
639                    IOException ioe = new IOException("Interrupted waiting for handshake");
640                    ioe.initCause(e);
641
642                    throw ioe;
643                }
644            }
645
646            if (state == STATE_CLOSED) {
647                throw new SocketException("Socket is closed");
648            }
649        }
650    }
651
652    /**
653     * This inner class provides input data stream functionality
654     * for the OpenSSL native implementation. It is used to
655     * read data received via SSL protocol.
656     */
657    private class SSLInputStream extends InputStream {
658        /**
659         * OpenSSL only lets one thread read at a time, so this is used to
660         * make sure we serialize callers of SSL_read. Thread is already
661         * expected to have completed handshaking.
662         */
663        private final Object readLock = new Object();
664
665        SSLInputStream() {
666        }
667
668        /**
669         * Reads one byte. If there is no data in the underlying buffer,
670         * this operation can block until the data will be
671         * available.
672         * @return read value.
673         * @throws IOException
674         */
675        @Override
676        public int read() throws IOException {
677            byte[] buffer = new byte[1];
678            int result = read(buffer, 0, 1);
679            return (result != -1) ? buffer[0] & 0xff : -1;
680        }
681
682        /**
683         * Method acts as described in spec for superclass.
684         * @see java.io.InputStream#read(byte[],int,int)
685         */
686        @Override
687        public int read(byte[] buf, int offset, int byteCount) throws IOException {
688            BlockGuard.getThreadPolicy().onNetwork();
689
690            checkOpen();
691            ArrayUtils.checkOffsetAndCount(buf.length, offset, byteCount);
692            if (byteCount == 0) {
693                return 0;
694            }
695
696            synchronized (readLock) {
697                synchronized (stateLock) {
698                    if (state == STATE_CLOSED) {
699                        throw new SocketException("socket is closed");
700                    }
701
702                    if (DBG_STATE) assertReadableOrWriteableState();
703                }
704
705                return NativeCrypto.SSL_read(sslNativePointer, Platform.getFileDescriptor(socket),
706                        OpenSSLSocketImpl.this, buf, offset, byteCount, getSoTimeout());
707            }
708        }
709
710        public void awaitPendingOps() {
711            if (DBG_STATE) {
712                synchronized (stateLock) {
713                    if (state != STATE_CLOSED) throw new AssertionError("State is: " + state);
714                }
715            }
716
717            synchronized (readLock) { }
718        }
719    }
720
721    /**
722     * This inner class provides output data stream functionality
723     * for the OpenSSL native implementation. It is used to
724     * write data according to the encryption parameters given in SSL context.
725     */
726    private class SSLOutputStream extends OutputStream {
727
728        /**
729         * OpenSSL only lets one thread write at a time, so this is used
730         * to make sure we serialize callers of SSL_write. Thread is
731         * already expected to have completed handshaking.
732         */
733        private final Object writeLock = new Object();
734
735        SSLOutputStream() {
736        }
737
738        /**
739         * Method acts as described in spec for superclass.
740         * @see java.io.OutputStream#write(int)
741         */
742        @Override
743        public void write(int oneByte) throws IOException {
744            byte[] buffer = new byte[1];
745            buffer[0] = (byte) (oneByte & 0xff);
746            write(buffer);
747        }
748
749        /**
750         * Method acts as described in spec for superclass.
751         * @see java.io.OutputStream#write(byte[],int,int)
752         */
753        @Override
754        public void write(byte[] buf, int offset, int byteCount) throws IOException {
755            BlockGuard.getThreadPolicy().onNetwork();
756            checkOpen();
757            ArrayUtils.checkOffsetAndCount(buf.length, offset, byteCount);
758            if (byteCount == 0) {
759                return;
760            }
761
762            synchronized (writeLock) {
763                synchronized (stateLock) {
764                    if (state == STATE_CLOSED) {
765                        throw new SocketException("socket is closed");
766                    }
767
768                    if (DBG_STATE) assertReadableOrWriteableState();
769                }
770
771                NativeCrypto.SSL_write(sslNativePointer, Platform.getFileDescriptor(socket),
772                        OpenSSLSocketImpl.this, buf, offset, byteCount, writeTimeoutMilliseconds);
773            }
774        }
775
776
777        public void awaitPendingOps() {
778            if (DBG_STATE) {
779                synchronized (stateLock) {
780                    if (state != STATE_CLOSED) throw new AssertionError("State is: " + state);
781                }
782            }
783
784            synchronized (writeLock) { }
785        }
786    }
787
788
789    @Override
790    public SSLSession getSession() {
791        if (sslSession == null) {
792            try {
793                waitForHandshake();
794            } catch (IOException e) {
795                // return an invalid session with
796                // invalid cipher suite of "SSL_NULL_WITH_NULL_NULL"
797                return SSLNullSession.getNullSession();
798            }
799        }
800        return sslSession;
801    }
802
803    @Override
804    public void addHandshakeCompletedListener(
805            HandshakeCompletedListener listener) {
806        if (listener == null) {
807            throw new IllegalArgumentException("Provided listener is null");
808        }
809        if (listeners == null) {
810            listeners = new ArrayList<HandshakeCompletedListener>();
811        }
812        listeners.add(listener);
813    }
814
815    @Override
816    public void removeHandshakeCompletedListener(
817            HandshakeCompletedListener listener) {
818        if (listener == null) {
819            throw new IllegalArgumentException("Provided listener is null");
820        }
821        if (listeners == null) {
822            throw new IllegalArgumentException(
823                    "Provided listener is not registered");
824        }
825        if (!listeners.remove(listener)) {
826            throw new IllegalArgumentException(
827                    "Provided listener is not registered");
828        }
829    }
830
831    @Override
832    public boolean getEnableSessionCreation() {
833        return sslParameters.getEnableSessionCreation();
834    }
835
836    @Override
837    public void setEnableSessionCreation(boolean flag) {
838        sslParameters.setEnableSessionCreation(flag);
839    }
840
841    @Override
842    public String[] getSupportedCipherSuites() {
843        return NativeCrypto.getSupportedCipherSuites();
844    }
845
846    @Override
847    public String[] getEnabledCipherSuites() {
848        return sslParameters.getEnabledCipherSuites();
849    }
850
851    @Override
852    public void setEnabledCipherSuites(String[] suites) {
853        sslParameters.setEnabledCipherSuites(suites);
854    }
855
856    @Override
857    public String[] getSupportedProtocols() {
858        return NativeCrypto.getSupportedProtocols();
859    }
860
861    @Override
862    public String[] getEnabledProtocols() {
863        return sslParameters.getEnabledProtocols();
864    }
865
866    @Override
867    public void setEnabledProtocols(String[] protocols) {
868        sslParameters.setEnabledProtocols(protocols);
869    }
870
871    /**
872     * This method enables session ticket support.
873     *
874     * @param useSessionTickets True to enable session tickets
875     */
876    public void setUseSessionTickets(boolean useSessionTickets) {
877        sslParameters.useSessionTickets = useSessionTickets;
878    }
879
880    /**
881     * This method enables Server Name Indication
882     *
883     * @param hostname the desired SNI hostname, or null to disable
884     */
885    public void setHostname(String hostname) {
886        sslParameters.setUseSni(hostname != null);
887        peerHostname = hostname;
888    }
889
890    /**
891     * Enables/disables TLS Channel ID for this server socket.
892     *
893     * <p>This method needs to be invoked before the handshake starts.
894     *
895     * @throws IllegalStateException if this is a client socket or if the handshake has already
896     *         started.
897
898     */
899    public void setChannelIdEnabled(boolean enabled) {
900        if (getUseClientMode()) {
901            throw new IllegalStateException("Client mode");
902        }
903
904        synchronized (stateLock) {
905            if (state != STATE_NEW) {
906                throw new IllegalStateException(
907                        "Could not enable/disable Channel ID after the initial handshake has"
908                                + " begun.");
909            }
910        }
911        sslParameters.channelIdEnabled = enabled;
912    }
913
914    /**
915     * Gets the TLS Channel ID for this server socket. Channel ID is only available once the
916     * handshake completes.
917     *
918     * @return channel ID or {@code null} if not available.
919     *
920     * @throws IllegalStateException if this is a client socket or if the handshake has not yet
921     *         completed.
922     * @throws SSLException if channel ID is available but could not be obtained.
923     */
924    public byte[] getChannelId() throws SSLException {
925        if (getUseClientMode()) {
926            throw new IllegalStateException("Client mode");
927        }
928
929        synchronized (stateLock) {
930            if (state != STATE_READY) {
931                throw new IllegalStateException(
932                        "Channel ID is only available after handshake completes");
933            }
934        }
935        return NativeCrypto.SSL_get_tls_channel_id(sslNativePointer);
936    }
937
938    /**
939     * Sets the {@link PrivateKey} to be used for TLS Channel ID by this client socket.
940     *
941     * <p>This method needs to be invoked before the handshake starts.
942     *
943     * @param privateKey private key (enables TLS Channel ID) or {@code null} for no key (disables
944     *        TLS Channel ID). The private key must be an Elliptic Curve (EC) key based on the NIST
945     *        P-256 curve (aka SECG secp256r1 or ANSI X9.62 prime256v1).
946     *
947     * @throws IllegalStateException if this is a server socket or if the handshake has already
948     *         started.
949     */
950    public void setChannelIdPrivateKey(PrivateKey privateKey) {
951        if (!getUseClientMode()) {
952            throw new IllegalStateException("Server mode");
953        }
954
955        synchronized (stateLock) {
956            if (state != STATE_NEW) {
957                throw new IllegalStateException(
958                        "Could not change Channel ID private key after the initial handshake has"
959                                + " begun.");
960            }
961        }
962
963        if (privateKey == null) {
964            sslParameters.channelIdEnabled = false;
965            channelIdPrivateKey = null;
966        } else {
967            sslParameters.channelIdEnabled = true;
968            try {
969                ECParameterSpec ecParams = null;
970                if (privateKey instanceof ECKey) {
971                    ecParams = ((ECKey) privateKey).getParams();
972                }
973                if (ecParams == null) {
974                    // Assume this is a P-256 key, as specified in the contract of this method.
975                    ecParams =
976                            OpenSSLECGroupContext.getCurveByName("prime256v1").getECParameterSpec();
977                }
978                channelIdPrivateKey =
979                        OpenSSLKey.fromECPrivateKeyForTLSStackOnly(privateKey, ecParams);
980            } catch (InvalidKeyException e) {
981                // Will have error in startHandshake
982            }
983        }
984    }
985
986    @Override
987    public boolean getUseClientMode() {
988        return sslParameters.getUseClientMode();
989    }
990
991    @Override
992    public void setUseClientMode(boolean mode) {
993        synchronized (stateLock) {
994            if (state != STATE_NEW) {
995                throw new IllegalArgumentException(
996                        "Could not change the mode after the initial handshake has begun.");
997            }
998        }
999        sslParameters.setUseClientMode(mode);
1000    }
1001
1002    @Override
1003    public boolean getWantClientAuth() {
1004        return sslParameters.getWantClientAuth();
1005    }
1006
1007    @Override
1008    public boolean getNeedClientAuth() {
1009        return sslParameters.getNeedClientAuth();
1010    }
1011
1012    @Override
1013    public void setNeedClientAuth(boolean need) {
1014        sslParameters.setNeedClientAuth(need);
1015    }
1016
1017    @Override
1018    public void setWantClientAuth(boolean want) {
1019        sslParameters.setWantClientAuth(want);
1020    }
1021
1022    @Override
1023    public void sendUrgentData(int data) throws IOException {
1024        throw new SocketException("Method sendUrgentData() is not supported.");
1025    }
1026
1027    @Override
1028    public void setOOBInline(boolean on) throws SocketException {
1029        throw new SocketException("Methods sendUrgentData, setOOBInline are not supported.");
1030    }
1031
1032    @Override
1033    public void setSoTimeout(int readTimeoutMilliseconds) throws SocketException {
1034        super.setSoTimeout(readTimeoutMilliseconds);
1035        this.readTimeoutMilliseconds = readTimeoutMilliseconds;
1036    }
1037
1038    @Override
1039    public int getSoTimeout() throws SocketException {
1040        return readTimeoutMilliseconds;
1041    }
1042
1043    /**
1044     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
1045     */
1046    public void setSoWriteTimeout(int writeTimeoutMilliseconds) throws SocketException {
1047        this.writeTimeoutMilliseconds = writeTimeoutMilliseconds;
1048
1049        Platform.setSocketWriteTimeout(this, writeTimeoutMilliseconds);
1050    }
1051
1052    /**
1053     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
1054     */
1055    public int getSoWriteTimeout() throws SocketException {
1056        return writeTimeoutMilliseconds;
1057    }
1058
1059    /**
1060     * Set the handshake timeout on this socket.  This timeout is specified in
1061     * milliseconds and will be used only during the handshake process.
1062     */
1063    public void setHandshakeTimeout(int handshakeTimeoutMilliseconds) throws SocketException {
1064        this.handshakeTimeoutMilliseconds = handshakeTimeoutMilliseconds;
1065    }
1066
1067    @Override
1068    public void close() throws IOException {
1069        // TODO: Close SSL sockets using a background thread so they close gracefully.
1070
1071        SSLInputStream sslInputStream = null;
1072        SSLOutputStream sslOutputStream = null;
1073
1074        synchronized (stateLock) {
1075            if (state == STATE_CLOSED) {
1076                // close() has already been called, so do nothing and return.
1077                return;
1078            }
1079
1080            int oldState = state;
1081            state = STATE_CLOSED;
1082
1083            if (oldState == STATE_NEW) {
1084                // The handshake hasn't been started yet, so there's no OpenSSL related
1085                // state to clean up. We still need to close the underlying socket if
1086                // we're wrapping it and were asked to autoClose.
1087                closeUnderlyingSocket();
1088
1089                stateLock.notifyAll();
1090                return;
1091            }
1092
1093            if (oldState != STATE_READY && oldState != STATE_READY_HANDSHAKE_CUT_THROUGH) {
1094                // If we're in these states, we still haven't returned from startHandshake.
1095                // We call SSL_interrupt so that we can interrupt SSL_do_handshake and then
1096                // set the state to STATE_CLOSED. startHandshake will handle all cleanup
1097                // after SSL_do_handshake returns, so we don't have anything to do here.
1098                NativeCrypto.SSL_interrupt(sslNativePointer);
1099
1100                stateLock.notifyAll();
1101                return;
1102            }
1103
1104            stateLock.notifyAll();
1105            // We've already returned from startHandshake, so we potentially have
1106            // input and output streams to clean up.
1107            sslInputStream = is;
1108            sslOutputStream = os;
1109        }
1110
1111        // Don't bother interrupting unless we have something to interrupt.
1112        if (sslInputStream != null || sslOutputStream != null) {
1113            NativeCrypto.SSL_interrupt(sslNativePointer);
1114        }
1115
1116        // Wait for the input and output streams to finish any reads they have in
1117        // progress. If there are no reads in progress at this point, future reads will
1118        // throw because state == STATE_CLOSED
1119        if (sslInputStream != null) {
1120            sslInputStream.awaitPendingOps();
1121        }
1122        if (sslOutputStream != null) {
1123            sslOutputStream.awaitPendingOps();
1124        }
1125
1126        shutdownAndFreeSslNative();
1127    }
1128
1129    private void shutdownAndFreeSslNative() throws IOException {
1130        try {
1131            BlockGuard.getThreadPolicy().onNetwork();
1132            NativeCrypto.SSL_shutdown(sslNativePointer, Platform.getFileDescriptor(socket),
1133                    this);
1134        } catch (IOException ignored) {
1135            /*
1136            * Note that although close() can throw
1137            * IOException, the RI does not throw if there
1138            * is problem sending a "close notify" which
1139            * can happen if the underlying socket is closed.
1140            */
1141        } finally {
1142            free();
1143            closeUnderlyingSocket();
1144        }
1145    }
1146
1147    private void closeUnderlyingSocket() throws IOException {
1148        if (socket != this) {
1149            if (autoClose && !socket.isClosed()) {
1150                socket.close();
1151            }
1152        } else {
1153            if (!super.isClosed()) {
1154                super.close();
1155            }
1156        }
1157    }
1158
1159    private void free() {
1160        if (sslNativePointer == 0) {
1161            return;
1162        }
1163        NativeCrypto.SSL_free(sslNativePointer);
1164        sslNativePointer = 0;
1165        guard.close();
1166    }
1167
1168    @Override
1169    protected void finalize() throws Throwable {
1170        try {
1171            /*
1172             * Just worry about our own state. Notably we do not try and
1173             * close anything. The SocketImpl, either our own
1174             * PlainSocketImpl, or the Socket we are wrapping, will do
1175             * that. This might mean we do not properly SSL_shutdown, but
1176             * if you want to do that, properly close the socket yourself.
1177             *
1178             * The reason why we don't try to SSL_shutdown, is that there
1179             * can be a race between finalizers where the PlainSocketImpl
1180             * finalizer runs first and closes the socket. However, in the
1181             * meanwhile, the underlying file descriptor could be reused
1182             * for another purpose. If we call SSL_shutdown, the
1183             * underlying socket BIOs still have the old file descriptor
1184             * and will write the close notify to some unsuspecting
1185             * reader.
1186             */
1187            if (guard != null) {
1188                guard.warnIfOpen();
1189            }
1190            free();
1191        } finally {
1192            super.finalize();
1193        }
1194    }
1195
1196    /* @Override */
1197    public FileDescriptor getFileDescriptor$() {
1198        if (socket == this) {
1199            return Platform.getFileDescriptorFromSSLSocket(this);
1200        } else {
1201            return Platform.getFileDescriptor(socket);
1202        }
1203    }
1204
1205    /**
1206     * Returns the protocol agreed upon by client and server, or null if no
1207     * protocol was agreed upon.
1208     */
1209    public byte[] getNpnSelectedProtocol() {
1210        return NativeCrypto.SSL_get_npn_negotiated_protocol(sslNativePointer);
1211    }
1212
1213    /**
1214     * Returns the protocol agreed upon by client and server, or {@code null} if
1215     * no protocol was agreed upon.
1216     */
1217    public byte[] getAlpnSelectedProtocol() {
1218        return NativeCrypto.SSL_get0_alpn_selected(sslNativePointer);
1219    }
1220
1221    /**
1222     * Sets the list of protocols this peer is interested in. If null no
1223     * protocols will be used.
1224     *
1225     * @param npnProtocols a non-empty array of protocol names. From
1226     *     SSL_select_next_proto, "vector of 8-bit, length prefixed byte
1227     *     strings. The length byte itself is not included in the length. A byte
1228     *     string of length 0 is invalid. No byte string may be truncated.".
1229     */
1230    public void setNpnProtocols(byte[] npnProtocols) {
1231        if (npnProtocols != null && npnProtocols.length == 0) {
1232            throw new IllegalArgumentException("npnProtocols.length == 0");
1233        }
1234        sslParameters.npnProtocols = npnProtocols;
1235    }
1236
1237    /**
1238     * Sets the list of protocols this peer is interested in. If the list is
1239     * {@code null}, no protocols will be used.
1240     *
1241     * @param alpnProtocols a non-empty array of protocol names. From
1242     *            SSL_select_next_proto, "vector of 8-bit, length prefixed byte
1243     *            strings. The length byte itself is not included in the length.
1244     *            A byte string of length 0 is invalid. No byte string may be
1245     *            truncated.".
1246     */
1247    public void setAlpnProtocols(byte[] alpnProtocols) {
1248        if (alpnProtocols != null && alpnProtocols.length == 0) {
1249            throw new IllegalArgumentException("alpnProtocols.length == 0");
1250        }
1251        sslParameters.alpnProtocols = alpnProtocols;
1252    }
1253
1254    @Override
1255    public String chooseServerAlias(X509KeyManager keyManager, String keyType) {
1256        return keyManager.chooseServerAlias(keyType, null, this);
1257    }
1258
1259    @Override
1260    public String chooseClientAlias(X509KeyManager keyManager, X500Principal[] issuers,
1261            String[] keyTypes) {
1262        return keyManager.chooseClientAlias(keyTypes, null, this);
1263    }
1264
1265    @Override
1266    public String chooseServerPSKIdentityHint(PSKKeyManager keyManager) {
1267        return keyManager.chooseServerKeyIdentityHint(this);
1268    }
1269
1270    @Override
1271    public String chooseClientPSKIdentity(PSKKeyManager keyManager, String identityHint) {
1272        return keyManager.chooseClientKeyIdentity(identityHint, this);
1273    }
1274
1275    @Override
1276    public SecretKey getPSKKey(PSKKeyManager keyManager, String identityHint, String identity) {
1277        return keyManager.getKey(identityHint, identity, this);
1278    }
1279}
1280