1/* 2 * Copyright 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package org.conscrypt; 18 19import java.io.IOException; 20import java.io.InputStream; 21import java.io.OutputStream; 22import java.net.ServerSocket; 23import java.util.concurrent.ExecutorService; 24import java.util.concurrent.Executors; 25import java.util.concurrent.Future; 26import java.util.concurrent.TimeUnit; 27import javax.net.ssl.SSLServerSocketFactory; 28import javax.net.ssl.SSLSocket; 29import javax.net.ssl.SSLSocketFactory; 30 31/** 32 * A simple socket-based test server. 33 */ 34final class ServerEndpoint { 35 /** 36 * A processor for receipt of a single message. 37 */ 38 public interface MessageProcessor { 39 void processMessage(byte[] message, int numBytes, OutputStream os); 40 } 41 42 /** 43 * A {@link MessageProcessor} that simply echos back the received message to the client. 44 */ 45 public static final class EchoProcessor implements MessageProcessor { 46 @Override 47 public void processMessage(byte[] message, int numBytes, OutputStream os) { 48 try { 49 os.write(message, 0, numBytes); 50 os.flush(); 51 } catch (IOException e) { 52 throw new RuntimeException(e); 53 } 54 } 55 } 56 57 private final ServerSocket serverSocket; 58 private final ChannelType channelType; 59 private final SSLSocketFactory socketFactory; 60 private final int messageSize; 61 private final String[] protocols; 62 private final String[] cipherSuites; 63 private final byte[] buffer; 64 private SSLSocket socket; 65 private ExecutorService executor; 66 private InputStream inputStream; 67 private OutputStream outputStream; 68 private volatile boolean stopping; 69 private volatile MessageProcessor messageProcessor = new EchoProcessor(); 70 71 ServerEndpoint(SSLSocketFactory socketFactory, SSLServerSocketFactory serverSocketFactory, 72 ChannelType channelType, int messageSize, String[] protocols, 73 String[] cipherSuites) throws IOException { 74 this.serverSocket = channelType.newServerSocket(serverSocketFactory); 75 this.socketFactory = socketFactory; 76 this.channelType = channelType; 77 this.messageSize = messageSize; 78 this.protocols = protocols; 79 this.cipherSuites = cipherSuites; 80 buffer = new byte[messageSize]; 81 } 82 83 void setMessageProcessor(MessageProcessor messageProcessor) { 84 this.messageProcessor = messageProcessor; 85 } 86 87 Future<?> start() throws IOException { 88 executor = Executors.newSingleThreadExecutor(); 89 return executor.submit(new AcceptTask()); 90 } 91 92 void stop() { 93 try { 94 stopping = true; 95 96 if (socket != null) { 97 socket.close(); 98 socket = null; 99 } 100 serverSocket.close(); 101 102 if (executor != null) { 103 executor.shutdown(); 104 executor.awaitTermination(5, TimeUnit.SECONDS); 105 executor = null; 106 } 107 } catch (IOException | InterruptedException e) { 108 throw new RuntimeException(e); 109 } 110 } 111 112 public int port() { 113 return serverSocket.getLocalPort(); 114 } 115 116 private final class AcceptTask implements Runnable { 117 @Override 118 public void run() { 119 try { 120 if (stopping) { 121 return; 122 } 123 socket = channelType.accept(serverSocket, socketFactory); 124 socket.setEnabledProtocols(protocols); 125 socket.setEnabledCipherSuites(cipherSuites); 126 127 socket.startHandshake(); 128 129 inputStream = socket.getInputStream(); 130 outputStream = socket.getOutputStream(); 131 132 if (stopping) { 133 return; 134 } 135 executor.execute(new ProcessTask()); 136 } catch (IOException e) { 137 e.printStackTrace(); 138 throw new RuntimeException(e); 139 } 140 } 141 } 142 143 private final class ProcessTask implements Runnable { 144 @Override 145 public void run() { 146 try { 147 Thread thread = Thread.currentThread(); 148 while (!stopping && !thread.isInterrupted()) { 149 int bytesRead = readMessage(); 150 if (!stopping && !thread.isInterrupted()) { 151 messageProcessor.processMessage(buffer, bytesRead, outputStream); 152 } 153 } 154 } catch (Throwable e) { 155 throw new RuntimeException(e); 156 } 157 } 158 159 private int readMessage() throws IOException { 160 int totalBytesRead = 0; 161 while (totalBytesRead < messageSize) { 162 int remaining = messageSize - totalBytesRead; 163 int bytesRead = inputStream.read(buffer, totalBytesRead, remaining); 164 if (bytesRead == -1) { 165 break; 166 } 167 totalBytesRead += bytesRead; 168 } 169 return totalBytesRead; 170 } 171 } 172} 173