NetworkStatsObservers.java revision cd42acd9515bdce89d4f1401ee2888d684bf1918
1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.server.net;
18
19import static android.net.TrafficStats.MB_IN_BYTES;
20import static com.android.internal.util.Preconditions.checkArgument;
21
22import android.app.usage.NetworkStatsManager;
23import android.net.DataUsageRequest;
24import android.net.NetworkStats;
25import android.net.NetworkStats.NonMonotonicObserver;
26import android.net.NetworkStatsHistory;
27import android.net.NetworkTemplate;
28import android.os.Binder;
29import android.os.Bundle;
30import android.os.Looper;
31import android.os.Message;
32import android.os.Messenger;
33import android.os.Handler;
34import android.os.HandlerThread;
35import android.os.IBinder;
36import android.os.Process;
37import android.os.RemoteException;
38import android.util.ArrayMap;
39import android.util.IntArray;
40import android.util.SparseArray;
41import android.util.Slog;
42
43import com.android.internal.annotations.VisibleForTesting;
44import com.android.internal.net.VpnInfo;
45
46import java.util.concurrent.atomic.AtomicInteger;
47
48/**
49 * Manages observers of {@link NetworkStats}. Allows observers to be notified when
50 * data usage has been reported in {@link NetworkStatsService}. An observer can set
51 * a threshold of how much data it cares about to be notified.
52 */
53class NetworkStatsObservers {
54    private static final String TAG = "NetworkStatsObservers";
55    private static final boolean LOGV = true;
56
57    private static final long MIN_THRESHOLD_BYTES = 2 * MB_IN_BYTES;
58
59    private static final int MSG_REGISTER = 1;
60    private static final int MSG_UNREGISTER = 2;
61    private static final int MSG_UPDATE_STATS = 3;
62
63    // All access to this map must be done from the handler thread.
64    // indexed by DataUsageRequest#requestId
65    private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();
66
67    // Sequence number of DataUsageRequests
68    private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();
69
70    // Lazily instantiated when an observer is registered.
71    private Handler mHandler;
72
73    /**
74     * Creates a wrapper that contains the caller context and a normalized request.
75     * The request should be returned to the caller app, and the wrapper should be sent to this
76     * object through #addObserver by the service handler.
77     *
78     * <p>It will register the observer asynchronously, so it is safe to call from any thread.
79     *
80     * @return the normalized request wrapped within {@link RequestInfo}.
81     */
82    public DataUsageRequest register(DataUsageRequest inputRequest, Messenger messenger,
83                IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel) {
84        checkVisibilityUids(callingUid, accessLevel, inputRequest.uids);
85
86        DataUsageRequest request = buildRequest(inputRequest);
87        RequestInfo requestInfo = buildRequestInfo(request, messenger, binder, callingUid,
88                accessLevel);
89
90        if (LOGV) Slog.v(TAG, "Registering observer for " + request);
91        getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
92        return request;
93    }
94
95    /**
96     * Unregister a data usage observer.
97     *
98     * <p>It will unregister the observer asynchronously, so it is safe to call from any thread.
99     */
100    public void unregister(DataUsageRequest request, int callingUid) {
101        getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */,
102                request));
103    }
104
105    /**
106     * Updates data usage statistics of registered observers and notifies if limits are reached.
107     *
108     * <p>It will update stats asynchronously, so it is safe to call from any thread.
109     */
110    public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
111                ArrayMap<String, NetworkIdentitySet> activeIfaces,
112                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
113                VpnInfo[] vpnArray, long currentTime) {
114        StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces,
115                activeUidIfaces, vpnArray, currentTime);
116        getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext));
117    }
118
119    private Handler getHandler() {
120        if (mHandler == null) {
121            synchronized (this) {
122                if (mHandler == null) {
123                    if (LOGV) Slog.v(TAG, "Creating handler");
124                    mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback);
125                }
126            }
127        }
128        return mHandler;
129    }
130
131    @VisibleForTesting
132    protected Looper getHandlerLooperLocked() {
133        HandlerThread handlerThread = new HandlerThread(TAG);
134        handlerThread.start();
135        return handlerThread.getLooper();
136    }
137
138    private Handler.Callback mHandlerCallback = new Handler.Callback() {
139        @Override
140        public boolean handleMessage(Message msg) {
141            switch (msg.what) {
142                case MSG_REGISTER: {
143                    handleRegister((RequestInfo) msg.obj);
144                    return true;
145                }
146                case MSG_UNREGISTER: {
147                    handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */);
148                    return true;
149                }
150                case MSG_UPDATE_STATS: {
151                    handleUpdateStats((StatsContext) msg.obj);
152                    return true;
153                }
154                default: {
155                    return false;
156                }
157            }
158        }
159    };
160
161    /**
162     * Adds a {@link RequestInfo} as an observer.
163     * Should only be called from the handler thread otherwise there will be a race condition
164     * on mDataUsageRequests.
165     */
166    private void handleRegister(RequestInfo requestInfo) {
167        mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo);
168    }
169
170    /**
171     * Removes a {@link DataUsageRequest} if the calling uid is authorized.
172     * Should only be called from the handler thread otherwise there will be a race condition
173     * on mDataUsageRequests.
174     */
175    private void handleUnregister(DataUsageRequest request, int callingUid) {
176        RequestInfo requestInfo;
177        requestInfo = mDataUsageRequests.get(request.requestId);
178        if (requestInfo == null) {
179            if (LOGV) Slog.v(TAG, "Trying to unregister unknown request " + request);
180            return;
181        }
182        if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
183            Slog.w(TAG, "Caller uid " + callingUid + " is not owner of " + request);
184            return;
185        }
186
187        if (LOGV) Slog.v(TAG, "Unregistering " + request);
188        mDataUsageRequests.remove(request.requestId);
189        requestInfo.unlinkDeathRecipient();
190        requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
191    }
192
193    private void handleUpdateStats(StatsContext statsContext) {
194        if (mDataUsageRequests.size() == 0) {
195            if (LOGV) Slog.v(TAG, "No registered listeners of data usage");
196            return;
197        }
198
199        if (LOGV) Slog.v(TAG, "Checking if any registered observer needs to be notified");
200        for (int i = 0; i < mDataUsageRequests.size(); i++) {
201            RequestInfo requestInfo = mDataUsageRequests.valueAt(i);
202            requestInfo.updateStats(statsContext);
203        }
204    }
205
206    private DataUsageRequest buildRequest(DataUsageRequest request) {
207        // Cap the minimum threshold to a safe default to avoid too many callbacks
208        long thresholdInBytes = Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes);
209        if (thresholdInBytes < request.thresholdInBytes) {
210            Slog.w(TAG, "Threshold was too low for " + request
211                    + ". Overriding to a safer default of " + thresholdInBytes + " bytes");
212        }
213        return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(),
214                request.templates, request.uids, thresholdInBytes);
215    }
216
217    private RequestInfo buildRequestInfo(DataUsageRequest request,
218                Messenger messenger, IBinder binder, int callingUid,
219                @NetworkStatsAccess.Level int accessLevel) {
220        if (accessLevel <= NetworkStatsAccess.Level.USER
221                || request.uids != null && request.uids.length > 0) {
222            return new UserUsageRequestInfo(this, request, messenger, binder, callingUid,
223                    accessLevel);
224        } else {
225            // Safety check in case a new access level is added and we forgot to update this
226            checkArgument(accessLevel >= NetworkStatsAccess.Level.DEVICESUMMARY);
227            return new NetworkUsageRequestInfo(this, request, messenger, binder, callingUid,
228                    accessLevel);
229        }
230    }
231
232    private void checkVisibilityUids(int callingUid, @NetworkStatsAccess.Level int accessLevel,
233                int[] uids) {
234        if (uids == null) {
235            return;
236        }
237        for (int i = 0; i < uids.length; i++) {
238            if (!NetworkStatsAccess.isAccessibleToUser(uids[i], callingUid, accessLevel)) {
239                throw new SecurityException("Caller " + callingUid + " cannot monitor network stats"
240                        + " for uid " + uids[i] + " with accessLevel " + accessLevel);
241            }
242        }
243    }
244
245    /**
246     * Tracks information relevant to a data usage observer.
247     * It will notice when the calling process dies so we can self-expire.
248     */
249    private abstract static class RequestInfo implements IBinder.DeathRecipient {
250        private final NetworkStatsObservers mStatsObserver;
251        protected final DataUsageRequest mRequest;
252        private final Messenger mMessenger;
253        private final IBinder mBinder;
254        protected final int mCallingUid;
255        protected final @NetworkStatsAccess.Level int mAccessLevel;
256        protected NetworkStatsRecorder mRecorder;
257        protected NetworkStatsCollection mCollection;
258
259        RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
260                    Messenger messenger, IBinder binder, int callingUid,
261                    @NetworkStatsAccess.Level int accessLevel) {
262            mStatsObserver = statsObserver;
263            mRequest = request;
264            mMessenger = messenger;
265            mBinder = binder;
266            mCallingUid = callingUid;
267            mAccessLevel = accessLevel;
268
269            try {
270                mBinder.linkToDeath(this, 0);
271            } catch (RemoteException e) {
272                binderDied();
273            }
274        }
275
276        @Override
277        public void binderDied() {
278            if (LOGV) Slog.v(TAG, "RequestInfo binderDied("
279                    + mRequest + ", " + mBinder + ")");
280            mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
281            callCallback(NetworkStatsManager.CALLBACK_RELEASED);
282        }
283
284        @Override
285        public String toString() {
286            return "RequestInfo from uid:" + mCallingUid
287                    + " for " + mRequest + " accessLevel:" + mAccessLevel;
288        }
289
290        private void unlinkDeathRecipient() {
291            if (mBinder != null) {
292                mBinder.unlinkToDeath(this, 0);
293            }
294        }
295
296        /**
297         * Update stats given the samples and interface to identity mappings.
298         */
299        private void updateStats(StatsContext statsContext) {
300            if (mRecorder == null) {
301                // First run; establish baseline stats
302                resetRecorder();
303                recordSample(statsContext);
304                return;
305            }
306            recordSample(statsContext);
307
308            if (checkStats()) {
309                resetRecorder();
310                callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED);
311            }
312        }
313
314        private void callCallback(int callbackType) {
315            Bundle bundle = new Bundle();
316            bundle.putParcelable(DataUsageRequest.PARCELABLE_KEY, mRequest);
317            Message msg = Message.obtain();
318            msg.what = callbackType;
319            msg.setData(bundle);
320            try {
321                if (LOGV) {
322                    Slog.v(TAG, "sending notification " + callbackTypeToName(callbackType)
323                            + " for " + mRequest);
324                }
325                mMessenger.send(msg);
326            } catch (RemoteException e) {
327                // May occur naturally in the race of binder death.
328                Slog.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
329            }
330        }
331
332        private void resetRecorder() {
333            mRecorder = new NetworkStatsRecorder();
334            mCollection = mRecorder.getSinceBoot();
335        }
336
337        protected abstract boolean checkStats();
338
339        protected abstract void recordSample(StatsContext statsContext);
340
341        private String callbackTypeToName(int callbackType) {
342            switch (callbackType) {
343                case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
344                    return "LIMIT_REACHED";
345                case NetworkStatsManager.CALLBACK_RELEASED:
346                    return "RELEASED";
347                default:
348                    return "UNKNOWN";
349            }
350        }
351    }
352
353    private static class NetworkUsageRequestInfo extends RequestInfo {
354        NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
355                    Messenger messenger, IBinder binder, int callingUid,
356                    @NetworkStatsAccess.Level int accessLevel) {
357            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
358        }
359
360        @Override
361        protected boolean checkStats() {
362            for (int i = 0; i < mRequest.templates.length; i++) {
363                long bytesSoFar = getTotalBytesForNetwork(mRequest.templates[i]);
364                if (LOGV) {
365                    Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
366                            + mRequest.templates[i]);
367                }
368                if (bytesSoFar > mRequest.thresholdInBytes) {
369                    return true;
370                }
371            }
372            return false;
373        }
374
375        @Override
376        protected void recordSample(StatsContext statsContext) {
377            // Recorder does not need to be locked in this context since only the handler
378            // thread will update it
379            mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces,
380                    statsContext.mVpnArray, statsContext.mCurrentTime);
381        }
382
383        /**
384         * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate
385         * over all buckets, which in this case should be only one since we built it big enough
386         * that it will outlive the caller. If it doesn't, then there will be multiple buckets.
387         */
388        private long getTotalBytesForNetwork(NetworkTemplate template) {
389            NetworkStats stats = mCollection.getSummary(template,
390                    Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
391                    mAccessLevel, mCallingUid);
392            if (LOGV) {
393                Slog.v(TAG, "Netstats for " + template + ": " + stats);
394            }
395            return stats.getTotalBytes();
396        }
397    }
398
399    private static class UserUsageRequestInfo extends RequestInfo {
400        UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
401                    Messenger messenger, IBinder binder, int callingUid,
402                    @NetworkStatsAccess.Level int accessLevel) {
403            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
404        }
405
406        @Override
407        protected boolean checkStats() {
408            int[] uidsToMonitor = getUidsToMonitor();
409
410            for (int i = 0; i < mRequest.templates.length; i++) {
411                for (int j = 0; j < uidsToMonitor.length; j++) {
412                    long bytesSoFar = getTotalBytesForNetworkUid(mRequest.templates[i],
413                            uidsToMonitor[j]);
414
415                    if (LOGV) {
416                        Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
417                                + mRequest.templates[i] + " for uid=" + uidsToMonitor[j]);
418                    }
419                    if (bytesSoFar > mRequest.thresholdInBytes) {
420                        return true;
421                    }
422                }
423            }
424            return false;
425        }
426
427        @Override
428        protected void recordSample(StatsContext statsContext) {
429            // Recorder does not need to be locked in this context since only the handler
430            // thread will update it
431            mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces,
432                    statsContext.mVpnArray, statsContext.mCurrentTime);
433        }
434
435        /**
436         * Reads all stats matching the given template and uid. Ther history will likely only
437         * contain one bucket per ident since we build it big enough that it will outlive the
438         * caller lifetime.
439         */
440        private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
441            try {
442                NetworkStatsHistory history = mCollection.getHistory(template, uid,
443                        NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
444                        NetworkStatsHistory.FIELD_ALL,
445                        Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
446                        mAccessLevel, mCallingUid);
447                return history.getTotalBytes();
448            } catch (SecurityException e) {
449                if (LOGV) {
450                    Slog.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid "
451                            + uid);
452                }
453                return 0;
454            }
455        }
456
457        private int[] getUidsToMonitor() {
458            if (mRequest.uids == null || mRequest.uids.length == 0) {
459                return mCollection.getRelevantUids(mAccessLevel, mCallingUid);
460            }
461            // Pick only uids from the request that are currently accessible to the user
462            IntArray accessibleUids = new IntArray(mRequest.uids.length);
463            for (int i = 0; i < mRequest.uids.length; i++) {
464                int uid = mRequest.uids[i];
465                if (NetworkStatsAccess.isAccessibleToUser(uid, mCallingUid, mAccessLevel)) {
466                    accessibleUids.add(uid);
467                }
468            }
469            return accessibleUids.toArray();
470        }
471    }
472
473    private static class StatsContext {
474        NetworkStats mXtSnapshot;
475        NetworkStats mUidSnapshot;
476        ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
477        ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
478        VpnInfo[] mVpnArray;
479        long mCurrentTime;
480
481        StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
482                ArrayMap<String, NetworkIdentitySet> activeIfaces,
483                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
484                VpnInfo[] vpnArray, long currentTime) {
485            mXtSnapshot = xtSnapshot;
486            mUidSnapshot = uidSnapshot;
487            mActiveIfaces = activeIfaces;
488            mActiveUidIfaces = activeUidIfaces;
489            mVpnArray = vpnArray;
490            mCurrentTime = currentTime;
491        }
492    }
493}
494