1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import static org.conscrypt.Preconditions.checkArgument;
20import static org.conscrypt.Preconditions.checkNotNull;
21
22import java.io.FileDescriptor;
23import java.io.IOException;
24import java.io.InputStream;
25import java.io.OutputStream;
26import java.net.InetAddress;
27import java.net.InetSocketAddress;
28import java.net.Socket;
29import java.net.SocketAddress;
30import java.net.SocketException;
31import java.nio.channels.SocketChannel;
32import java.security.PrivateKey;
33import java.util.ArrayList;
34import java.util.List;
35import javax.net.ssl.HandshakeCompletedEvent;
36import javax.net.ssl.HandshakeCompletedListener;
37import javax.net.ssl.SSLException;
38import javax.net.ssl.SSLSession;
39import javax.net.ssl.SSLSocket;
40
41/**
42 * Abstract base class for all Conscrypt sockets that extends the basic {@link SSLSocket} API.
43 */
44abstract class AbstractConscryptSocket extends SSLSocket {
45    final Socket socket;
46    private final boolean autoClose;
47
48    /**
49     * The peer's DNS hostname if it was supplied during creation. Note that
50     * this may be a raw IP address, so it should be checked before use with
51     * extensions that don't use it like Server Name Indication (SNI).
52     */
53    private String peerHostname;
54
55    /**
56     * The peer's port if it was supplied during creation. Should only be set if
57     * {@link #peerHostname} is also set.
58     */
59    private final int peerPort;
60
61    private final PeerInfoProvider peerInfoProvider = new PeerInfoProvider() {
62        @Override
63        String getHostname() {
64            return AbstractConscryptSocket.this.getHostname();
65        }
66
67        @Override
68        String getHostnameOrIP() {
69            return AbstractConscryptSocket.this.getHostnameOrIP();
70        }
71
72        @Override
73        int getPort() {
74            return AbstractConscryptSocket.this.getPort();
75        }
76    };
77
78    private final List<HandshakeCompletedListener> listeners =
79            new ArrayList<HandshakeCompletedListener>(2);
80
81    /**
82     * Local cache of timeout to avoid getsockopt on every read and
83     * write for non-wrapped sockets. Note that this is not used when delegating
84     * to another socket.
85     */
86    private int readTimeoutMilliseconds;
87
88    AbstractConscryptSocket() throws IOException {
89        this.socket = this;
90        this.peerHostname = null;
91        this.peerPort = -1;
92        this.autoClose = false;
93    }
94
95    AbstractConscryptSocket(String hostname, int port) throws IOException {
96        super(hostname, port);
97        this.socket = this;
98        this.peerHostname = hostname;
99        this.peerPort = port;
100        this.autoClose = false;
101    }
102
103    AbstractConscryptSocket(InetAddress address, int port) throws IOException {
104        super(address, port);
105        this.socket = this;
106        this.peerHostname = null;
107        this.peerPort = -1;
108        this.autoClose = false;
109    }
110
111    AbstractConscryptSocket(String hostname, int port, InetAddress clientAddress, int clientPort)
112            throws IOException {
113        super(hostname, port, clientAddress, clientPort);
114        this.socket = this;
115        this.peerHostname = hostname;
116        this.peerPort = port;
117        this.autoClose = false;
118    }
119
120    AbstractConscryptSocket(InetAddress address, int port, InetAddress clientAddress,
121            int clientPort) throws IOException {
122        super(address, port, clientAddress, clientPort);
123        this.socket = this;
124        this.peerHostname = null;
125        this.peerPort = -1;
126        this.autoClose = false;
127    }
128
129    AbstractConscryptSocket(Socket socket, String hostname, int port, boolean autoClose)
130            throws IOException {
131        this.socket = checkNotNull(socket, "socket");
132        this.peerHostname = hostname;
133        this.peerPort = port;
134        this.autoClose = autoClose;
135    }
136
137    @Override
138    public final void connect(SocketAddress endpoint) throws IOException {
139        connect(endpoint, 0);
140    }
141
142    /**
143     * Try to extract the peer's hostname if it's available from the endpoint address.
144     */
145    @Override
146    public final void connect(SocketAddress endpoint, int timeout) throws IOException {
147        if (peerHostname == null && endpoint instanceof InetSocketAddress) {
148            peerHostname =
149                    Platform.getHostStringFromInetSocketAddress((InetSocketAddress) endpoint);
150        }
151
152        if (isDelegating()) {
153            socket.connect(endpoint, timeout);
154        } else {
155            super.connect(endpoint, timeout);
156        }
157    }
158
159    @Override
160    public void bind(SocketAddress bindpoint) throws IOException {
161        if (isDelegating()) {
162            socket.bind(bindpoint);
163        } else {
164            super.bind(bindpoint);
165        }
166    }
167
168    @Override
169    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
170    public void close() throws IOException {
171        if (isDelegating()) {
172            if (autoClose && !socket.isClosed()) {
173                socket.close();
174            }
175        } else {
176            if (!super.isClosed()) {
177                super.close();
178            }
179        }
180    }
181
182    @Override
183    public InetAddress getInetAddress() {
184        if (isDelegating()) {
185            return socket.getInetAddress();
186        }
187        return super.getInetAddress();
188    }
189
190    @Override
191    public InetAddress getLocalAddress() {
192        if (isDelegating()) {
193            return socket.getLocalAddress();
194        }
195        return super.getLocalAddress();
196    }
197
198    @Override
199    public int getLocalPort() {
200        if (isDelegating()) {
201            return socket.getLocalPort();
202        }
203        return super.getLocalPort();
204    }
205
206    @Override
207    public SocketAddress getRemoteSocketAddress() {
208        if (isDelegating()) {
209            return socket.getRemoteSocketAddress();
210        }
211        return super.getRemoteSocketAddress();
212    }
213
214    @Override
215    public SocketAddress getLocalSocketAddress() {
216        if (isDelegating()) {
217            return socket.getLocalSocketAddress();
218        }
219        return super.getLocalSocketAddress();
220    }
221
222    @Override
223    public final int getPort() {
224        if (isDelegating()) {
225            return socket.getPort();
226        }
227
228        if (peerPort != -1) {
229            // Return the port that has been explicitly set in the constructor.
230            return peerPort;
231        }
232        return super.getPort();
233    }
234
235    @Override
236    public void addHandshakeCompletedListener(HandshakeCompletedListener listener) {
237        checkArgument(listener != null, "Provided listener is null");
238        listeners.add(listener);
239    }
240
241    @Override
242    public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) {
243        checkArgument(listener != null, "Provided listener is null");
244        if (!listeners.remove(listener)) {
245            throw new IllegalArgumentException("Provided listener is not registered");
246        }
247    }
248
249    /* @Override */
250    @SuppressWarnings("MissingOverride") // For compilation with Java 6.
251    public abstract SSLSession getHandshakeSession();
252
253    /* @Override */
254    public FileDescriptor getFileDescriptor$() {
255        if (isDelegating()) {
256            return Platform.getFileDescriptor(socket);
257        }
258        return Platform.getFileDescriptorFromSSLSocket(this);
259    }
260
261    @Override
262    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
263    public final void setSoTimeout(int readTimeoutMilliseconds) throws SocketException {
264        if (isDelegating()) {
265            socket.setSoTimeout(readTimeoutMilliseconds);
266        } else {
267            super.setSoTimeout(readTimeoutMilliseconds);
268            this.readTimeoutMilliseconds = readTimeoutMilliseconds;
269        }
270    }
271
272    @Override
273    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
274    public final int getSoTimeout() throws SocketException {
275        if (isDelegating()) {
276            return socket.getSoTimeout();
277        }
278        return readTimeoutMilliseconds;
279    }
280
281    @Override
282    public final void sendUrgentData(int data) throws IOException {
283        throw new SocketException("Method sendUrgentData() is not supported.");
284    }
285
286    @Override
287    public final void setOOBInline(boolean on) throws SocketException {
288        throw new SocketException("Method setOOBInline() is not supported.");
289    }
290
291    @Override
292    public boolean getOOBInline() throws SocketException {
293        return false;
294    }
295
296    @Override
297    public SocketChannel getChannel() {
298        // TODO(nmittler): Support channels?
299        return null;
300    }
301
302    @Override
303    public InputStream getInputStream() throws IOException {
304        if (isDelegating()) {
305            return socket.getInputStream();
306        }
307        return super.getInputStream();
308    }
309
310    @Override
311    public OutputStream getOutputStream() throws IOException {
312        if (isDelegating()) {
313            return socket.getOutputStream();
314        }
315        return super.getOutputStream();
316    }
317
318    @Override
319    public void setTcpNoDelay(boolean on) throws SocketException {
320        if (isDelegating()) {
321            socket.setTcpNoDelay(on);
322        } else {
323            super.setTcpNoDelay(on);
324        }
325    }
326
327    @Override
328    public boolean getTcpNoDelay() throws SocketException {
329        if (isDelegating()) {
330            return socket.getTcpNoDelay();
331        }
332        return super.getTcpNoDelay();
333    }
334
335    @Override
336    public void setSoLinger(boolean on, int linger) throws SocketException {
337        if (isDelegating()) {
338            socket.setSoLinger(on, linger);
339        } else {
340            super.setSoLinger(on, linger);
341        }
342    }
343
344    @Override
345    public int getSoLinger() throws SocketException {
346        if (isDelegating()) {
347            return socket.getSoLinger();
348        }
349        return super.getSoLinger();
350    }
351
352    @Override
353    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
354    public void setSendBufferSize(int size) throws SocketException {
355        if (isDelegating()) {
356            socket.setSendBufferSize(size);
357        } else {
358            super.setSendBufferSize(size);
359        }
360    }
361
362    @Override
363    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
364    public int getSendBufferSize() throws SocketException {
365        if (isDelegating()) {
366            return socket.getSendBufferSize();
367        }
368        return super.getSendBufferSize();
369    }
370
371    @Override
372    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
373    public void setReceiveBufferSize(int size) throws SocketException {
374        if (isDelegating()) {
375            socket.setReceiveBufferSize(size);
376        } else {
377            super.setReceiveBufferSize(size);
378        }
379    }
380
381    @Override
382    @SuppressWarnings("UnsynchronizedOverridesSynchronized")
383    public int getReceiveBufferSize() throws SocketException {
384        if (isDelegating()) {
385            return socket.getReceiveBufferSize();
386        }
387        return super.getReceiveBufferSize();
388    }
389
390    @Override
391    public void setKeepAlive(boolean on) throws SocketException {
392        if (isDelegating()) {
393            socket.setKeepAlive(on);
394        } else {
395            super.setKeepAlive(on);
396        }
397    }
398
399    @Override
400    public boolean getKeepAlive() throws SocketException {
401        if (isDelegating()) {
402            return socket.getKeepAlive();
403        }
404        return super.getKeepAlive();
405    }
406
407    @Override
408    public void setTrafficClass(int tc) throws SocketException {
409        if (isDelegating()) {
410            socket.setTrafficClass(tc);
411        } else {
412            super.setTrafficClass(tc);
413        }
414    }
415
416    @Override
417    public int getTrafficClass() throws SocketException {
418        if (isDelegating()) {
419            return socket.getTrafficClass();
420        }
421        return super.getTrafficClass();
422    }
423
424    @Override
425    public void setReuseAddress(boolean on) throws SocketException {
426        if (isDelegating()) {
427            socket.setReuseAddress(on);
428        } else {
429            super.setReuseAddress(on);
430        }
431    }
432
433    @Override
434    public boolean getReuseAddress() throws SocketException {
435        if (isDelegating()) {
436            return socket.getReuseAddress();
437        }
438        return super.getReuseAddress();
439    }
440
441    @Override
442    public void shutdownInput() throws IOException {
443        if (isDelegating()) {
444            socket.shutdownInput();
445        } else {
446            super.shutdownInput();
447        }
448    }
449
450    @Override
451    public void shutdownOutput() throws IOException {
452        if (isDelegating()) {
453            socket.shutdownOutput();
454        } else {
455            super.shutdownOutput();
456        }
457    }
458
459    @Override
460    public boolean isConnected() {
461        if (isDelegating()) {
462            return socket.isConnected();
463        }
464        return super.isConnected();
465    }
466
467    @Override
468    public boolean isBound() {
469        if (isDelegating()) {
470            return socket.isBound();
471        }
472        return super.isBound();
473    }
474
475    @Override
476    public boolean isClosed() {
477        if (isDelegating()) {
478            return socket.isClosed();
479        }
480        return super.isClosed();
481    }
482
483    @Override
484    public boolean isInputShutdown() {
485        if (isDelegating()) {
486            return socket.isInputShutdown();
487        }
488        return super.isInputShutdown();
489    }
490
491    @Override
492    public boolean isOutputShutdown() {
493        if (isDelegating()) {
494            return socket.isOutputShutdown();
495        }
496        return super.isOutputShutdown();
497    }
498
499    @Override
500    public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) {
501        if (isDelegating()) {
502            socket.setPerformancePreferences(connectionTime, latency, bandwidth);
503        } else {
504            super.setPerformancePreferences(connectionTime, latency, bandwidth);
505        }
506    }
507
508    @Override
509    public String toString() {
510        StringBuilder builder = new StringBuilder("SSL socket over ");
511        if (isDelegating()) {
512            builder.append(socket.toString());
513        } else {
514            builder.append(super.toString());
515        }
516        return builder.toString();
517    }
518
519    /**
520     * Returns the hostname that was supplied during socket creation. No DNS resolution is
521     * attempted before returning the hostname.
522     */
523    String getHostname() {
524        return peerHostname;
525    }
526
527    /**
528     * This method enables Server Name Indication
529     *
530     * @param hostname the desired SNI hostname, or null to disable
531     */
532    void setHostname(String hostname) {
533        peerHostname = hostname;
534    }
535
536    /**
537     * For the purposes of an SSLSession, we want a way to represent the supplied hostname
538     * or the IP address in a textual representation. We do not want to perform reverse DNS
539     * lookups on this address.
540     */
541    String getHostnameOrIP() {
542        if (peerHostname != null) {
543            return peerHostname;
544        }
545
546        InetAddress peerAddress = getInetAddress();
547        if (peerAddress != null) {
548            return peerAddress.getHostAddress();
549        }
550
551        return null;
552    }
553
554    /**
555     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
556     */
557    void setSoWriteTimeout(int writeTimeoutMilliseconds) throws SocketException {
558        throw new SocketException("Method setSoWriteTimeout() is not supported.");
559    }
560
561    /**
562     * Note write timeouts are not part of the javax.net.ssl.SSLSocket API
563     */
564    int getSoWriteTimeout() throws SocketException {
565        return 0;
566    }
567
568    /**
569     * Set the handshake timeout on this socket.  This timeout is specified in
570     * milliseconds and will be used only during the handshake process.
571     */
572    void setHandshakeTimeout(int handshakeTimeoutMilliseconds) throws SocketException {
573        throw new SocketException("Method setHandshakeTimeout() is not supported.");
574    }
575
576    /**
577     * This method enables session ticket support.
578     *
579     * @param useSessionTickets True to enable session tickets
580     */
581    abstract void setUseSessionTickets(boolean useSessionTickets);
582
583    /**
584     * Enables/disables TLS Channel ID for this server socket.
585     *
586     * <p>This method needs to be invoked before the handshake starts.
587     *
588     * @throws IllegalStateException if this is a client socket or if the handshake has already
589     *         started.
590     */
591    abstract void setChannelIdEnabled(boolean enabled);
592
593    /**
594     * Gets the TLS Channel ID for this server socket. Channel ID is only available once the
595     * handshake completes.
596     *
597     * @return channel ID or {@code null} if not available.
598     *
599     * @throws IllegalStateException if this is a client socket or if the handshake has not yet
600     *         completed.
601     * @throws SSLException if channel ID is available but could not be obtained.
602     */
603    abstract byte[] getChannelId() throws SSLException;
604
605    /**
606     * Sets the {@link PrivateKey} to be used for TLS Channel ID by this client socket.
607     *
608     * <p>This method needs to be invoked before the handshake starts.
609     *
610     * @param privateKey private key (enables TLS Channel ID) or {@code null} for no key (disables
611     *        TLS Channel ID). The private key must be an Elliptic Curve (EC) key based on the NIST
612     *        P-256 curve (aka SECG secp256r1 or ANSI X9.62 prime256v1).
613     *
614     * @throws IllegalStateException if this is a server socket or if the handshake has already
615     *         started.
616     */
617    abstract void setChannelIdPrivateKey(PrivateKey privateKey);
618
619    /**
620     * Returns null always for backward compatibility.
621     */
622    byte[] getNpnSelectedProtocol() {
623        return null;
624    }
625
626    /**
627     * This method does nothing and is kept for backward compatibility.
628     */
629    void setNpnProtocols(byte[] npnProtocols) {}
630
631    /**
632     * Returns the protocol agreed upon by client and server, or {@code null} if
633     * no protocol was agreed upon.
634     */
635    abstract byte[] getAlpnSelectedProtocol();
636
637    /**
638     * Sets the list of ALPN protocols. This method internally converts the protocols to their
639     * wire-format form.
640     *
641     * @param alpnProtocols the list of ALPN protocols
642     * @see #setAlpnProtocols(byte[])
643     */
644    abstract void setAlpnProtocols(String[] alpnProtocols);
645
646    /**
647     * Alternate version of {@link #setAlpnProtocols(String[])} that directly sets the list of
648     * ALPN in the wire-format form used by BoringSSL (length-prefixed 8-bit strings).
649     * Requires that all strings be encoded with US-ASCII.
650     *
651     * @param alpnProtocols the encoded form of the ALPN protocol list
652     * @see #setAlpnProtocols(String[])
653     */
654    abstract void setAlpnProtocols(byte[] alpnProtocols);
655
656    /**
657     * Called by {@link #notifyHandshakeCompletedListeners()} to get the currently active session.
658     * Unlike {@link #getSession()}, this method must not block.
659     */
660    abstract SSLSession getActiveSession();
661
662    final PeerInfoProvider peerInfoProvider() {
663        return peerInfoProvider;
664    }
665
666    final void checkOpen() throws SocketException {
667        if (isClosed()) {
668            throw new SocketException("Socket is closed");
669        }
670    }
671
672    final void notifyHandshakeCompletedListeners() {
673        if (listeners != null && !listeners.isEmpty()) {
674            // notify the listeners
675            HandshakeCompletedEvent event = new HandshakeCompletedEvent(this, getActiveSession());
676            for (HandshakeCompletedListener listener : listeners) {
677                try {
678                    listener.handshakeCompleted(event);
679                } catch (RuntimeException e) {
680                    // The RI runs the handlers in a separate thread,
681                    // which we do not. But we try to preserve their
682                    // behavior of logging a problem and not killing
683                    // the handshaking thread just because a listener
684                    // has a problem.
685                    Thread thread = Thread.currentThread();
686                    thread.getUncaughtExceptionHandler().uncaughtException(thread, e);
687                }
688            }
689        }
690    }
691
692    private boolean isDelegating() {
693        // Checking for null to handle the case of calling virtual methods in the super class
694        // constructor.
695        return socket != null && socket != this;
696    }
697}
698