1/*
2 *  Licensed to the Apache Software Foundation (ASF) under one or more
3 *  contributor license agreements.  See the NOTICE file distributed with
4 *  this work for additional information regarding copyright ownership.
5 *  The ASF licenses this file to You under the Apache License, Version 2.0
6 *  (the "License"); you may not use this file except in compliance with
7 *  the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 *  Unless required by applicable law or agreed to in writing, software
12 *  distributed under the License is distributed on an "AS IS" BASIS,
13 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *  See the License for the specific language governing permissions and
15 *  limitations under the License.
16 */
17
18package org.apache.harmony.xnet.provider.jsse;
19
20import java.io.IOException;
21import java.io.InputStream;
22import java.io.OutputStream;
23import java.net.InetAddress;
24import java.net.InetSocketAddress;
25import java.net.Socket;
26import java.util.Arrays;
27import javax.net.ssl.SSLServerSocket;
28import javax.net.ssl.SSLServerSocketFactory;
29import javax.net.ssl.SSLSocket;
30import javax.net.ssl.SSLSocketFactory;
31
32import junit.framework.Test;
33import junit.framework.TestCase;
34import junit.framework.TestSuite;
35
36/**
37 * SSLSocketImplTest
38 */
39public class SSLSocketFactoriesTest extends TestCase {
40
41    // turn on/off the debug logging
42    private boolean doLog = false;
43
44    /**
45     * Sets up the test case.
46     */
47    @Override
48    public void setUp() throws Exception {
49        super.setUp();
50        if (doLog) {
51            System.out.println("========================");
52            System.out.println("====== Running the test: " + getName());
53            System.out.println("========================");
54        }
55    }
56
57    @Override
58    public void tearDown() throws Exception {
59        super.tearDown();
60    }
61
62    /**
63     * Tests default initialized factories.
64     */
65    public void testDefaultInitialized() throws Exception {
66
67        SSLServerSocketFactory ssfactory =
68            (SSLServerSocketFactory) SSLServerSocketFactory.getDefault();
69        SSLSocketFactory sfactory =
70            (SSLSocketFactory) SSLSocketFactory.getDefault();
71
72        assertNotNull(ssfactory.getDefaultCipherSuites());
73        assertNotNull(ssfactory.getSupportedCipherSuites());
74        assertNotNull(ssfactory.createServerSocket());
75
76        assertNotNull(sfactory.getDefaultCipherSuites());
77        assertNotNull(sfactory.getSupportedCipherSuites());
78        assertNotNull(sfactory.createSocket());
79    }
80
81    public void testSocketCreation() throws Throwable {
82        SSLSocketFactory socketFactory
83            = new SSLSocketFactoryImpl(JSSETestData.getSSLParameters());
84        SSLServerSocketFactory serverSocketFactory
85            = new SSLServerSocketFactoryImpl(JSSETestData.getSSLParameters());
86
87        String[] enabled = {"TLS_RSA_WITH_RC4_128_MD5"};
88        for (int i=0; i<4; i++) {
89            SSLServerSocket ssocket;
90            switch (i) {
91                case 0:
92                    if (doLog) {
93                        System.out.println(
94                            "*** ServerSocketFactory.createServerSocket()");
95                    }
96                    ssocket = (SSLServerSocket)
97                        serverSocketFactory.createServerSocket();
98                    ssocket.bind(null);
99                    break;
100                case 1:
101                    if (doLog) {
102                        System.out.println(
103                            "*** ServerSocketFactory.createServerSocket(int)");
104                    }
105                    ssocket = (SSLServerSocket)
106                        serverSocketFactory.createServerSocket(0);
107                    break;
108                case 2:
109                    if (doLog) {
110                        System.out.println(
111                        "*** ServerSocketFactory.createServerSocket(int,int)");
112                    }
113                    ssocket = (SSLServerSocket)
114                        serverSocketFactory.createServerSocket(0, 6);
115                    break;
116                default:
117                    if (doLog) {
118                        System.out.println("*** ServerSocketFactory."
119                                + "createServerSocket(int,int,InetAddress)");
120                    }
121                    ssocket = (SSLServerSocket)
122                        serverSocketFactory.createServerSocket(0, 6, null);
123                    break;
124            }
125            ssocket.setUseClientMode(false);
126            ssocket.setEnabledCipherSuites(enabled);
127            for (int j=0; j<6; j++) {
128                SSLSocket csocket;
129                switch (j) {
130                    case 0:
131                        if (doLog) {
132                            System.out.println(
133                                "=== SocketFactory.createSocket()");
134                        }
135                        csocket = (SSLSocket) socketFactory.createSocket();
136                        csocket.connect(
137                                new InetSocketAddress("localhost",
138                                    ssocket.getLocalPort()));
139                        break;
140                    case 1:
141                        if (doLog) {
142                            System.out.println(
143                                "=== SocketFactory.createSocket(String,int)");
144                        }
145                        csocket = (SSLSocket)
146                            socketFactory.createSocket("localhost",
147                                    ssocket.getLocalPort());
148                        break;
149                    case 2:
150                        if (doLog) {
151                            System.out.println("=== SocketFactory.createSocket("
152                                    + "String,int,InetAddress,int)");
153                        }
154                        csocket = (SSLSocket)
155                            socketFactory.createSocket("localhost",
156                                ssocket.getLocalPort(),
157                                InetAddress.getByName("localhost"), 0);
158                        break;
159                    case 3:
160                        if (doLog) {
161                            System.out.println("=== SocketFactory.createSocket("
162                                    + "InetAddress,int)");
163                        }
164                        csocket = (SSLSocket) socketFactory.createSocket(
165                                InetAddress.getByName("localhost"),
166                                ssocket.getLocalPort());
167                        break;
168                    case 4:
169                        if (doLog) {
170                            System.out.println("=== SocketFactory.createSocket("
171                                    + "InetAddress,int,InetAddress,int)");
172                        }
173                        csocket = (SSLSocket) socketFactory.createSocket(
174                                InetAddress.getByName("localhost"),
175                                ssocket.getLocalPort(),
176                                InetAddress.getByName("localhost"),
177                                0);
178                        break;
179                    default:
180                        if (doLog) {
181                            System.out.println(
182                                    "=== SSLSocketFactory.createSocket("
183                                    + "socket,String,int,boolean)");
184                        }
185                        Socket socket = new Socket(
186                                InetAddress.getByName("localhost"),
187                                ssocket.getLocalPort());
188                        csocket = (SSLSocket) socketFactory.createSocket(
189                                socket, "localhost", ssocket.getLocalPort(),
190                                true);
191                        break;
192
193                }
194                csocket.setUseClientMode(true);
195                csocket.setEnabledCipherSuites(enabled);
196                doTest(ssocket, csocket);
197            }
198        }
199    }
200
201    /**
202     * SSLSocketFactory.getSupportedCipherSuites() method testing.
203     */
204    public void testGetSupportedCipherSuites1() throws Exception {
205        SSLSocketFactory socketFactory
206            = new SSLSocketFactoryImpl(JSSETestData.getSSLParameters());
207        String[] supported = socketFactory.getSupportedCipherSuites();
208        assertNotNull(supported);
209        supported[0] = "NOT_SUPPORTED_CIPHER_SUITE";
210        supported = socketFactory.getSupportedCipherSuites();
211        for (int i=0; i<supported.length; i++) {
212            if ("NOT_SUPPORTED_CIPHER_SUITE".equals(supported[i])) {
213                fail("Modification of the returned result "
214                        + "causes the modification of the internal state");
215            }
216        }
217    }
218
219    /**
220     * SSLServerSocketFactory.getSupportedCipherSuites() method testing.
221     */
222    public void testGetSupportedCipherSuites2() throws Exception {
223        SSLServerSocketFactory serverSocketFactory
224            = new SSLServerSocketFactoryImpl(JSSETestData.getSSLParameters());
225        String[] supported = serverSocketFactory.getSupportedCipherSuites();
226        assertNotNull(supported);
227        supported[0] = "NOT_SUPPORTED_CIPHER_SUITE";
228        supported = serverSocketFactory.getSupportedCipherSuites();
229        for (int i=0; i<supported.length; i++) {
230            if ("NOT_SUPPORTED_CIPHER_SUITE".equals(supported[i])) {
231                fail("Modification of the returned result "
232                        + "causes the modification of the internal state");
233            }
234        }
235    }
236
237    /**
238     * SSLSocketFactory.getDefaultCipherSuites() method testing.
239     */
240    public void testGetDefaultCipherSuites1() throws Exception {
241        SSLSocketFactory socketFactory
242            = new SSLSocketFactoryImpl(JSSETestData.getSSLParameters());
243        String[] supported = socketFactory.getSupportedCipherSuites();
244        String[] defaultcs = socketFactory.getDefaultCipherSuites();
245        assertNotNull(supported);
246        assertNotNull(defaultcs);
247        for (int i=0; i<defaultcs.length; i++) {
248            found: {
249                for (int j=0; j<supported.length; j++) {
250                    if (defaultcs[i].equals(supported[j])) {
251                        break found;
252                    }
253                }
254                fail("Default suite does not belong to the set "
255                        + "of supported cipher suites: " + defaultcs[i]);
256            }
257        }
258    }
259
260    /**
261     * SSLServerSocketFactory.getDefaultCipherSuites() method testing.
262     */
263    public void testGetDefaultCipherSuites2() throws Exception {
264        SSLServerSocketFactory serverSocketFactory
265            = new SSLServerSocketFactoryImpl(JSSETestData.getSSLParameters());
266        String[] supported = serverSocketFactory.getSupportedCipherSuites();
267        String[] defaultcs = serverSocketFactory.getDefaultCipherSuites();
268        assertNotNull(supported);
269        assertNotNull(defaultcs);
270        for (int i=0; i<defaultcs.length; i++) {
271            found: {
272                for (int j=0; j<supported.length; j++) {
273                    if (defaultcs[i].equals(supported[j])) {
274                        break found;
275                    }
276                }
277                fail("Default suite does not belong to the set "
278                        + "of supported cipher suites: " + defaultcs[i]);
279            }
280        }
281    }
282
283    /**
284     * Performs SSL connection between the sockets
285     * @return
286     */
287    public void doTest(SSLServerSocket ssocket, SSLSocket csocket)
288            throws Throwable {
289        final String server_message = "Hello from SSL Server Socket!";
290        final String client_message = "Hello from SSL Socket!";
291        Thread server = null;
292        Thread client = null;
293        final Throwable[] throwed = new Throwable[1];
294        try {
295            final SSLServerSocket ss = ssocket;
296            final SSLSocket s = csocket;
297            server = new Thread() {
298                @Override
299                public void run() {
300                    InputStream is = null;
301                    OutputStream os = null;
302                    SSLSocket s = null;
303                    try {
304                        s = (SSLSocket) ss.accept();
305                        if (doLog) {
306                            System.out.println("Socket accepted: " + s);
307                        }
308                        is = s.getInputStream();
309                        os = s.getOutputStream();
310                        // send the message to the client
311                        os.write(server_message.getBytes());
312                        // read the response
313                        byte[] buff = new byte[client_message.length()];
314                        int len = is.read(buff);
315                        if (doLog) {
316                            System.out.println("Received message of length "
317                                + len + ": '" + new String(buff, 0, len)+"'");
318                        }
319                        assertTrue("Read message does not equal to expected",
320                                Arrays.equals(client_message.getBytes(), buff));
321                        os.write(-1);
322                        assertEquals("Read data differs from expected",
323                                255, is.read());
324                        if (doLog) {
325                            System.out.println("Server is closed: "
326                                    +s.isClosed());
327                        }
328                        assertEquals("Returned value should be -1",
329                        // initiate an exchange of closure alerts
330                                -1, is.read());
331                        if (doLog) {
332                            System.out.println("Server is closed: "
333                                    +s.isClosed());
334                        }
335                        assertEquals("Returned value should be -1",
336                        // initiate an exchange of closure alerts
337                                -1, is.read());
338                    } catch (Throwable e) {
339                        synchronized (throwed) {
340                            if (doLog) {
341                                e.printStackTrace();
342                            }
343                            if (throwed[0] == null) {
344                                throwed[0] = e;
345                            }
346                        }
347                    } finally {
348                        try {
349                            if (is != null) {
350                                is.close();
351                            }
352                        } catch (IOException ex) {}
353                        try {
354                            if (os != null) {
355                                os.close();
356                            }
357                        } catch (IOException ex) {}
358                        try {
359                            if (s != null) {
360                                s.close();
361                            }
362                        } catch (IOException ex) {}
363                    }
364                }
365            };
366
367            client = new Thread() {
368                @Override
369                public void run() {
370                    InputStream is = null;
371                    OutputStream os = null;
372                    try {
373                        assertTrue("Client was not connected", s.isConnected());
374                        if (doLog) {
375                            System.out.println("Client connected");
376                        }
377                        is = s.getInputStream();
378                        os = s.getOutputStream();
379                        s.startHandshake();
380                        if (doLog) {
381                            System.out.println("Client: HS was done");
382                        }
383                        // read the message from the server
384                        byte[] buff = new byte[server_message.length()];
385                        int len = is.read(buff);
386                        if (doLog) {
387                            System.out.println("Received message of length "
388                                + len + ": '" + new String(buff, 0, len)+"'");
389                        }
390                        assertTrue("Read message does not equal to expected",
391                                Arrays.equals(server_message.getBytes(), buff));
392                        // send the response
393                        buff = (" "+client_message+" ").getBytes();
394                        os.write(buff, 1, buff.length-2);
395                        assertEquals("Read data differs from expected",
396                                255, is.read());
397                        os.write(-1);
398                        if (doLog) {
399                            System.out.println("Client is closed: "
400                                    +s.isClosed());
401                        }
402                        s.close();
403                        if (doLog) {
404                            System.out.println("Client is closed: "
405                                    +s.isClosed());
406                        }
407                    } catch (Throwable e) {
408                        synchronized (throwed) {
409                            if (doLog) {
410                                e.printStackTrace();
411                            }
412                            if (throwed[0] == null) {
413                                throwed[0] = e;
414                            }
415                        }
416                    } finally {
417                        try {
418                            if (is != null) {
419                                is.close();
420                            }
421                        } catch (IOException ex) {}
422                        try {
423                            if (os != null) {
424                                os.close();
425                            }
426                        } catch (IOException ex) {}
427                        try {
428                            if (s != null) {
429                                s.close();
430                            }
431                        } catch (IOException ex) {}
432                    }
433                }
434            };
435
436            server.start();
437            client.start();
438
439            while (server.isAlive() || client.isAlive()) {
440                if (throwed[0] != null) {
441                    throw throwed[0];
442                }
443                try {
444                    Thread.sleep(500);
445                } catch (Exception e) { }
446            }
447        } finally {
448            if (server != null) {
449                server.stop();
450            }
451            if (client != null) {
452                client.stop();
453            }
454        }
455        if (throwed[0] != null) {
456            throw throwed[0];
457        }
458    }
459
460    public static Test suite() {
461        return new TestSuite(SSLSocketFactoriesTest.class);
462    }
463
464}
465