1/* 2 * Copyright 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16package org.conscrypt; 17 18import static org.junit.Assert.assertFalse; 19import static org.junit.Assert.assertNull; 20 21import java.io.IOException; 22import java.net.InetAddress; 23import java.net.InetSocketAddress; 24import java.net.ServerSocket; 25import java.net.Socket; 26import java.nio.channels.ServerSocketChannel; 27import java.nio.channels.SocketChannel; 28import javax.net.ServerSocketFactory; 29import javax.net.ssl.SSLServerSocket; 30import javax.net.ssl.SSLServerSocketFactory; 31import javax.net.ssl.SSLSocket; 32import javax.net.ssl.SSLSocketFactory; 33 34/** 35 * The type of socket to be wrapped by the Conscrypt socket. 36 */ 37@SuppressWarnings("unused") 38public enum ChannelType { 39 NONE { 40 @Override 41 SSLSocket newClientSocket(SSLSocketFactory factory, InetAddress address, int port) 42 throws IOException { 43 return clientMode(factory.createSocket(address, port)); 44 } 45 46 @Override 47 ServerSocket newServerSocket(SSLServerSocketFactory factory) throws IOException { 48 return factory.createServerSocket(0, 50, InetAddress.getLoopbackAddress()); 49 } 50 51 @Override 52 SSLSocket accept(ServerSocket socket, SSLSocketFactory unused) throws IOException { 53 return serverMode(socket.accept()); 54 } 55 }, 56 NO_CHANNEL { 57 @Override 58 SSLSocket newClientSocket(SSLSocketFactory factory, InetAddress address, int port) 59 throws IOException { 60 Socket wrapped = new Socket(address, port); 61 assertNull(wrapped.getChannel()); 62 63 return clientMode(factory.createSocket(wrapped, address.getHostName(), port, true)); 64 } 65 66 @Override 67 ServerSocket newServerSocket(SSLServerSocketFactory unused) throws IOException { 68 return ServerSocketFactory.getDefault().createServerSocket( 69 0, 50, InetAddress.getLoopbackAddress()); 70 } 71 72 @Override 73 SSLSocket accept(ServerSocket serverSocket, SSLSocketFactory factory) throws IOException { 74 assertFalse(serverSocket instanceof SSLServerSocket); 75 Socket wrapped = serverSocket.accept(); 76 assertNull(wrapped.getChannel()); 77 78 return serverMode(factory.createSocket( 79 wrapped, wrapped.getInetAddress().getHostAddress(), wrapped.getPort(), true)); 80 } 81 }, 82 CHANNEL { 83 @Override 84 SSLSocket newClientSocket(SSLSocketFactory factory, InetAddress address, int port) 85 throws IOException { 86 Socket wrapped = SocketChannel.open(new InetSocketAddress(address, port)).socket(); 87 return clientMode(factory.createSocket(wrapped, address.getHostName(), port, true)); 88 } 89 90 @Override 91 ServerSocket newServerSocket(SSLServerSocketFactory unused) throws IOException { 92 return ServerSocketChannel.open() 93 .bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)) 94 .socket(); 95 } 96 97 @Override 98 SSLSocket accept(ServerSocket serverSocket, SSLSocketFactory factory) throws IOException { 99 assertFalse(serverSocket instanceof SSLServerSocket); 100 ServerSocketChannel serverChannel = serverSocket.getChannel(); 101 102 // Just loop until the accept completes. 103 SocketChannel channel; 104 do { 105 channel = serverChannel.accept(); 106 } while (channel == null); 107 108 Socket wrapped = channel.socket(); 109 return serverMode(factory.createSocket( 110 wrapped, wrapped.getInetAddress().getHostAddress(), wrapped.getPort(), true)); 111 } 112 }; 113 114 abstract SSLSocket newClientSocket(SSLSocketFactory factory, InetAddress address, int port) 115 throws IOException; 116 abstract ServerSocket newServerSocket(SSLServerSocketFactory factory) throws IOException; 117 abstract SSLSocket accept(ServerSocket socket, SSLSocketFactory factory) throws IOException; 118 119 private static SSLSocket clientMode(Socket socket) { 120 SSLSocket sslSocket = (SSLSocket) socket; 121 sslSocket.setUseClientMode(true); 122 return sslSocket; 123 } 124 125 private static SSLSocket serverMode(Socket socket) { 126 SSLSocket sslSocket = (SSLSocket) socket; 127 sslSocket.setUseClientMode(false); 128 return sslSocket; 129 } 130} 131