1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.server.connectivity.tethering;
18
19import static android.net.ConnectivityManager.TYPE_MOBILE_DUN;
20import static android.net.ConnectivityManager.TYPE_MOBILE_HIPRI;
21import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN;
22import static org.junit.Assert.assertEquals;
23import static org.junit.Assert.assertFalse;
24import static org.junit.Assert.assertTrue;
25import static org.junit.Assert.fail;
26import static org.mockito.Mockito.any;
27import static org.mockito.Mockito.anyInt;
28import static org.mockito.Mockito.anyString;
29import static org.mockito.Mockito.reset;
30import static org.mockito.Mockito.spy;
31import static org.mockito.Mockito.times;
32import static org.mockito.Mockito.verify;
33import static org.mockito.Mockito.verifyNoMoreInteractions;
34import static org.mockito.Mockito.when;
35
36import android.content.Context;
37import android.os.Handler;
38import android.os.Message;
39import android.net.ConnectivityManager;
40import android.net.ConnectivityManager.NetworkCallback;
41import android.net.IConnectivityManager;
42import android.net.Network;
43import android.net.NetworkCapabilities;
44import android.net.NetworkRequest;
45import android.net.util.SharedLog;
46
47import android.support.test.filters.SmallTest;
48import android.support.test.runner.AndroidJUnit4;
49
50import com.android.internal.util.State;
51import com.android.internal.util.StateMachine;
52
53import org.junit.After;
54import org.junit.Before;
55import org.junit.runner.RunWith;
56import org.junit.Test;
57import org.mockito.Mock;
58import org.mockito.Mockito;
59import org.mockito.MockitoAnnotations;
60
61import java.util.ArrayList;
62import java.util.HashMap;
63import java.util.HashSet;
64import java.util.Map;
65import java.util.Set;
66
67
68@RunWith(AndroidJUnit4.class)
69@SmallTest
70public class UpstreamNetworkMonitorTest {
71    private static final int EVENT_UNM_UPDATE = 1;
72
73    @Mock private Context mContext;
74    @Mock private IConnectivityManager mCS;
75    @Mock private SharedLog mLog;
76
77    private TestStateMachine mSM;
78    private TestConnectivityManager mCM;
79    private UpstreamNetworkMonitor mUNM;
80
81    @Before public void setUp() throws Exception {
82        MockitoAnnotations.initMocks(this);
83        reset(mContext);
84        reset(mCS);
85        reset(mLog);
86        when(mLog.forSubComponent(anyString())).thenReturn(mLog);
87
88        mCM = spy(new TestConnectivityManager(mContext, mCS));
89        mSM = new TestStateMachine();
90        mUNM = new UpstreamNetworkMonitor(mSM, EVENT_UNM_UPDATE, (ConnectivityManager) mCM, mLog);
91    }
92
93    @After public void tearDown() throws Exception {
94        if (mSM != null) {
95            mSM.quit();
96            mSM = null;
97        }
98    }
99
100    @Test
101    public void testDoesNothingBeforeStarted() {
102        assertTrue(mCM.hasNoCallbacks());
103        assertFalse(mUNM.mobileNetworkRequested());
104
105        mUNM.updateMobileRequiresDun(true);
106        assertTrue(mCM.hasNoCallbacks());
107        mUNM.updateMobileRequiresDun(false);
108        assertTrue(mCM.hasNoCallbacks());
109    }
110
111    @Test
112    public void testDefaultNetworkIsTracked() throws Exception {
113        assertEquals(0, mCM.trackingDefault.size());
114
115        mUNM.start();
116        assertEquals(1, mCM.trackingDefault.size());
117
118        mUNM.stop();
119        assertTrue(mCM.hasNoCallbacks());
120    }
121
122    @Test
123    public void testListensForAllNetworks() throws Exception {
124        assertTrue(mCM.listening.isEmpty());
125
126        mUNM.start();
127        assertFalse(mCM.listening.isEmpty());
128        assertTrue(mCM.isListeningForAll());
129
130        mUNM.stop();
131        assertTrue(mCM.hasNoCallbacks());
132    }
133
134    @Test
135    public void testRequestsMobileNetwork() throws Exception {
136        assertFalse(mUNM.mobileNetworkRequested());
137        assertEquals(0, mCM.requested.size());
138
139        mUNM.start();
140        assertFalse(mUNM.mobileNetworkRequested());
141        assertEquals(0, mCM.requested.size());
142
143        mUNM.updateMobileRequiresDun(false);
144        assertFalse(mUNM.mobileNetworkRequested());
145        assertEquals(0, mCM.requested.size());
146
147        mUNM.registerMobileNetworkRequest();
148        assertTrue(mUNM.mobileNetworkRequested());
149        assertUpstreamTypeRequested(TYPE_MOBILE_HIPRI);
150        assertFalse(mCM.isDunRequested());
151
152        mUNM.stop();
153        assertFalse(mUNM.mobileNetworkRequested());
154        assertTrue(mCM.hasNoCallbacks());
155    }
156
157    @Test
158    public void testDuplicateMobileRequestsIgnored() throws Exception {
159        assertFalse(mUNM.mobileNetworkRequested());
160        assertEquals(0, mCM.requested.size());
161
162        mUNM.start();
163        verify(mCM, Mockito.times(1)).registerNetworkCallback(
164                any(NetworkRequest.class), any(NetworkCallback.class), any(Handler.class));
165        verify(mCM, Mockito.times(1)).registerDefaultNetworkCallback(
166                any(NetworkCallback.class), any(Handler.class));
167        assertFalse(mUNM.mobileNetworkRequested());
168        assertEquals(0, mCM.requested.size());
169
170        mUNM.updateMobileRequiresDun(true);
171        mUNM.registerMobileNetworkRequest();
172        verify(mCM, Mockito.times(1)).requestNetwork(
173                any(NetworkRequest.class), any(NetworkCallback.class), anyInt(), anyInt(),
174                any(Handler.class));
175
176        assertTrue(mUNM.mobileNetworkRequested());
177        assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
178        assertTrue(mCM.isDunRequested());
179
180        // Try a few things that must not result in any state change.
181        mUNM.registerMobileNetworkRequest();
182        mUNM.updateMobileRequiresDun(true);
183        mUNM.registerMobileNetworkRequest();
184
185        assertTrue(mUNM.mobileNetworkRequested());
186        assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
187        assertTrue(mCM.isDunRequested());
188
189        mUNM.stop();
190        verify(mCM, times(3)).unregisterNetworkCallback(any(NetworkCallback.class));
191
192        verifyNoMoreInteractions(mCM);
193    }
194
195    @Test
196    public void testRequestsDunNetwork() throws Exception {
197        assertFalse(mUNM.mobileNetworkRequested());
198        assertEquals(0, mCM.requested.size());
199
200        mUNM.start();
201        assertFalse(mUNM.mobileNetworkRequested());
202        assertEquals(0, mCM.requested.size());
203
204        mUNM.updateMobileRequiresDun(true);
205        assertFalse(mUNM.mobileNetworkRequested());
206        assertEquals(0, mCM.requested.size());
207
208        mUNM.registerMobileNetworkRequest();
209        assertTrue(mUNM.mobileNetworkRequested());
210        assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
211        assertTrue(mCM.isDunRequested());
212
213        mUNM.stop();
214        assertFalse(mUNM.mobileNetworkRequested());
215        assertTrue(mCM.hasNoCallbacks());
216    }
217
218    @Test
219    public void testUpdateMobileRequiresDun() throws Exception {
220        mUNM.start();
221
222        // Test going from no-DUN to DUN correctly re-registers callbacks.
223        mUNM.updateMobileRequiresDun(false);
224        mUNM.registerMobileNetworkRequest();
225        assertTrue(mUNM.mobileNetworkRequested());
226        assertUpstreamTypeRequested(TYPE_MOBILE_HIPRI);
227        assertFalse(mCM.isDunRequested());
228        mUNM.updateMobileRequiresDun(true);
229        assertTrue(mUNM.mobileNetworkRequested());
230        assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
231        assertTrue(mCM.isDunRequested());
232
233        // Test going from DUN to no-DUN correctly re-registers callbacks.
234        mUNM.updateMobileRequiresDun(false);
235        assertTrue(mUNM.mobileNetworkRequested());
236        assertUpstreamTypeRequested(TYPE_MOBILE_HIPRI);
237        assertFalse(mCM.isDunRequested());
238
239        mUNM.stop();
240        assertFalse(mUNM.mobileNetworkRequested());
241    }
242
243    private void assertUpstreamTypeRequested(int upstreamType) throws Exception {
244        assertEquals(1, mCM.requested.size());
245        assertEquals(1, mCM.legacyTypeMap.size());
246        assertEquals(Integer.valueOf(upstreamType),
247                mCM.legacyTypeMap.values().iterator().next());
248    }
249
250    public static class TestConnectivityManager extends ConnectivityManager {
251        public Map<NetworkCallback, Handler> allCallbacks = new HashMap<>();
252        public Set<NetworkCallback> trackingDefault = new HashSet<>();
253        public Map<NetworkCallback, NetworkRequest> listening = new HashMap<>();
254        public Map<NetworkCallback, NetworkRequest> requested = new HashMap<>();
255        public Map<NetworkCallback, Integer> legacyTypeMap = new HashMap<>();
256
257        public TestConnectivityManager(Context ctx, IConnectivityManager svc) {
258            super(ctx, svc);
259        }
260
261        boolean hasNoCallbacks() {
262            return allCallbacks.isEmpty() &&
263                   trackingDefault.isEmpty() &&
264                   listening.isEmpty() &&
265                   requested.isEmpty() &&
266                   legacyTypeMap.isEmpty();
267        }
268
269        boolean isListeningForAll() {
270            final NetworkCapabilities empty = new NetworkCapabilities();
271            empty.clearAll();
272
273            for (NetworkRequest req : listening.values()) {
274                if (req.networkCapabilities.equalRequestableCapabilities(empty)) {
275                    return true;
276                }
277            }
278            return false;
279        }
280
281        boolean isDunRequested() {
282            for (NetworkRequest req : requested.values()) {
283                if (req.networkCapabilities.hasCapability(NET_CAPABILITY_DUN)) {
284                    return true;
285                }
286            }
287            return false;
288        }
289
290        @Override
291        public void requestNetwork(NetworkRequest req, NetworkCallback cb, Handler h) {
292            assertFalse(allCallbacks.containsKey(cb));
293            allCallbacks.put(cb, h);
294            assertFalse(requested.containsKey(cb));
295            requested.put(cb, req);
296        }
297
298        @Override
299        public void requestNetwork(NetworkRequest req, NetworkCallback cb) {
300            fail("Should never be called.");
301        }
302
303        @Override
304        public void requestNetwork(NetworkRequest req, NetworkCallback cb,
305                int timeoutMs, int legacyType, Handler h) {
306            assertFalse(allCallbacks.containsKey(cb));
307            allCallbacks.put(cb, h);
308            assertFalse(requested.containsKey(cb));
309            requested.put(cb, req);
310            assertFalse(legacyTypeMap.containsKey(cb));
311            if (legacyType != ConnectivityManager.TYPE_NONE) {
312                legacyTypeMap.put(cb, legacyType);
313            }
314        }
315
316        @Override
317        public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb, Handler h) {
318            assertFalse(allCallbacks.containsKey(cb));
319            allCallbacks.put(cb, h);
320            assertFalse(listening.containsKey(cb));
321            listening.put(cb, req);
322        }
323
324        @Override
325        public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb) {
326            fail("Should never be called.");
327        }
328
329        @Override
330        public void registerDefaultNetworkCallback(NetworkCallback cb, Handler h) {
331            assertFalse(allCallbacks.containsKey(cb));
332            allCallbacks.put(cb, h);
333            assertFalse(trackingDefault.contains(cb));
334            trackingDefault.add(cb);
335        }
336
337        @Override
338        public void registerDefaultNetworkCallback(NetworkCallback cb) {
339            fail("Should never be called.");
340        }
341
342        @Override
343        public void unregisterNetworkCallback(NetworkCallback cb) {
344            if (trackingDefault.contains(cb)) {
345                trackingDefault.remove(cb);
346            } else if (listening.containsKey(cb)) {
347                listening.remove(cb);
348            } else if (requested.containsKey(cb)) {
349                requested.remove(cb);
350                legacyTypeMap.remove(cb);
351            } else {
352                fail("Unexpected callback removed");
353            }
354            allCallbacks.remove(cb);
355
356            assertFalse(allCallbacks.containsKey(cb));
357            assertFalse(trackingDefault.contains(cb));
358            assertFalse(listening.containsKey(cb));
359            assertFalse(requested.containsKey(cb));
360        }
361    }
362
363    public static class TestStateMachine extends StateMachine {
364        public final ArrayList<Message> messages = new ArrayList<>();
365        private final State mLoggingState = new LoggingState();
366
367        class LoggingState extends State {
368            @Override public void enter() { messages.clear(); }
369
370            @Override public void exit() { messages.clear(); }
371
372            @Override public boolean processMessage(Message msg) {
373                messages.add(msg);
374                return true;
375            }
376        }
377
378        public TestStateMachine() {
379            super("UpstreamNetworkMonitor.TestStateMachine");
380            addState(mLoggingState);
381            setInitialState(mLoggingState);
382            super.start();
383        }
384    }
385}
386