1/*
2 * Copyright (C) 2007 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.io.FileDescriptor;
20import java.io.IOException;
21import java.io.InputStream;
22import java.io.OutputStream;
23import java.net.InetAddress;
24import java.net.InetSocketAddress;
25import java.net.Socket;
26import java.net.SocketAddress;
27import java.net.SocketException;
28import java.security.InvalidKeyException;
29import java.security.PrivateKey;
30import java.security.cert.CertificateEncodingException;
31import java.security.cert.CertificateException;
32import java.security.interfaces.ECKey;
33import java.security.spec.ECParameterSpec;
34import java.util.ArrayList;
35import javax.crypto.SecretKey;
36import javax.net.ssl.HandshakeCompletedEvent;
37import javax.net.ssl.HandshakeCompletedListener;
38import javax.net.ssl.SSLException;
39import javax.net.ssl.SSLHandshakeException;
40import javax.net.ssl.SSLParameters;
41import javax.net.ssl.SSLProtocolException;
42import javax.net.ssl.SSLSession;
43import javax.net.ssl.X509KeyManager;
44import javax.net.ssl.X509TrustManager;
45import javax.security.auth.x500.X500Principal;
46
47/**
48 * Implementation of the class OpenSSLSocketImpl based on OpenSSL.
49 * <p>
50 * Extensions to SSLSocket include:
51 * <ul>
52 * <li>handshake timeout
53 * <li>session tickets
54 * <li>Server Name Indication
55 * </ul>
56 *
57 * @hide
58 */
59@Internal
60public class OpenSSLSocketImpl
61        extends javax.net.ssl.SSLSocket
62        implements NativeCrypto.SSLHandshakeCallbacks, SSLParametersImpl.AliasChooser,
63        SSLParametersImpl.PSKCallbacks {
64
65    private static final boolean DBG_STATE = false;
66
67    /**
68     * Protects handshakeStarted and handshakeCompleted.
69     */
70    private final Object stateLock = new Object();
71
72    /**
73     * The {@link OpenSSLSocketImpl} object is constructed, but {@link #startHandshake()}
74     * has not yet been called.
75     */
76    private static final int STATE_NEW = 0;
77
78    /**
79     * {@link #startHandshake()} has been called at least once.
80     */
81    private static final int STATE_HANDSHAKE_STARTED = 1;
82
83    /**
84     * {@link HandshakeCompletedListener#handshakeCompleted} has been called, but
85     * {@link #startHandshake()} hasn't returned yet.
86     */
87    private static final int STATE_HANDSHAKE_COMPLETED = 2;
88
89    /**
90     * {@link #startHandshake()} has completed but
91     * {@link HandshakeCompletedListener#handshakeCompleted} hasn't been called. This is expected
92     * behaviour in cut-through mode, where SSL_do_handshake returns before the handshake is
93     * complete. We can now start writing data to the socket.
94     */
95    private static final int STATE_READY_HANDSHAKE_CUT_THROUGH = 3;
96
97    /**
98     * {@link #startHandshake()} has completed and
99     * {@link HandshakeCompletedListener#handshakeCompleted} has been called.
100     */
101    private static final int STATE_READY = 4;
102
103    /**
104     * {@link #close()} has been called at least once.
105     */
106    private static final int STATE_CLOSED = 5;
107
108    // @GuardedBy("stateLock");
109    private int state = STATE_NEW;
110
111    /**
112     * Protected by synchronizing on stateLock. Starts as 0, set by
113     * startHandshake, reset to 0 on close.
114     */
115    // @GuardedBy("stateLock");
116    private long sslNativePointer;
117
118    /**
119     * Protected by synchronizing on stateLock. Starts as null, set by
120     * getInputStream.
121     */
122    // @GuardedBy("stateLock");
123    private SSLInputStream is;
124
125    /**
126     * Protected by synchronizing on stateLock. Starts as null, set by
127     * getInputStream.
128     */
129    // @GuardedBy("stateLock");
130    private SSLOutputStream os;
131
132    private final Socket socket;
133    private final boolean autoClose;
134
135    /**
136     * The peer's DNS hostname if it was supplied during creation. Note that
137     * this may be a raw IP address, so it should be checked before use with
138     * extensions that don't use it like Server Name Indication (SNI).
139     */
140    private String peerHostname;
141
142    /**
143     * The peer's port if it was supplied during creation. Should only be set if
144     * {@link #peerHostname} is also set.
145     */
146    private final int peerPort;
147
148    private final SSLParametersImpl sslParameters;
149
150    /*
151     * A CloseGuard object on Android. On other platforms, this is nothing.
152     */
153    private final Object guard = Platform.closeGuardGet();
154
155    private ArrayList<HandshakeCompletedListener> listeners;
156
157    /**
158     * Private key for the TLS Channel ID extension. This field is client-side
159     * only. Set during startHandshake.
160     */
161    private OpenSSLKey channelIdPrivateKey;
162
163    /** Set during startHandshake. */
164    private AbstractOpenSSLSession sslSession;
165
166    /** Used during handshake callbacks. */
167    private AbstractOpenSSLSession handshakeSession;
168
169    /**
170     * Local cache of timeout to avoid getsockopt on every read and
171     * write for non-wrapped sockets. Note that
172     * OpenSSLSocketImplWrapper overrides setSoTimeout and
173     * getSoTimeout to delegate to the wrapped socket.
174     */
175    private int readTimeoutMilliseconds = 0;
176    private int writeTimeoutMilliseconds = 0;
177
178    private int handshakeTimeoutMilliseconds = -1;  // -1 = same as timeout; 0 = infinite
179
180    protected OpenSSLSocketImpl(SSLParametersImpl sslParameters) throws IOException {
181        this.socket = this;
182        this.peerHostname = null;
183        this.peerPort = -1;
184        this.autoClose = false;
185        this.sslParameters = sslParameters;
186    }
187
188    protected OpenSSLSocketImpl(String hostname, int port, SSLParametersImpl sslParameters)
189            throws IOException {
190        super(hostname, port);
191        this.socket = this;
192        this.peerHostname = hostname;
193        this.peerPort = port;
194        this.autoClose = false;
195        this.sslParameters = sslParameters;
196    }
197
198    protected OpenSSLSocketImpl(InetAddress address, int port, SSLParametersImpl sslParameters)
199            throws IOException {
200        super(address, port);
201        this.socket = this;
202        this.peerHostname = null;
203        this.peerPort = -1;
204        this.autoClose = false;
205        this.sslParameters = sslParameters;
206    }
207
208
209    protected OpenSSLSocketImpl(String hostname, int port,
210                                InetAddress clientAddress, int clientPort,
211                                SSLParametersImpl sslParameters) throws IOException {
212        super(hostname, port, clientAddress, clientPort);
213        this.socket = this;
214        this.peerHostname = hostname;
215        this.peerPort = port;
216        this.autoClose = false;
217        this.sslParameters = sslParameters;
218    }
219
220    protected OpenSSLSocketImpl(InetAddress address, int port,
221                                InetAddress clientAddress, int clientPort,
222                                SSLParametersImpl sslParameters) throws IOException {
223        super(address, port, clientAddress, clientPort);
224        this.socket = this;
225        this.peerHostname = null;
226        this.peerPort = -1;
227        this.autoClose = false;
228        this.sslParameters = sslParameters;
229    }
230
231    /**
232     * Create an SSL socket that wraps another socket. Invoked by
233     * OpenSSLSocketImplWrapper constructor.
234     */
235    protected OpenSSLSocketImpl(Socket socket, String hostname, int port,
236            boolean autoClose, SSLParametersImpl sslParameters) throws IOException {
237        this.socket = socket;
238        this.peerHostname = hostname;
239        this.peerPort = port;
240        this.autoClose = autoClose;
241        this.sslParameters = sslParameters;
242
243        // this.timeout is not set intentionally.
244        // OpenSSLSocketImplWrapper.getSoTimeout will delegate timeout
245        // to wrapped socket
246    }
247
248    @Override
249    public void connect(SocketAddress endpoint) throws IOException {
250        connect(endpoint, 0);
251    }
252
253    /**
254     * Try to extract the peer's hostname if it's available from the endpoint address.
255     */
256    @Override
257    public void connect(SocketAddress endpoint, int timeout) throws IOException {
258        if (peerHostname == null && endpoint instanceof InetSocketAddress) {
259            peerHostname = Platform.getHostStringFromInetSocketAddress(
260                    (InetSocketAddress) endpoint);
261        }
262
263        super.connect(endpoint, timeout);
264    }
265
266    private void checkOpen() throws SocketException {
267        if (isClosed()) {
268            throw new SocketException("Socket is closed");
269        }
270    }
271
272    /**
273     * Starts a TLS/SSL handshake on this connection using some native methods
274     * from the OpenSSL library. It can negotiate new encryption keys, change
275     * cipher suites, or initiate a new session. The certificate chain is
276     * verified if the correspondent property in java.Security is set. All
277     * listeners are notified at the end of the TLS/SSL handshake.
278     */
279    @Override
280    public void startHandshake() throws IOException {
281        checkOpen();
282        synchronized (stateLock) {
283            if (state == STATE_NEW) {
284                state = STATE_HANDSHAKE_STARTED;
285            } else {
286                // We've either started the handshake already or have been closed.
287                // Do nothing in both cases.
288                return;
289            }
290        }
291
292        final boolean client = sslParameters.getUseClientMode();
293
294        sslNativePointer = 0;
295        boolean releaseResources = true;
296        try {
297            final AbstractSessionContext sessionContext = sslParameters.getSessionContext();
298            sslNativePointer = NativeCrypto.SSL_new(sessionContext.sslCtxNativePointer);
299            Platform.closeGuardOpen(guard, "close");
300
301            boolean enableSessionCreation = getEnableSessionCreation();
302            if (!enableSessionCreation) {
303                NativeCrypto.SSL_set_session_creation_enabled(sslNativePointer,
304                        enableSessionCreation);
305            }
306
307            // Allow servers to trigger renegotiation. Some inadvisable server
308            // configurations cause them to attempt to renegotiate during
309            // certain protocols.
310            NativeCrypto.SSL_accept_renegotiations(sslNativePointer);
311
312            if (client) {
313                NativeCrypto.SSL_set_connect_state(sslNativePointer);
314
315                // Configure OCSP and CT extensions for client
316                NativeCrypto.SSL_enable_ocsp_stapling(sslNativePointer);
317                if (sslParameters.isCTVerificationEnabled(getHostname())) {
318                    NativeCrypto.SSL_enable_signed_cert_timestamps(sslNativePointer);
319                }
320            } else {
321                NativeCrypto.SSL_set_accept_state(sslNativePointer);
322
323                // Configure OCSP for server
324                if (sslParameters.getOCSPResponse() != null) {
325                    NativeCrypto.SSL_enable_ocsp_stapling(sslNativePointer);
326                }
327            }
328
329            final AbstractOpenSSLSession sessionToReuse =
330                    sslParameters.getSessionToReuse(sslNativePointer, getHostnameOrIP(), getPort());
331            sslParameters.setSSLParameters(sslNativePointer, this, this, getHostname());
332            sslParameters.setCertificateValidation(sslNativePointer);
333            sslParameters.setTlsChannelId(sslNativePointer, channelIdPrivateKey);
334
335            // Temporarily use a different timeout for the handshake process
336            int savedReadTimeoutMilliseconds = getSoTimeout();
337            int savedWriteTimeoutMilliseconds = getSoWriteTimeout();
338            if (handshakeTimeoutMilliseconds >= 0) {
339                setSoTimeout(handshakeTimeoutMilliseconds);
340                setSoWriteTimeout(handshakeTimeoutMilliseconds);
341            }
342
343            synchronized (stateLock) {
344                if (state == STATE_CLOSED) {
345                    return;
346                }
347            }
348
349            long sslSessionNativePointer;
350            try {
351                NativeCrypto.SSL_do_handshake(
352                        sslNativePointer, Platform.getFileDescriptor(socket), this, getSoTimeout());
353                sslSessionNativePointer = NativeCrypto.SSL_get1_session(sslNativePointer);
354            } catch (CertificateException e) {
355                SSLHandshakeException wrapper = new SSLHandshakeException(e.getMessage());
356                wrapper.initCause(e);
357                throw wrapper;
358            } catch (SSLException e) {
359                // Swallow this exception if it's thrown as the result of an interruption.
360                //
361                // TODO: SSL_read and SSL_write return -1 when interrupted, but SSL_do_handshake
362                // will throw the last sslError that it saw before sslSelect, usually SSL_WANT_READ
363                // (or WANT_WRITE). Catching that exception here doesn't seem much worse than
364                // changing the native code to return a "special" native pointer value when that
365                // happens.
366                synchronized (stateLock) {
367                    if (state == STATE_CLOSED) {
368                        return;
369                    }
370                }
371
372                // Write CCS errors to EventLog
373                String message = e.getMessage();
374                // Must match error string of SSL_R_UNEXPECTED_CCS
375                if (message.contains("unexpected CCS")) {
376                    String logMessage = String.format("ssl_unexpected_ccs: host=%s",
377                            getHostnameOrIP());
378                    Platform.logEvent(logMessage);
379                }
380
381                throw e;
382            }
383
384            boolean handshakeCompleted = false;
385            synchronized (stateLock) {
386                if (state == STATE_HANDSHAKE_COMPLETED) {
387                    handshakeCompleted = true;
388                } else if (state == STATE_CLOSED) {
389                    return;
390                }
391            }
392
393            sslSession = sslParameters.setupSession(sslSessionNativePointer, sslNativePointer,
394                    sessionToReuse, getHostnameOrIP(), getPort(), handshakeCompleted);
395
396            // Restore the original timeout now that the handshake is complete
397            if (handshakeTimeoutMilliseconds >= 0) {
398                setSoTimeout(savedReadTimeoutMilliseconds);
399                setSoWriteTimeout(savedWriteTimeoutMilliseconds);
400            }
401
402            // if not, notifyHandshakeCompletedListeners later in handshakeCompleted() callback
403            if (handshakeCompleted) {
404                notifyHandshakeCompletedListeners();
405            }
406
407            synchronized (stateLock) {
408                releaseResources = (state == STATE_CLOSED);
409
410                if (state == STATE_HANDSHAKE_STARTED) {
411                    state = STATE_READY_HANDSHAKE_CUT_THROUGH;
412                } else if (state == STATE_HANDSHAKE_COMPLETED) {
413                    state = STATE_READY;
414                }
415
416                if (!releaseResources) {
417                    // Unblock threads that are waiting for our state to transition
418                    // into STATE_READY or STATE_READY_HANDSHAKE_CUT_THROUGH.
419                    stateLock.notifyAll();
420                }
421            }
422        } catch (SSLProtocolException e) {
423            throw (SSLHandshakeException) new SSLHandshakeException("Handshake failed")
424                    .initCause(e);
425        } finally {
426            // on exceptional exit, treat the socket as closed
427            if (releaseResources) {
428                synchronized (stateLock) {
429                    // Mark the socket as closed since we might have reached this as
430                    // a result on an exception thrown by the handshake process.
431                    //
432                    // The state will already be set to closed if we reach this as a result of
433                    // an early return or an interruption due to a concurrent call to close().
434                    state = STATE_CLOSED;
435                    stateLock.notifyAll();
436                }
437
438                try {
439                    shutdownAndFreeSslNative();
440                } catch (IOException ignored) {
441
442                }
443            }
444        }
445    }
446
447    /**
448     * Returns the hostname that was supplied during socket creation. No DNS resolution is
449     * attempted before returning the hostname.
450     */
451    public String getHostname() {
452        return peerHostname;
453    }
454
455    /**
456     * For the purposes of an SSLSession, we want a way to represent the supplied hostname
457     * or the IP address in a textual representation. We do not want to perform reverse DNS
458     * lookups on this address.
459     */
460    public String getHostnameOrIP() {
461        if (peerHostname != null) {
462            return peerHostname;
463        }
464
465        InetAddress peerAddress = getInetAddress();
466        if (peerAddress != null) {
467            return peerAddress.getHostAddress();
468        }
469
470        return null;
471    }
472
473    @Override
474    public int getPort() {
475        return peerPort == -1 ? super.getPort() : peerPort;
476    }
477
478    @Override
479    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / client_cert_cb
480    public void clientCertificateRequested(byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals)
481            throws CertificateEncodingException, SSLException {
482        sslParameters.chooseClientCertificate(keyTypeBytes, asn1DerEncodedPrincipals,
483                sslNativePointer, this);
484    }
485
486    @Override
487    @SuppressWarnings("unused") // used by native psk_client_callback
488    public int clientPSKKeyRequested(String identityHint, byte[] identity, byte[] key) {
489        return sslParameters.clientPSKKeyRequested(identityHint, identity, key, this);
490    }
491
492    @Override
493    @SuppressWarnings("unused") // used by native psk_server_callback
494    public int serverPSKKeyRequested(String identityHint, String identity, byte[] key) {
495        return sslParameters.serverPSKKeyRequested(identityHint, identity, key, this);
496    }
497
498    @Override
499    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks / info_callback
500    public void onSSLStateChange(int type, int val) {
501        if (type != NativeConstants.SSL_CB_HANDSHAKE_DONE) {
502            return;
503        }
504
505        synchronized (stateLock) {
506            if (state == STATE_HANDSHAKE_STARTED) {
507                // If sslSession is null, the handshake was completed during
508                // the call to NativeCrypto.SSL_do_handshake and not during a
509                // later read operation. That means we do not need to fix up
510                // the SSLSession and session cache or notify
511                // HandshakeCompletedListeners, it will be done in
512                // startHandshake.
513
514                state = STATE_HANDSHAKE_COMPLETED;
515                return;
516            } else if (state == STATE_READY_HANDSHAKE_CUT_THROUGH) {
517                // We've returned from startHandshake, which means we've set a sslSession etc.
518                // we need to fix them up, which we'll do outside this lock.
519            } else if (state == STATE_CLOSED) {
520                // Someone called "close" but the handshake hasn't been interrupted yet.
521                return;
522            }
523        }
524
525        // reset session id from the native pointer and update the
526        // appropriate cache.
527        sslSession.resetId();
528        AbstractSessionContext sessionContext =
529            (sslParameters.getUseClientMode())
530            ? sslParameters.getClientSessionContext()
531                : sslParameters.getServerSessionContext();
532        sessionContext.putSession(sslSession);
533
534        // let listeners know we are finally done
535        notifyHandshakeCompletedListeners();
536
537        synchronized (stateLock) {
538            // Now that we've fixed up our state, we can tell waiting threads that
539            // we're ready.
540            state = STATE_READY;
541            // Notify all threads waiting for the handshake to complete.
542            stateLock.notifyAll();
543        }
544    }
545
546    void notifyHandshakeCompletedListeners() {
547        if (listeners != null && !listeners.isEmpty()) {
548            // notify the listeners
549            HandshakeCompletedEvent event =
550                new HandshakeCompletedEvent(this, sslSession);
551            for (HandshakeCompletedListener listener : listeners) {
552                try {
553                    listener.handshakeCompleted(event);
554                } catch (RuntimeException e) {
555                    // The RI runs the handlers in a separate thread,
556                    // which we do not. But we try to preserve their
557                    // behavior of logging a problem and not killing
558                    // the handshaking thread just because a listener
559                    // has a problem.
560                    Thread thread = Thread.currentThread();
561                    thread.getUncaughtExceptionHandler().uncaughtException(thread, e);
562                }
563            }
564        }
565    }
566
567    @SuppressWarnings("unused") // used by NativeCrypto.SSLHandshakeCallbacks
568    @Override
569    public void verifyCertificateChain(long[] certRefs, String authMethod)
570            throws CertificateException {
571        try {
572            X509TrustManager x509tm = sslParameters.getX509TrustManager();
573            if (x509tm == null) {
574                throw new CertificateException("No X.509 TrustManager");
575            }
576            if (certRefs == null || certRefs.length == 0) {
577                throw new SSLException("Peer sent no certificate");
578            }
579            OpenSSLX509Certificate[] peerCertChain =
580                    OpenSSLX509Certificate.createCertChain(certRefs);
581
582            byte[] ocspData = NativeCrypto.SSL_get_ocsp_response(sslNativePointer);
583            byte[] tlsSctData = NativeCrypto.SSL_get_signed_cert_timestamp_list(sslNativePointer);
584
585            // Used for verifyCertificateChain callback
586            handshakeSession = new OpenSSLSessionImpl(
587                    NativeCrypto.SSL_get1_session(sslNativePointer), null, peerCertChain, ocspData,
588                    tlsSctData, getHostnameOrIP(), getPort(), null);
589
590            boolean client = sslParameters.getUseClientMode();
591            if (client) {
592                Platform.checkServerTrusted(x509tm, peerCertChain, authMethod, this);
593            } else {
594                String authType = peerCertChain[0].getPublicKey().getAlgorithm();
595                Platform.checkClientTrusted(x509tm, peerCertChain, authType, this);
596            }
597        } catch (CertificateException e) {
598            throw e;
599        } catch (Exception e) {
600            throw new CertificateException(e);
601        } finally {
602            // Clear this before notifying handshake completed listeners
603            handshakeSession = null;
604        }
605    }
606
607    @Override
608    public InputStream getInputStream() throws IOException {
609        checkOpen();
610
611        InputStream returnVal;
612        synchronized (stateLock) {
613            if (state == STATE_CLOSED) {
614                throw new SocketException("Socket is closed.");
615            }
616
617            if (is == null) {
618                is = new SSLInputStream();
619            }
620
621            returnVal = is;
622        }
623
624        // Block waiting for a handshake without a lock held. It's possible that the socket
625        // is closed at this point. If that happens, we'll still return the input stream but
626        // all reads on it will throw.
627        waitForHandshake();
628        return returnVal;
629    }
630
631    @Override
632    public OutputStream getOutputStream() throws IOException {
633        checkOpen();
634
635        OutputStream returnVal;
636        synchronized (stateLock) {
637            if (state == STATE_CLOSED) {
638                throw new SocketException("Socket is closed.");
639            }
640
641            if (os == null) {
642                os = new SSLOutputStream();
643            }
644
645            returnVal = os;
646        }
647
648        // Block waiting for a handshake without a lock held. It's possible that the socket
649        // is closed at this point. If that happens, we'll still return the output stream but
650        // all writes on it will throw.
651        waitForHandshake();
652        return returnVal;
653    }
654
655    private void assertReadableOrWriteableState() {
656        if (state == STATE_READY || state == STATE_READY_HANDSHAKE_CUT_THROUGH) {
657            return;
658        }
659
660        throw new AssertionError("Invalid state: " + state);
661    }
662
663
664    private void waitForHandshake() throws IOException {
665        startHandshake();
666
667        synchronized (stateLock) {
668            while (state != STATE_READY &&
669                    state != STATE_READY_HANDSHAKE_CUT_THROUGH &&
670                    state != STATE_CLOSED) {
671                try {
672                    stateLock.wait();
673                } catch (InterruptedException e) {
674                    Thread.currentThread().interrupt();
675                    IOException ioe = new IOException("Interrupted waiting for handshake");
676                    ioe.initCause(e);
677
678                    throw ioe;
679                }
680            }
681
682            if (state == STATE_CLOSED) {
683                throw new SocketException("Socket is closed");
684            }
685        }
686    }
687
688    /**
689     * This inner class provides input data stream functionality
690     * for the OpenSSL native implementation. It is used to
691     * read data received via SSL protocol.
692     */
693    private class SSLInputStream extends InputStream {
694        /**
695         * OpenSSL only lets one thread read at a time, so this is used to
696         * make sure we serialize callers of SSL_read. Thread is already
697         * expected to have completed handshaking.
698         */
699        private final Object readLock = new Object();
700
701        SSLInputStream() {
702        }
703
704        /**
705         * Reads one byte. If there is no data in the underlying buffer,
706         * this operation can block until the data will be
707         * available.
708         * @return read value.
709         * @throws IOException
710         */
711        @Override
712        public int read() throws IOException {
713            byte[] buffer = new byte[1];
714            int result = read(buffer, 0, 1);
715            return (result != -1) ? buffer[0] & 0xff : -1;
716        }
717
718        /**
719         * Method acts as described in spec for superclass.
720         * @see java.io.InputStream#read(byte[],int,int)
721         */
722        @Override
723        public int read(byte[] buf, int offset, int byteCount) throws IOException {
724            Platform.blockGuardOnNetwork();
725
726            checkOpen();
727            ArrayUtils.checkOffsetAndCount(buf.length, offset, byteCount);
728            if (byteCount == 0) {
729                return 0;
730            }
731
732            synchronized (readLock) {
733                synchronized (stateLock) {
734                    if (state == STATE_CLOSED) {
735                        throw new SocketException("socket is closed");
736                    }
737
738                    if (DBG_STATE) assertReadableOrWriteableState();
739                }
740
741                return NativeCrypto.SSL_read(sslNativePointer, Platform.getFileDescriptor(socket),
742                        OpenSSLSocketImpl.this, buf, offset, byteCount, getSoTimeout());
743            }
744        }
745
746        public void awaitPendingOps() {
747            if (DBG_STATE) {
748                synchronized (stateLock) {
749                    if (state != STATE_CLOSED) throw new AssertionError("State is: " + state);
750                }
751            }
752
753            synchronized (readLock) { }
754        }
755    }
756
757    /**
758     * This inner class provides output data stream functionality
759     * for the OpenSSL native implementation. It is used to
760     * write data according to the encryption parameters given in SSL context.
761     */
762    private class SSLOutputStream extends OutputStream {
763
764        /**
765         * OpenSSL only lets one thread write at a time, so this is used
766         * to make sure we serialize callers of SSL_write. Thread is
767         * already expected to have completed handshaking.
768         */
769        private final Object writeLock = new Object();
770
771        SSLOutputStream() {
772        }
773
774        /**
775         * Method acts as described in spec for superclass.
776         * @see java.io.OutputStream#write(int)
777         */
778        @Override
779        public void write(int oneByte) throws IOException {
780            byte[] buffer = new byte[1];
781            buffer[0] = (byte) (oneByte & 0xff);
782            write(buffer);
783        }
784
785        /**
786         * Method acts as described in spec for superclass.
787         * @see java.io.OutputStream#write(byte[],int,int)
788         */
789        @Override
790        public void write(byte[] buf, int offset, int byteCount) throws IOException {
791            Platform.blockGuardOnNetwork();
792            checkOpen();
793            ArrayUtils.checkOffsetAndCount(buf.length, offset, byteCount);
794            if (byteCount == 0) {
795                return;
796            }
797
798            synchronized (writeLock) {
799                synchronized (stateLock) {
800                    if (state == STATE_CLOSED) {
801                        throw new SocketException("socket is closed");
802                    }
803
804                    if (DBG_STATE) assertReadableOrWriteableState();
805                }
806
807                NativeCrypto.SSL_write(sslNativePointer, Platform.getFileDescriptor(socket),
808                        OpenSSLSocketImpl.this, buf, offset, byteCount, writeTimeoutMilliseconds);
809            }
810        }
811
812
813        public void awaitPendingOps() {
814            if (DBG_STATE) {
815                synchronized (stateLock) {
816                    if (state != STATE_CLOSED) throw new AssertionError("State is: " + state);
817                }
818            }
819
820            synchronized (writeLock) { }
821        }
822    }
823
824
825    @Override
826    public SSLSession getSession() {
827        if (sslSession == null) {
828            try {
829                waitForHandshake();
830            } catch (IOException e) {
831                // return an invalid session with
832                // invalid cipher suite of "SSL_NULL_WITH_NULL_NULL"
833                return SSLNullSession.getNullSession();
834            }
835        }
836        return Platform.wrapSSLSession(sslSession);
837    }
838
839    /* @Override */
840    @SuppressWarnings("MissingOverride")  // For compilation with Java 6.
841    public SSLSession getHandshakeSession() {
842        return handshakeSession;
843    }
844
845    @Override
846    public void addHandshakeCompletedListener(
847            HandshakeCompletedListener listener) {
848        if (listener == null) {
849            throw new IllegalArgumentException("Provided listener is null");
850        }
851        if (listeners == null) {
852            listeners = new ArrayList<HandshakeCompletedListener>();
853        }
854        listeners.add(listener);
855    }
856
857    @Override
858    public void removeHandshakeCompletedListener(
859            HandshakeCompletedListener listener) {
860        if (listener == null) {
861            throw new IllegalArgumentException("Provided listener is null");
862        }
863        if (listeners == null) {
864            throw new IllegalArgumentException(
865                    "Provided listener is not registered");
866        }
867        if (!listeners.remove(listener)) {
868            throw new IllegalArgumentException(
869                    "Provided listener is not registered");
870        }
871    }
872
873    @Override
874    public boolean getEnableSessionCreation() {
875        return sslParameters.getEnableSessionCreation();
876    }
877
878    @Override
879    public void setEnableSessionCreation(boolean flag) {
880        sslParameters.setEnableSessionCreation(flag);
881    }
882
883    @Override
884    public String[] getSupportedCipherSuites() {
885        return NativeCrypto.getSupportedCipherSuites();
886    }
887
888    @Override
889    public String[] getEnabledCipherSuites() {
890        return sslParameters.getEnabledCipherSuites();
891    }
892
893    @Override
894    public void setEnabledCipherSuites(String[] suites) {
895        sslParameters.setEnabledCipherSuites(suites);
896    }
897
898    @Override
899    public String[] getSupportedProtocols() {
900        return NativeCrypto.getSupportedProtocols();
901    }
902
903    @Override
904    public String[] getEnabledProtocols() {
905        return sslParameters.getEnabledProtocols();
906    }
907
908    @Override
909    public void setEnabledProtocols(String[] protocols) {
910        sslParameters.setEnabledProtocols(protocols);
911    }
912
913    /**
914     * This method enables session ticket support.
915     *
916     * @param useSessionTickets True to enable session tickets
917     */
918    public void setUseSessionTickets(boolean useSessionTickets) {
919        sslParameters.setUseSessionTickets(useSessionTickets);
920    }
921
922    /**
923     * This method enables Server Name Indication
924     *
925     * @param hostname the desired SNI hostname, or null to disable
926     */
927    public void setHostname(String hostname) {
928        sslParameters.setUseSni(hostname != null);
929        peerHostname = hostname;
930    }
931
932    /**
933     * Enables/disables TLS Channel ID for this server socket.
934     *
935     * <p>This method needs to be invoked before the handshake starts.
936     *
937     * @throws IllegalStateException if this is a client socket or if the handshake has already
938     *         started.
939     */
940    public void setChannelIdEnabled(boolean enabled) {
941        if (getUseClientMode()) {
942            throw new IllegalStateException("Client mode");
943        }
944
945        synchronized (stateLock) {
946            if (state != STATE_NEW) {
947                throw new IllegalStateException(
948                        "Could not enable/disable Channel ID after the initial handshake has"
949                                + " begun.");
950            }
951        }
952        sslParameters.channelIdEnabled = enabled;
953    }
954
955    /**
956     * Gets the TLS Channel ID for this server socket. Channel ID is only available once the
957     * handshake completes.
958     *
959     * @return channel ID or {@code null} if not available.
960     *
961     * @throws IllegalStateException if this is a client socket or if the handshake has not yet
962     *         completed.
963     * @throws SSLException if channel ID is available but could not be obtained.
964     */
965    public byte[] getChannelId() throws SSLException {
966        if (getUseClientMode()) {
967            throw new IllegalStateException("Client mode");
968        }
969
970        synchronized (stateLock) {
971            if (state != STATE_READY) {
972                throw new IllegalStateException(
973                        "Channel ID is only available after handshake completes");
974            }
975        }
976        return NativeCrypto.SSL_get_tls_channel_id(sslNativePointer);
977    }
978
979    /**
980     * Sets the {@link PrivateKey} to be used for TLS Channel ID by this client socket.
981     *
982     * <p>This method needs to be invoked before the handshake starts.
983     *
984     * @param privateKey private key (enables TLS Channel ID) or {@code null} for no key (disables
985     *        TLS Channel ID). The private key must be an Elliptic Curve (EC) key based on the NIST
986     *        P-256 curve (aka SECG secp256r1 or ANSI X9.62 prime256v1).
987     *
988     * @throws IllegalStateException if this is a server socket or if the handshake has already
989     *         started.
990     */
991    public void setChannelIdPrivateKey(PrivateKey privateKey) {
992        if (!getUseClientMode()) {
993            throw new IllegalStateException("Server mode");
994        }
995
996        synchronized (stateLock) {
997            if (state != STATE_NEW) {
998                throw new IllegalStateException(
999                        "Could not change Channel ID private key after the initial handshake has"
1000                                + " begun.");
1001            }
1002        }
1003
1004        if (privateKey == null) {
1005            sslParameters.channelIdEnabled = false;
1006            channelIdPrivateKey = null;
1007        } else {
1008            sslParameters.channelIdEnabled = true;
1009            try {
1010                ECParameterSpec ecParams = null;
1011                if (privateKey instanceof ECKey) {
1012                    ecParams = ((ECKey) privateKey).getParams();
1013                }
1014                if (ecParams == null) {
1015                    // Assume this is a P-256 key, as specified in the contract of this method.
1016                    ecParams =
1017                            OpenSSLECGroupContext.getCurveByName("prime256v1").getECParameterSpec();
1018                }
1019                channelIdPrivateKey =
1020                        OpenSSLKey.fromECPrivateKeyForTLSStackOnly(privateKey, ecParams);
1021            } catch (InvalidKeyException e) {
1022                // Will have error in startHandshake
1023            }
1024        }
1025    }
1026
1027    @Override
1028    public boolean getUseClientMode() {
1029        return sslParameters.getUseClientMode();
1030    }
1031
1032    @Override
1033    public void setUseClientMode(boolean mode) {
1034        synchronized (stateLock) {
1035            if (state != STATE_NEW) {
1036                throw new IllegalArgumentException(
1037                        "Could not change the mode after the initial handshake has begun.");
1038            }
1039        }
1040        sslParameters.setUseClientMode(mode);
1041    }
1042
1043    @Override
1044    public boolean getWantClientAuth() {
1045        return sslParameters.getWantClientAuth();
1046    }
1047
1048    @Override
1049    public boolean getNeedClientAuth() {
1050        return sslParameters.getNeedClientAuth();
1051    }
1052
1053    @Override
1054    public void setNeedClientAuth(boolean need) {
1055        sslParameters.setNeedClientAuth(need);
1056    }
1057
1058    @Override
1059    public void setWantClientAuth(boolean want) {
1060        sslParameters.setWantClientAuth(want);
1061    }
1062
1063    @Override
1064    public void sendUrgentData(int data) throws IOException {
1065        throw new SocketException("Method sendUrgentData() is not supported.");
1066    }
1067
1068    @Override
1069    public void setOOBInline(boolean on) throws SocketException {
1070        throw new SocketException("Methods sendUrgentData, setOOBInline are not supported.");
1071    }
1072
1073    @Override
1074    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
1075    public void setSoTimeout(int readTimeoutMilliseconds) throws SocketException {
1076        if (socket != this) {
1077            socket.setSoTimeout(readTimeoutMilliseconds);
1078        } else {
1079            super.setSoTimeout(readTimeoutMilliseconds);
1080        }
1081
1082        this.readTimeoutMilliseconds = readTimeoutMilliseconds;
1083    }
1084
1085    @Override
1086    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
1087    public int getSoTimeout() throws SocketException {
1088        return readTimeoutMilliseconds;
1089    }
1090
1091    /**
1092     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
1093     */
1094    public void setSoWriteTimeout(int writeTimeoutMilliseconds) throws SocketException {
1095        this.writeTimeoutMilliseconds = writeTimeoutMilliseconds;
1096
1097        Platform.setSocketWriteTimeout(this, writeTimeoutMilliseconds);
1098    }
1099
1100    /**
1101     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
1102     */
1103    public int getSoWriteTimeout() throws SocketException {
1104        return writeTimeoutMilliseconds;
1105    }
1106
1107    /**
1108     * Set the handshake timeout on this socket.  This timeout is specified in
1109     * milliseconds and will be used only during the handshake process.
1110     */
1111    public void setHandshakeTimeout(int handshakeTimeoutMilliseconds) throws SocketException {
1112        this.handshakeTimeoutMilliseconds = handshakeTimeoutMilliseconds;
1113    }
1114
1115    @Override
1116    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
1117    public void close() throws IOException {
1118        // TODO: Close SSL sockets using a background thread so they close gracefully.
1119
1120        SSLInputStream sslInputStream = null;
1121        SSLOutputStream sslOutputStream = null;
1122
1123        synchronized (stateLock) {
1124            if (state == STATE_CLOSED) {
1125                // close() has already been called, so do nothing and return.
1126                return;
1127            }
1128
1129            int oldState = state;
1130            state = STATE_CLOSED;
1131
1132            if (oldState == STATE_NEW) {
1133                // The handshake hasn't been started yet, so there's no OpenSSL related
1134                // state to clean up. We still need to close the underlying socket if
1135                // we're wrapping it and were asked to autoClose.
1136                closeUnderlyingSocket();
1137
1138                stateLock.notifyAll();
1139                return;
1140            }
1141
1142            if (oldState != STATE_READY && oldState != STATE_READY_HANDSHAKE_CUT_THROUGH) {
1143                // If we're in these states, we still haven't returned from startHandshake.
1144                // We call SSL_interrupt so that we can interrupt SSL_do_handshake and then
1145                // set the state to STATE_CLOSED. startHandshake will handle all cleanup
1146                // after SSL_do_handshake returns, so we don't have anything to do here.
1147                NativeCrypto.SSL_interrupt(sslNativePointer);
1148
1149                stateLock.notifyAll();
1150                return;
1151            }
1152
1153            stateLock.notifyAll();
1154            // We've already returned from startHandshake, so we potentially have
1155            // input and output streams to clean up.
1156            sslInputStream = is;
1157            sslOutputStream = os;
1158        }
1159
1160        // Don't bother interrupting unless we have something to interrupt.
1161        if (sslInputStream != null || sslOutputStream != null) {
1162            NativeCrypto.SSL_interrupt(sslNativePointer);
1163        }
1164
1165        // Wait for the input and output streams to finish any reads they have in
1166        // progress. If there are no reads in progress at this point, future reads will
1167        // throw because state == STATE_CLOSED
1168        if (sslInputStream != null) {
1169            sslInputStream.awaitPendingOps();
1170        }
1171        if (sslOutputStream != null) {
1172            sslOutputStream.awaitPendingOps();
1173        }
1174
1175        shutdownAndFreeSslNative();
1176    }
1177
1178    private void shutdownAndFreeSslNative() throws IOException {
1179        try {
1180            Platform.blockGuardOnNetwork();
1181            NativeCrypto.SSL_shutdown(sslNativePointer, Platform.getFileDescriptor(socket),
1182                    this);
1183        } catch (IOException ignored) {
1184            /*
1185            * Note that although close() can throw
1186            * IOException, the RI does not throw if there
1187            * is problem sending a "close notify" which
1188            * can happen if the underlying socket is closed.
1189            */
1190        } finally {
1191            free();
1192            closeUnderlyingSocket();
1193        }
1194    }
1195
1196    private void closeUnderlyingSocket() throws IOException {
1197        if (socket != this) {
1198            if (autoClose && !socket.isClosed()) {
1199                socket.close();
1200            }
1201        } else {
1202            if (!super.isClosed()) {
1203                super.close();
1204            }
1205        }
1206    }
1207
1208    private void free() {
1209        if (sslNativePointer == 0) {
1210            return;
1211        }
1212        NativeCrypto.SSL_free(sslNativePointer);
1213        sslNativePointer = 0;
1214        Platform.closeGuardClose(guard);
1215    }
1216
1217    @Override
1218    protected void finalize() throws Throwable {
1219        try {
1220            /*
1221             * Just worry about our own state. Notably we do not try and
1222             * close anything. The SocketImpl, either our own
1223             * PlainSocketImpl, or the Socket we are wrapping, will do
1224             * that. This might mean we do not properly SSL_shutdown, but
1225             * if you want to do that, properly close the socket yourself.
1226             *
1227             * The reason why we don't try to SSL_shutdown, is that there
1228             * can be a race between finalizers where the PlainSocketImpl
1229             * finalizer runs first and closes the socket. However, in the
1230             * meanwhile, the underlying file descriptor could be reused
1231             * for another purpose. If we call SSL_shutdown, the
1232             * underlying socket BIOs still have the old file descriptor
1233             * and will write the close notify to some unsuspecting
1234             * reader.
1235             */
1236            if (guard != null) {
1237                Platform.closeGuardWarnIfOpen(guard);
1238            }
1239            free();
1240        } finally {
1241            super.finalize();
1242        }
1243    }
1244
1245    /* @Override */
1246    public FileDescriptor getFileDescriptor$() {
1247        if (socket == this) {
1248            return Platform.getFileDescriptorFromSSLSocket(this);
1249        } else {
1250            return Platform.getFileDescriptor(socket);
1251        }
1252    }
1253
1254    /**
1255     * Returns null always for backward compatibility.
1256     */
1257    public byte[] getNpnSelectedProtocol() {
1258        return null;
1259    }
1260
1261    /**
1262     * Returns the protocol agreed upon by client and server, or {@code null} if
1263     * no protocol was agreed upon.
1264     */
1265    public byte[] getAlpnSelectedProtocol() {
1266        return NativeCrypto.SSL_get0_alpn_selected(sslNativePointer);
1267    }
1268
1269    /**
1270     * This method does nothing and is kept for backward compatibility.
1271     */
1272    public void setNpnProtocols(byte[] npnProtocols) {
1273    }
1274
1275    /**
1276     * Sets the list of ALPN protocols. This method internally converts the protocols to their
1277     * wire-format form.
1278     *
1279     * @param alpnProtocols the list of ALPN protocols
1280     * @see #setAlpnProtocols(byte[])
1281     */
1282    public void setAlpnProtocols(String[] alpnProtocols) {
1283        sslParameters.setAlpnProtocols(alpnProtocols);
1284    }
1285
1286    /**
1287     * Alternate version of {@link #setAlpnProtocols(String[])} that directly sets the list of
1288     * ALPN in the wire-format form used by BoringSSL (length-prefixed 8-bit strings).
1289     * Requires that all strings be encoded with US-ASCII.
1290     *
1291     * @param alpnProtocols the encoded form of the ALPN protocol list
1292     * @see #setAlpnProtocols(String[])
1293     */
1294    public void setAlpnProtocols(byte[] alpnProtocols) {
1295        sslParameters.setAlpnProtocols(alpnProtocols);
1296    }
1297
1298    @Override
1299    public SSLParameters getSSLParameters() {
1300        SSLParameters params = super.getSSLParameters();
1301        Platform.getSSLParameters(params, sslParameters, this);
1302        return params;
1303    }
1304
1305    @Override
1306    public void setSSLParameters(SSLParameters p) {
1307        super.setSSLParameters(p);
1308        Platform.setSSLParameters(p, sslParameters, this);
1309    }
1310
1311    @Override
1312    public String chooseServerAlias(X509KeyManager keyManager, String keyType) {
1313        return keyManager.chooseServerAlias(keyType, null, this);
1314    }
1315
1316    @Override
1317    public String chooseClientAlias(X509KeyManager keyManager, X500Principal[] issuers,
1318            String[] keyTypes) {
1319        return keyManager.chooseClientAlias(keyTypes, null, this);
1320    }
1321
1322    @Override
1323    @SuppressWarnings("deprecation") // PSKKeyManager is deprecated, but in our own package
1324    public String chooseServerPSKIdentityHint(PSKKeyManager keyManager) {
1325        return keyManager.chooseServerKeyIdentityHint(this);
1326    }
1327
1328    @Override
1329    @SuppressWarnings("deprecation") // PSKKeyManager is deprecated, but in our own package
1330    public String chooseClientPSKIdentity(PSKKeyManager keyManager, String identityHint) {
1331        return keyManager.chooseClientKeyIdentity(identityHint, this);
1332    }
1333
1334    @Override
1335    @SuppressWarnings("deprecation") // PSKKeyManager is deprecated, but in our own package
1336    public SecretKey getPSKKey(PSKKeyManager keyManager, String identityHint, String identity) {
1337        return keyManager.getKey(identityHint, identity, this);
1338    }
1339}
1340