1/*
2 * Copyright 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16package org.conscrypt;
17
18import static org.junit.Assert.assertFalse;
19import static org.junit.Assert.assertNull;
20
21import java.io.IOException;
22import java.net.InetAddress;
23import java.net.InetSocketAddress;
24import java.net.ServerSocket;
25import java.net.Socket;
26import java.nio.channels.ServerSocketChannel;
27import java.nio.channels.SocketChannel;
28import javax.net.ServerSocketFactory;
29import javax.net.ssl.SSLServerSocket;
30import javax.net.ssl.SSLServerSocketFactory;
31import javax.net.ssl.SSLSocket;
32import javax.net.ssl.SSLSocketFactory;
33
34/**
35 * The type of socket to be wrapped by the Conscrypt socket.
36 */
37@SuppressWarnings("unused")
38public enum ChannelType {
39    NONE {
40        @Override
41        SSLSocket newClientSocket(SSLSocketFactory factory, InetAddress address, int port)
42                throws IOException {
43            return clientMode(factory.createSocket(address, port));
44        }
45
46        @Override
47        ServerSocket newServerSocket(SSLServerSocketFactory factory) throws IOException {
48            return factory.createServerSocket(0, 50, InetAddress.getLoopbackAddress());
49        }
50
51        @Override
52        SSLSocket accept(ServerSocket socket, SSLSocketFactory unused) throws IOException {
53            return serverMode(socket.accept());
54        }
55    },
56    NO_CHANNEL {
57        @Override
58        SSLSocket newClientSocket(SSLSocketFactory factory, InetAddress address, int port)
59                throws IOException {
60            Socket wrapped = new Socket(address, port);
61            assertNull(wrapped.getChannel());
62
63            return clientMode(factory.createSocket(wrapped, address.getHostName(), port, true));
64        }
65
66        @Override
67        ServerSocket newServerSocket(SSLServerSocketFactory unused) throws IOException {
68            return ServerSocketFactory.getDefault().createServerSocket(
69                    0, 50, InetAddress.getLoopbackAddress());
70        }
71
72        @Override
73        SSLSocket accept(ServerSocket serverSocket, SSLSocketFactory factory) throws IOException {
74            assertFalse(serverSocket instanceof SSLServerSocket);
75            Socket wrapped = serverSocket.accept();
76            assertNull(wrapped.getChannel());
77
78            return serverMode(factory.createSocket(
79                    wrapped, wrapped.getInetAddress().getHostAddress(), wrapped.getPort(), true));
80        }
81    },
82    CHANNEL {
83        @Override
84        SSLSocket newClientSocket(SSLSocketFactory factory, InetAddress address, int port)
85                throws IOException {
86            Socket wrapped = SocketChannel.open(new InetSocketAddress(address, port)).socket();
87            return clientMode(factory.createSocket(wrapped, address.getHostName(), port, true));
88        }
89
90        @Override
91        ServerSocket newServerSocket(SSLServerSocketFactory unused) throws IOException {
92            return ServerSocketChannel.open()
93                    .bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0))
94                    .socket();
95        }
96
97        @Override
98        SSLSocket accept(ServerSocket serverSocket, SSLSocketFactory factory) throws IOException {
99            assertFalse(serverSocket instanceof SSLServerSocket);
100            ServerSocketChannel serverChannel = serverSocket.getChannel();
101
102            // Just loop until the accept completes.
103            SocketChannel channel;
104            do {
105                channel = serverChannel.accept();
106            } while (channel == null);
107
108            Socket wrapped = channel.socket();
109            return serverMode(factory.createSocket(
110                    wrapped, wrapped.getInetAddress().getHostAddress(), wrapped.getPort(), true));
111        }
112    };
113
114    abstract SSLSocket newClientSocket(SSLSocketFactory factory, InetAddress address, int port)
115            throws IOException;
116    abstract ServerSocket newServerSocket(SSLServerSocketFactory factory) throws IOException;
117    abstract SSLSocket accept(ServerSocket socket, SSLSocketFactory factory) throws IOException;
118
119    private static SSLSocket clientMode(Socket socket) {
120        SSLSocket sslSocket = (SSLSocket) socket;
121        sslSocket.setUseClientMode(true);
122        return sslSocket;
123    }
124
125    private static SSLSocket serverMode(Socket socket) {
126        SSLSocket sslSocket = (SSLSocket) socket;
127        sslSocket.setUseClientMode(false);
128        return sslSocket;
129    }
130}
131