1/*
2 * Copyright 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.io.EOFException;
20import java.io.IOException;
21import java.io.InputStream;
22import java.io.OutputStream;
23import java.net.ServerSocket;
24import java.net.SocketException;
25import java.nio.channels.ClosedChannelException;
26import java.util.concurrent.ExecutionException;
27import java.util.concurrent.ExecutorService;
28import java.util.concurrent.Executors;
29import java.util.concurrent.Future;
30import java.util.concurrent.TimeUnit;
31import java.util.concurrent.TimeoutException;
32import javax.net.ssl.SSLException;
33import javax.net.ssl.SSLServerSocketFactory;
34import javax.net.ssl.SSLSocket;
35import javax.net.ssl.SSLSocketFactory;
36
37/**
38 * A simple socket-based test server.
39 */
40final class ServerEndpoint {
41    /**
42     * A processor for receipt of a single message.
43     */
44    public interface MessageProcessor {
45        void processMessage(byte[] message, int numBytes, OutputStream os);
46    }
47
48    /**
49     * A {@link MessageProcessor} that simply echos back the received message to the client.
50     */
51    public static final class EchoProcessor implements MessageProcessor {
52        @Override
53        public void processMessage(byte[] message, int numBytes, OutputStream os) {
54            try {
55                os.write(message, 0, numBytes);
56                os.flush();
57            } catch (IOException e) {
58                throw new RuntimeException(e);
59            }
60        }
61    }
62
63    private final ServerSocket serverSocket;
64    private final ChannelType channelType;
65    private final SSLSocketFactory socketFactory;
66    private final int messageSize;
67    private final String[] protocols;
68    private final String[] cipherSuites;
69    private final byte[] buffer;
70    private SSLSocket socket;
71    private ExecutorService executor;
72    private InputStream inputStream;
73    private OutputStream outputStream;
74    private volatile boolean stopping;
75    private volatile MessageProcessor messageProcessor = new EchoProcessor();
76    private volatile Future<?> processFuture;
77
78    ServerEndpoint(SSLSocketFactory socketFactory, SSLServerSocketFactory serverSocketFactory,
79            ChannelType channelType, int messageSize, String[] protocols,
80            String[] cipherSuites) throws IOException {
81        this.serverSocket = channelType.newServerSocket(serverSocketFactory);
82        this.socketFactory = socketFactory;
83        this.channelType = channelType;
84        this.messageSize = messageSize;
85        this.protocols = protocols;
86        this.cipherSuites = cipherSuites;
87        buffer = new byte[messageSize];
88    }
89
90    void setMessageProcessor(MessageProcessor messageProcessor) {
91        this.messageProcessor = messageProcessor;
92    }
93
94    Future<?> start() throws IOException {
95        executor = Executors.newSingleThreadExecutor();
96        return executor.submit(new AcceptTask());
97    }
98
99    void stop() {
100        try {
101            stopping = true;
102
103            if (socket != null) {
104                socket.close();
105                socket = null;
106            }
107
108            if (processFuture != null) {
109                processFuture.get(5, TimeUnit.SECONDS);
110            }
111
112            serverSocket.close();
113
114            if (executor != null) {
115                executor.shutdown();
116                executor.awaitTermination(5, TimeUnit.SECONDS);
117                executor = null;
118            }
119        } catch (IOException | InterruptedException | ExecutionException | TimeoutException e) {
120            throw new RuntimeException(e);
121        }
122    }
123
124    public int port() {
125        return serverSocket.getLocalPort();
126    }
127
128    private final class AcceptTask implements Runnable {
129        @Override
130        public void run() {
131            try {
132                if (stopping) {
133                    return;
134                }
135                socket = channelType.accept(serverSocket, socketFactory);
136                socket.setEnabledProtocols(protocols);
137                socket.setEnabledCipherSuites(cipherSuites);
138
139                socket.startHandshake();
140
141                inputStream = socket.getInputStream();
142                outputStream = socket.getOutputStream();
143
144                if (stopping) {
145                    return;
146                }
147                processFuture = executor.submit(new ProcessTask());
148            } catch (IOException e) {
149                e.printStackTrace();
150                throw new RuntimeException(e);
151            }
152        }
153    }
154
155    private final class ProcessTask implements Runnable {
156        @Override
157        public void run() {
158            try {
159                Thread thread = Thread.currentThread();
160                while (!stopping && !thread.isInterrupted()) {
161                    int bytesRead = readMessage();
162                    if (!stopping && !thread.isInterrupted()) {
163                        messageProcessor.processMessage(buffer, bytesRead, outputStream);
164                    }
165                }
166            } catch (Throwable e) {
167                throw new RuntimeException(e);
168            }
169        }
170
171        private int readMessage() throws IOException {
172            int totalBytesRead = 0;
173            while (!stopping && totalBytesRead < messageSize) {
174                try {
175                    int remaining = messageSize - totalBytesRead;
176                    int bytesRead = inputStream.read(buffer, totalBytesRead, remaining);
177                    if (bytesRead == -1) {
178                        break;
179                    }
180                    totalBytesRead += bytesRead;
181                } catch (SSLException e) {
182                    if (e.getCause() instanceof EOFException) {
183                        break;
184                    }
185                    throw e;
186                } catch (ClosedChannelException e) {
187                    // Thrown for channel-based sockets. Just treat like EOF.
188                    break;
189                } catch (SocketException e) {
190                    // The socket was broken. Just treat like EOF.
191                    break;
192                }
193            }
194            return totalBytesRead;
195        }
196    }
197}
198