1/*
2 * Copyright 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.io.IOException;
20import java.io.InputStream;
21import java.io.OutputStream;
22import java.net.ServerSocket;
23import java.util.concurrent.ExecutorService;
24import java.util.concurrent.Executors;
25import java.util.concurrent.Future;
26import java.util.concurrent.TimeUnit;
27import javax.net.ssl.SSLServerSocketFactory;
28import javax.net.ssl.SSLSocket;
29import javax.net.ssl.SSLSocketFactory;
30
31/**
32 * A simple socket-based test server.
33 */
34final class ServerEndpoint {
35    /**
36     * A processor for receipt of a single message.
37     */
38    public interface MessageProcessor {
39        void processMessage(byte[] message, int numBytes, OutputStream os);
40    }
41
42    /**
43     * A {@link MessageProcessor} that simply echos back the received message to the client.
44     */
45    public static final class EchoProcessor implements MessageProcessor {
46        @Override
47        public void processMessage(byte[] message, int numBytes, OutputStream os) {
48            try {
49                os.write(message, 0, numBytes);
50                os.flush();
51            } catch (IOException e) {
52                throw new RuntimeException(e);
53            }
54        }
55    }
56
57    private final ServerSocket serverSocket;
58    private final ChannelType channelType;
59    private final SSLSocketFactory socketFactory;
60    private final int messageSize;
61    private final String[] protocols;
62    private final String[] cipherSuites;
63    private final byte[] buffer;
64    private SSLSocket socket;
65    private ExecutorService executor;
66    private InputStream inputStream;
67    private OutputStream outputStream;
68    private volatile boolean stopping;
69    private volatile MessageProcessor messageProcessor = new EchoProcessor();
70
71    ServerEndpoint(SSLSocketFactory socketFactory, SSLServerSocketFactory serverSocketFactory,
72            ChannelType channelType, int messageSize, String[] protocols,
73            String[] cipherSuites) throws IOException {
74        this.serverSocket = channelType.newServerSocket(serverSocketFactory);
75        this.socketFactory = socketFactory;
76        this.channelType = channelType;
77        this.messageSize = messageSize;
78        this.protocols = protocols;
79        this.cipherSuites = cipherSuites;
80        buffer = new byte[messageSize];
81    }
82
83    void setMessageProcessor(MessageProcessor messageProcessor) {
84        this.messageProcessor = messageProcessor;
85    }
86
87    Future<?> start() throws IOException {
88        executor = Executors.newSingleThreadExecutor();
89        return executor.submit(new AcceptTask());
90    }
91
92    void stop() {
93        try {
94            stopping = true;
95
96            if (socket != null) {
97                socket.close();
98                socket = null;
99            }
100            serverSocket.close();
101
102            if (executor != null) {
103                executor.shutdown();
104                executor.awaitTermination(5, TimeUnit.SECONDS);
105                executor = null;
106            }
107        } catch (IOException | InterruptedException e) {
108            throw new RuntimeException(e);
109        }
110    }
111
112    public int port() {
113        return serverSocket.getLocalPort();
114    }
115
116    private final class AcceptTask implements Runnable {
117        @Override
118        public void run() {
119            try {
120                if (stopping) {
121                    return;
122                }
123                socket = channelType.accept(serverSocket, socketFactory);
124                socket.setEnabledProtocols(protocols);
125                socket.setEnabledCipherSuites(cipherSuites);
126
127                socket.startHandshake();
128
129                inputStream = socket.getInputStream();
130                outputStream = socket.getOutputStream();
131
132                if (stopping) {
133                    return;
134                }
135                executor.execute(new ProcessTask());
136            } catch (IOException e) {
137                e.printStackTrace();
138                throw new RuntimeException(e);
139            }
140        }
141    }
142
143    private final class ProcessTask implements Runnable {
144        @Override
145        public void run() {
146            try {
147                Thread thread = Thread.currentThread();
148                while (!stopping && !thread.isInterrupted()) {
149                    int bytesRead = readMessage();
150                    if (!stopping && !thread.isInterrupted()) {
151                        messageProcessor.processMessage(buffer, bytesRead, outputStream);
152                    }
153                }
154            } catch (Throwable e) {
155                throw new RuntimeException(e);
156            }
157        }
158
159        private int readMessage() throws IOException {
160            int totalBytesRead = 0;
161            while (totalBytesRead < messageSize) {
162                int remaining = messageSize - totalBytesRead;
163                int bytesRead = inputStream.read(buffer, totalBytesRead, remaining);
164                if (bytesRead == -1) {
165                    break;
166                }
167                totalBytesRead += bytesRead;
168            }
169            return totalBytesRead;
170        }
171    }
172}
173