1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package libcore.javax.net.ssl;
18
19import java.lang.reflect.Field;
20import java.lang.reflect.Method;
21import java.net.InetAddress;
22import java.net.InetSocketAddress;
23import java.net.ServerSocket;
24import java.net.Socket;
25import java.net.SocketException;
26import java.security.KeyManagementException;
27import java.security.Provider;
28import java.security.SecureRandom;
29import java.security.Security;
30import java.util.Properties;
31import javax.net.ServerSocketFactory;
32import javax.net.SocketFactory;
33import javax.net.ssl.KeyManager;
34import javax.net.ssl.SSLContext;
35import javax.net.ssl.SSLContextSpi;
36import javax.net.ssl.SSLEngine;
37import javax.net.ssl.SSLServerSocketFactory;
38import javax.net.ssl.SSLSessionContext;
39import javax.net.ssl.SSLSocket;
40import javax.net.ssl.SSLSocketFactory;
41import javax.net.ssl.TrustManager;
42import junit.framework.TestCase;
43import libcore.java.security.StandardNames;
44
45public class SSLSocketFactoryTest extends TestCase {
46    private static final String SSL_PROPERTY = "ssl.SocketFactory.provider";
47
48    public void test_SSLSocketFactory_getDefault() {
49        SocketFactory sf = SSLSocketFactory.getDefault();
50        assertNotNull(sf);
51        assertTrue(SSLSocketFactory.class.isAssignableFrom(sf.getClass()));
52    }
53
54    public static class FakeSSLSocketProvider extends Provider {
55        public FakeSSLSocketProvider() {
56            super("FakeSSLSocketProvider", 1.0, "Testing provider");
57            put("SSLContext.Default", FakeSSLContextSpi.class.getName());
58        }
59    }
60
61    public static final class FakeSSLContextSpi extends SSLContextSpi {
62        @Override
63        protected void engineInit(KeyManager[] keyManagers, TrustManager[] trustManagers,
64                SecureRandom secureRandom) throws KeyManagementException {
65            throw new UnsupportedOperationException();
66        }
67
68        @Override
69        protected SSLSocketFactory engineGetSocketFactory() {
70            return new FakeSSLSocketFactory();
71        }
72
73        @Override
74        protected SSLServerSocketFactory engineGetServerSocketFactory() {
75            throw new UnsupportedOperationException();
76        }
77
78        @Override
79        protected SSLEngine engineCreateSSLEngine(String s, int i) {
80            throw new UnsupportedOperationException();
81        }
82
83        @Override
84        protected SSLEngine engineCreateSSLEngine() {
85            throw new UnsupportedOperationException();
86        }
87
88        @Override
89        protected SSLSessionContext engineGetServerSessionContext() {
90            throw new UnsupportedOperationException();
91        }
92
93        @Override
94        protected SSLSessionContext engineGetClientSessionContext() {
95            throw new UnsupportedOperationException();
96        }
97    }
98
99    public static class FakeSSLSocketFactory extends SSLSocketFactory {
100        public FakeSSLSocketFactory() {
101        }
102
103        @Override
104        public String[] getDefaultCipherSuites() {
105            throw new UnsupportedOperationException();
106        }
107
108        @Override
109        public String[] getSupportedCipherSuites() {
110            throw new UnsupportedOperationException();
111        }
112
113        @Override
114        public Socket createSocket(Socket s, String host, int port, boolean autoClose) {
115            throw new UnsupportedOperationException();
116        }
117
118        @Override
119        public Socket createSocket(InetAddress address, int port, InetAddress localAddress,
120                int localPort) {
121            throw new UnsupportedOperationException();
122        }
123
124        @Override
125        public Socket createSocket(InetAddress host, int port) {
126            throw new UnsupportedOperationException();
127        }
128
129        @Override
130        public Socket createSocket(String host, int port, InetAddress localHost, int localPort) {
131            throw new UnsupportedOperationException();
132        }
133
134        @Override
135        public Socket createSocket(String host, int port) {
136            throw new UnsupportedOperationException();
137        }
138    }
139
140    public void test_SSLSocketFactory_getDefault_cacheInvalidate() throws Exception {
141        String origProvider = resetSslProvider();
142        try {
143            SocketFactory sf1 = SSLSocketFactory.getDefault();
144            assertNotNull(sf1);
145            assertTrue(SSLSocketFactory.class.isAssignableFrom(sf1.getClass()));
146
147            Provider fakeProvider = new FakeSSLSocketProvider();
148            SocketFactory sf4 = null;
149            SSLContext origContext = null;
150            try {
151                origContext = SSLContext.getDefault();
152                Security.insertProviderAt(fakeProvider, 1);
153                SSLContext.setDefault(SSLContext.getInstance("Default", fakeProvider));
154
155                sf4 = SSLSocketFactory.getDefault();
156                assertNotNull(sf4);
157                assertTrue(SSLSocketFactory.class.isAssignableFrom(sf4.getClass()));
158
159                assertFalse(sf1.getClass() + " should not be " + sf4.getClass(),
160                        sf1.getClass().equals(sf4.getClass()));
161            } finally {
162                SSLContext.setDefault(origContext);
163                Security.removeProvider(fakeProvider.getName());
164            }
165
166            SocketFactory sf3 = SSLSocketFactory.getDefault();
167            assertNotNull(sf3);
168            assertTrue(SSLSocketFactory.class.isAssignableFrom(sf3.getClass()));
169
170            assertTrue(sf1.getClass() + " should be " + sf3.getClass(),
171                    sf1.getClass().equals(sf3.getClass()));
172
173            if (!StandardNames.IS_RI) {
174                Security.setProperty(SSL_PROPERTY, FakeSSLSocketFactory.class.getName());
175                SocketFactory sf2 = SSLSocketFactory.getDefault();
176                assertNotNull(sf2);
177                assertTrue(SSLSocketFactory.class.isAssignableFrom(sf2.getClass()));
178
179                assertFalse(sf2.getClass().getName() + " should not be " + Security.getProperty(SSL_PROPERTY),
180                        sf1.getClass().equals(sf2.getClass()));
181                assertTrue(sf2.getClass().equals(sf4.getClass()));
182
183                resetSslProvider();
184            }
185        } finally {
186            Security.setProperty(SSL_PROPERTY, origProvider);
187        }
188    }
189
190    /**
191     * Should only run on Android.
192     */
193    private String resetSslProvider() {
194        String origProvider = Security.getProperty(SSL_PROPERTY);
195
196        try {
197            Field field_secprops = Security.class.getDeclaredField("props");
198            field_secprops.setAccessible(true);
199            Properties secprops = (Properties) field_secprops.get(null);
200            secprops.remove(SSL_PROPERTY);
201            Method m_increaseVersion = Security.class.getDeclaredMethod("increaseVersion");
202            m_increaseVersion.invoke(null);
203        } catch (Exception e) {
204            e.printStackTrace();
205            throw new RuntimeException("Could not clear security provider", e);
206        }
207
208        assertNull(Security.getProperty(SSL_PROPERTY));
209        return origProvider;
210    }
211
212    public void test_SSLSocketFactory_defaultConfiguration() throws Exception {
213        SSLConfigurationAsserts.assertSSLSocketFactoryDefaultConfiguration(
214                (SSLSocketFactory) SSLSocketFactory.getDefault());
215    }
216
217    public void test_SSLSocketFactory_getDefaultCipherSuitesReturnsCopies() {
218        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
219        assertNotSame(sf.getDefaultCipherSuites(), sf.getDefaultCipherSuites());
220    }
221
222    public void test_SSLSocketFactory_getSupportedCipherSuitesReturnsCopies() {
223        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
224        assertNotSame(sf.getSupportedCipherSuites(), sf.getSupportedCipherSuites());
225    }
226
227    public void test_SSLSocketFactory_createSocket() throws Exception {
228        try {
229            SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
230            Socket s = sf.createSocket(null, null, -1, false);
231            fail();
232        } catch (NullPointerException expected) {
233        }
234
235        try {
236            SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
237            Socket ssl = sf.createSocket(new Socket(), null, -1, false);
238            fail();
239        } catch (SocketException expected) {
240        }
241
242        ServerSocket ss = ServerSocketFactory.getDefault().createServerSocket(0);
243        InetSocketAddress sa = (InetSocketAddress) ss.getLocalSocketAddress();
244        InetAddress host = sa.getAddress();
245        int port = sa.getPort();
246        Socket s = new Socket(host, port);
247        SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
248        Socket ssl = sf.createSocket(s, null, -1, false);
249        assertNotNull(ssl);
250        assertTrue(SSLSocket.class.isAssignableFrom(ssl.getClass()));
251    }
252}
253