1/*
2 * Copyright 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt.benchmarks;
18
19import static org.conscrypt.testing.TestUtil.LOCALHOST;
20import static org.conscrypt.testing.TestUtil.getConscryptServerSocketFactory;
21import static org.conscrypt.testing.TestUtil.getConscryptSocketFactory;
22import static org.conscrypt.testing.TestUtil.getJdkServerSocketFactory;
23import static org.conscrypt.testing.TestUtil.getJdkSocketFactory;
24import static org.conscrypt.testing.TestUtil.getProtocols;
25import static org.conscrypt.testing.TestUtil.newTextMessage;
26import static org.conscrypt.testing.TestUtil.pickUnusedPort;
27
28import java.io.IOException;
29import java.io.OutputStream;
30import java.util.concurrent.ExecutorService;
31import java.util.concurrent.Executors;
32import java.util.concurrent.Future;
33import java.util.concurrent.TimeUnit;
34import java.util.concurrent.atomic.AtomicBoolean;
35import java.util.concurrent.atomic.AtomicLong;
36import javax.net.SocketFactory;
37import javax.net.ssl.SSLServerSocket;
38import javax.net.ssl.SSLServerSocketFactory;
39import javax.net.ssl.SSLSocket;
40import javax.net.ssl.SSLSocketFactory;
41import org.conscrypt.testing.TestClient;
42import org.conscrypt.testing.TestServer;
43import org.openjdk.jmh.annotations.AuxCounters;
44import org.openjdk.jmh.annotations.Benchmark;
45import org.openjdk.jmh.annotations.Fork;
46import org.openjdk.jmh.annotations.Level;
47import org.openjdk.jmh.annotations.Param;
48import org.openjdk.jmh.annotations.Scope;
49import org.openjdk.jmh.annotations.Setup;
50import org.openjdk.jmh.annotations.State;
51import org.openjdk.jmh.annotations.TearDown;
52
53/**
54 * Benchmark for comparing performance of client socket implementations. All benchmarks use Netty
55 * with tcnative as the server.
56 */
57@State(Scope.Benchmark)
58@Fork(1)
59public class ClientSocketThroughputBenchmark {
60    /**
61     * Use an AuxCounter so we can measure that bytes per second as they accumulate without
62     * consuming CPU in the benchmark method.
63     */
64    @AuxCounters
65    @State(Scope.Thread)
66    public static class BytesPerSecondCounter {
67        @Setup(Level.Iteration)
68        public void clean() {
69            bytesCounter.set(0);
70        }
71
72        public long bytesPerSecond() {
73            return bytesCounter.get();
74        }
75    }
76
77    /**
78     * Various factories for SSL sockets.
79     */
80    public enum SslProvider {
81        JDK(getJdkSocketFactory(), getJdkServerSocketFactory()),
82        CONSCRYPT(getConscryptSocketFactory(false), getConscryptServerSocketFactory(false)),
83        CONSCRYPT_ENGINE(getConscryptSocketFactory(true), getConscryptServerSocketFactory(true)) {
84            @Override
85            SSLSocket newClientSocket(String host, int port, SSLSocketFactory socketFactory)  throws IOException {
86                return (SSLSocket) socketFactory.createSocket(
87                    SocketFactory.getDefault().createSocket(host, port), host, port, true);
88            }
89        };
90
91        private final SSLSocketFactory clientSocketFactory;
92        private final SSLServerSocketFactory serverSocketFactory;
93
94        SslProvider(SSLSocketFactory clientSocketFactory, SSLServerSocketFactory serverSocketFactory) {
95            this.clientSocketFactory = clientSocketFactory;
96            this.serverSocketFactory = serverSocketFactory;
97        }
98
99        final SSLSocket newClientSocket(String host, int port, String cipher) {
100            try {
101                SSLSocket sslSocket = newClientSocket(host, port, clientSocketFactory);
102                sslSocket.setEnabledProtocols(getProtocols());
103                sslSocket.setEnabledCipherSuites(new String[] {cipher});
104                return sslSocket;
105            } catch (Exception e) {
106                throw new RuntimeException(e);
107            }
108        }
109
110        SSLSocket newClientSocket(String host, int port, SSLSocketFactory socketFactory)  throws IOException {
111            return (SSLSocket) socketFactory.createSocket(host, port);
112        }
113
114        final SSLServerSocket newServerSocket(String cipher) {
115            try {
116                int port = pickUnusedPort();
117                SSLServerSocket sslSocket =
118                    (SSLServerSocket) serverSocketFactory.createServerSocket(port);
119                sslSocket.setEnabledProtocols(getProtocols());
120                sslSocket.setEnabledCipherSuites(new String[] {cipher});
121                return sslSocket;
122            } catch (IOException e) {
123                throw new RuntimeException(e);
124            }
125        }
126    }
127
128    @Param public SslProvider sslProvider;
129
130    @Param({"64", "1024"}) public int messageSize;
131
132    @Param({"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}) public String cipher;
133
134    private TestClient client;
135    private TestServer server;
136    private byte[] message;
137    private ExecutorService executor;
138    private volatile boolean stopping;
139
140    private static final AtomicLong bytesCounter = new AtomicLong();
141    private AtomicBoolean recording = new AtomicBoolean();
142
143    @Setup(Level.Trial)
144    public void setup() throws Exception {
145        recording.set(false);
146
147        message = newTextMessage(messageSize);
148
149        server = new TestServer(sslProvider.newServerSocket(cipher), messageSize);
150        server.setMessageProcessor(new TestServer.MessageProcessor() {
151            @Override
152            public void processMessage(byte[] inMessage, int numBytes, OutputStream os) {
153                if (recording.get()) {
154                    // Server received a message, increment the count.
155                    bytesCounter.addAndGet(numBytes);
156                }
157            }
158        });
159        Future<?> connectedFuture = server.start();
160
161        client = new TestClient(sslProvider.newClientSocket(LOCALHOST, server.port(), cipher));
162        client.start();
163
164        // Wait for the initial connection to complete.
165        connectedFuture.get(5, TimeUnit.SECONDS);
166
167        executor = Executors.newSingleThreadExecutor();
168        executor.submit(new Runnable() {
169            @Override
170            public void run() {
171                Thread thread = Thread.currentThread();
172                while (!stopping && !thread.isInterrupted()) {
173                    client.sendMessage(message);
174                }
175            }
176        });
177    }
178
179    @TearDown(Level.Trial)
180    public void teardown() throws Exception {
181        stopping = true;
182        client.stop();
183        server.stop();
184        executor.shutdown();
185        executor.awaitTermination(5, TimeUnit.SECONDS);
186    }
187
188    @Benchmark
189    public final void throughput(BytesPerSecondCounter counter) throws Exception {
190        recording.set(true);
191        // No need to do anything, just sleep here.
192        Thread.sleep(1001);
193        recording.set(false);
194    }
195}
196