NetworkStatsObservers.java revision 6965c1869aa8499706522d057b5143bbc240178b
1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.server.net;
18
19import static android.net.TrafficStats.MB_IN_BYTES;
20import static com.android.internal.util.Preconditions.checkArgument;
21
22import android.app.usage.NetworkStatsManager;
23import android.net.DataUsageRequest;
24import android.net.NetworkStats;
25import android.net.NetworkStats.NonMonotonicObserver;
26import android.net.NetworkStatsHistory;
27import android.net.NetworkTemplate;
28import android.os.Binder;
29import android.os.Bundle;
30import android.os.Looper;
31import android.os.Message;
32import android.os.Messenger;
33import android.os.Handler;
34import android.os.HandlerThread;
35import android.os.IBinder;
36import android.os.Process;
37import android.os.RemoteException;
38import android.util.ArrayMap;
39import android.util.IntArray;
40import android.util.SparseArray;
41import android.util.Slog;
42
43import com.android.internal.annotations.VisibleForTesting;
44import com.android.internal.net.VpnInfo;
45
46import java.util.concurrent.atomic.AtomicInteger;
47
48/**
49 * Manages observers of {@link NetworkStats}. Allows observers to be notified when
50 * data usage has been reported in {@link NetworkStatsService}. An observer can set
51 * a threshold of how much data it cares about to be notified.
52 */
53class NetworkStatsObservers {
54    private static final String TAG = "NetworkStatsObservers";
55    private static final boolean LOGV = false;
56
57    private static final long MIN_THRESHOLD_BYTES = 2 * MB_IN_BYTES;
58
59    private static final int MSG_REGISTER = 1;
60    private static final int MSG_UNREGISTER = 2;
61    private static final int MSG_UPDATE_STATS = 3;
62
63    // All access to this map must be done from the handler thread.
64    // indexed by DataUsageRequest#requestId
65    private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();
66
67    // Sequence number of DataUsageRequests
68    private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();
69
70    // Lazily instantiated when an observer is registered.
71    private Handler mHandler;
72
73    /**
74     * Creates a wrapper that contains the caller context and a normalized request.
75     * The request should be returned to the caller app, and the wrapper should be sent to this
76     * object through #addObserver by the service handler.
77     *
78     * <p>It will register the observer asynchronously, so it is safe to call from any thread.
79     *
80     * @return the normalized request wrapped within {@link RequestInfo}.
81     */
82    public DataUsageRequest register(DataUsageRequest inputRequest, Messenger messenger,
83                IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel) {
84        DataUsageRequest request = buildRequest(inputRequest);
85        RequestInfo requestInfo = buildRequestInfo(request, messenger, binder, callingUid,
86                accessLevel);
87
88        if (LOGV) Slog.v(TAG, "Registering observer for " + request);
89        getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
90        return request;
91    }
92
93    /**
94     * Unregister a data usage observer.
95     *
96     * <p>It will unregister the observer asynchronously, so it is safe to call from any thread.
97     */
98    public void unregister(DataUsageRequest request, int callingUid) {
99        getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */,
100                request));
101    }
102
103    /**
104     * Updates data usage statistics of registered observers and notifies if limits are reached.
105     *
106     * <p>It will update stats asynchronously, so it is safe to call from any thread.
107     */
108    public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
109                ArrayMap<String, NetworkIdentitySet> activeIfaces,
110                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
111                VpnInfo[] vpnArray, long currentTime) {
112        StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces,
113                activeUidIfaces, vpnArray, currentTime);
114        getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext));
115    }
116
117    private Handler getHandler() {
118        if (mHandler == null) {
119            synchronized (this) {
120                if (mHandler == null) {
121                    if (LOGV) Slog.v(TAG, "Creating handler");
122                    mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback);
123                }
124            }
125        }
126        return mHandler;
127    }
128
129    @VisibleForTesting
130    protected Looper getHandlerLooperLocked() {
131        HandlerThread handlerThread = new HandlerThread(TAG);
132        handlerThread.start();
133        return handlerThread.getLooper();
134    }
135
136    private Handler.Callback mHandlerCallback = new Handler.Callback() {
137        @Override
138        public boolean handleMessage(Message msg) {
139            switch (msg.what) {
140                case MSG_REGISTER: {
141                    handleRegister((RequestInfo) msg.obj);
142                    return true;
143                }
144                case MSG_UNREGISTER: {
145                    handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */);
146                    return true;
147                }
148                case MSG_UPDATE_STATS: {
149                    handleUpdateStats((StatsContext) msg.obj);
150                    return true;
151                }
152                default: {
153                    return false;
154                }
155            }
156        }
157    };
158
159    /**
160     * Adds a {@link RequestInfo} as an observer.
161     * Should only be called from the handler thread otherwise there will be a race condition
162     * on mDataUsageRequests.
163     */
164    private void handleRegister(RequestInfo requestInfo) {
165        mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo);
166    }
167
168    /**
169     * Removes a {@link DataUsageRequest} if the calling uid is authorized.
170     * Should only be called from the handler thread otherwise there will be a race condition
171     * on mDataUsageRequests.
172     */
173    private void handleUnregister(DataUsageRequest request, int callingUid) {
174        RequestInfo requestInfo;
175        requestInfo = mDataUsageRequests.get(request.requestId);
176        if (requestInfo == null) {
177            if (LOGV) Slog.v(TAG, "Trying to unregister unknown request " + request);
178            return;
179        }
180        if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
181            Slog.w(TAG, "Caller uid " + callingUid + " is not owner of " + request);
182            return;
183        }
184
185        if (LOGV) Slog.v(TAG, "Unregistering " + request);
186        mDataUsageRequests.remove(request.requestId);
187        requestInfo.unlinkDeathRecipient();
188        requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
189    }
190
191    private void handleUpdateStats(StatsContext statsContext) {
192        if (mDataUsageRequests.size() == 0) {
193            if (LOGV) Slog.v(TAG, "No registered listeners of data usage");
194            return;
195        }
196
197        if (LOGV) Slog.v(TAG, "Checking if any registered observer needs to be notified");
198        for (int i = 0; i < mDataUsageRequests.size(); i++) {
199            RequestInfo requestInfo = mDataUsageRequests.valueAt(i);
200            requestInfo.updateStats(statsContext);
201        }
202    }
203
204    private DataUsageRequest buildRequest(DataUsageRequest request) {
205        // Cap the minimum threshold to a safe default to avoid too many callbacks
206        long thresholdInBytes = Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes);
207        if (thresholdInBytes < request.thresholdInBytes) {
208            Slog.w(TAG, "Threshold was too low for " + request
209                    + ". Overriding to a safer default of " + thresholdInBytes + " bytes");
210        }
211        return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(),
212                request.template, thresholdInBytes);
213    }
214
215    private RequestInfo buildRequestInfo(DataUsageRequest request,
216                Messenger messenger, IBinder binder, int callingUid,
217                @NetworkStatsAccess.Level int accessLevel) {
218        if (accessLevel <= NetworkStatsAccess.Level.USER) {
219            return new UserUsageRequestInfo(this, request, messenger, binder, callingUid,
220                    accessLevel);
221        } else {
222            // Safety check in case a new access level is added and we forgot to update this
223            checkArgument(accessLevel >= NetworkStatsAccess.Level.DEVICESUMMARY);
224            return new NetworkUsageRequestInfo(this, request, messenger, binder, callingUid,
225                    accessLevel);
226        }
227    }
228
229    /**
230     * Tracks information relevant to a data usage observer.
231     * It will notice when the calling process dies so we can self-expire.
232     */
233    private abstract static class RequestInfo implements IBinder.DeathRecipient {
234        private final NetworkStatsObservers mStatsObserver;
235        protected final DataUsageRequest mRequest;
236        private final Messenger mMessenger;
237        private final IBinder mBinder;
238        protected final int mCallingUid;
239        protected final @NetworkStatsAccess.Level int mAccessLevel;
240        protected NetworkStatsRecorder mRecorder;
241        protected NetworkStatsCollection mCollection;
242
243        RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
244                    Messenger messenger, IBinder binder, int callingUid,
245                    @NetworkStatsAccess.Level int accessLevel) {
246            mStatsObserver = statsObserver;
247            mRequest = request;
248            mMessenger = messenger;
249            mBinder = binder;
250            mCallingUid = callingUid;
251            mAccessLevel = accessLevel;
252
253            try {
254                mBinder.linkToDeath(this, 0);
255            } catch (RemoteException e) {
256                binderDied();
257            }
258        }
259
260        @Override
261        public void binderDied() {
262            if (LOGV) Slog.v(TAG, "RequestInfo binderDied("
263                    + mRequest + ", " + mBinder + ")");
264            mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
265            callCallback(NetworkStatsManager.CALLBACK_RELEASED);
266        }
267
268        @Override
269        public String toString() {
270            return "RequestInfo from uid:" + mCallingUid
271                    + " for " + mRequest + " accessLevel:" + mAccessLevel;
272        }
273
274        private void unlinkDeathRecipient() {
275            if (mBinder != null) {
276                mBinder.unlinkToDeath(this, 0);
277            }
278        }
279
280        /**
281         * Update stats given the samples and interface to identity mappings.
282         */
283        private void updateStats(StatsContext statsContext) {
284            if (mRecorder == null) {
285                // First run; establish baseline stats
286                resetRecorder();
287                recordSample(statsContext);
288                return;
289            }
290            recordSample(statsContext);
291
292            if (checkStats()) {
293                resetRecorder();
294                callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED);
295            }
296        }
297
298        private void callCallback(int callbackType) {
299            Bundle bundle = new Bundle();
300            bundle.putParcelable(DataUsageRequest.PARCELABLE_KEY, mRequest);
301            Message msg = Message.obtain();
302            msg.what = callbackType;
303            msg.setData(bundle);
304            try {
305                if (LOGV) {
306                    Slog.v(TAG, "sending notification " + callbackTypeToName(callbackType)
307                            + " for " + mRequest);
308                }
309                mMessenger.send(msg);
310            } catch (RemoteException e) {
311                // May occur naturally in the race of binder death.
312                Slog.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
313            }
314        }
315
316        private void resetRecorder() {
317            mRecorder = new NetworkStatsRecorder();
318            mCollection = mRecorder.getSinceBoot();
319        }
320
321        protected abstract boolean checkStats();
322
323        protected abstract void recordSample(StatsContext statsContext);
324
325        private String callbackTypeToName(int callbackType) {
326            switch (callbackType) {
327                case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
328                    return "LIMIT_REACHED";
329                case NetworkStatsManager.CALLBACK_RELEASED:
330                    return "RELEASED";
331                default:
332                    return "UNKNOWN";
333            }
334        }
335    }
336
337    private static class NetworkUsageRequestInfo extends RequestInfo {
338        NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
339                    Messenger messenger, IBinder binder, int callingUid,
340                    @NetworkStatsAccess.Level int accessLevel) {
341            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
342        }
343
344        @Override
345        protected boolean checkStats() {
346            long bytesSoFar = getTotalBytesForNetwork(mRequest.template);
347            if (LOGV) {
348                Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
349                        + mRequest.template);
350            }
351            if (bytesSoFar > mRequest.thresholdInBytes) {
352                return true;
353            }
354            return false;
355        }
356
357        @Override
358        protected void recordSample(StatsContext statsContext) {
359            // Recorder does not need to be locked in this context since only the handler
360            // thread will update it
361            mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces,
362                    statsContext.mVpnArray, statsContext.mCurrentTime);
363        }
364
365        /**
366         * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate
367         * over all buckets, which in this case should be only one since we built it big enough
368         * that it will outlive the caller. If it doesn't, then there will be multiple buckets.
369         */
370        private long getTotalBytesForNetwork(NetworkTemplate template) {
371            NetworkStats stats = mCollection.getSummary(template,
372                    Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
373                    mAccessLevel, mCallingUid);
374            if (LOGV) {
375                Slog.v(TAG, "Netstats for " + template + ": " + stats);
376            }
377            return stats.getTotalBytes();
378        }
379    }
380
381    private static class UserUsageRequestInfo extends RequestInfo {
382        UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
383                    Messenger messenger, IBinder binder, int callingUid,
384                    @NetworkStatsAccess.Level int accessLevel) {
385            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
386        }
387
388        @Override
389        protected boolean checkStats() {
390            int[] uidsToMonitor = mCollection.getRelevantUids(mAccessLevel, mCallingUid);
391
392            for (int i = 0; i < uidsToMonitor.length; i++) {
393                long bytesSoFar = getTotalBytesForNetworkUid(mRequest.template, uidsToMonitor[i]);
394
395                if (LOGV) {
396                    Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
397                            + mRequest.template + " for uid=" + uidsToMonitor[i]);
398                }
399                if (bytesSoFar > mRequest.thresholdInBytes) {
400                    return true;
401                }
402            }
403            return false;
404        }
405
406        @Override
407        protected void recordSample(StatsContext statsContext) {
408            // Recorder does not need to be locked in this context since only the handler
409            // thread will update it
410            mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces,
411                    statsContext.mVpnArray, statsContext.mCurrentTime);
412        }
413
414        /**
415         * Reads all stats matching the given template and uid. Ther history will likely only
416         * contain one bucket per ident since we build it big enough that it will outlive the
417         * caller lifetime.
418         */
419        private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
420            try {
421                NetworkStatsHistory history = mCollection.getHistory(template, uid,
422                        NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
423                        NetworkStatsHistory.FIELD_ALL,
424                        Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
425                        mAccessLevel, mCallingUid);
426                return history.getTotalBytes();
427            } catch (SecurityException e) {
428                if (LOGV) {
429                    Slog.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid "
430                            + uid);
431                }
432                return 0;
433            }
434        }
435    }
436
437    private static class StatsContext {
438        NetworkStats mXtSnapshot;
439        NetworkStats mUidSnapshot;
440        ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
441        ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
442        VpnInfo[] mVpnArray;
443        long mCurrentTime;
444
445        StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
446                ArrayMap<String, NetworkIdentitySet> activeIfaces,
447                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
448                VpnInfo[] vpnArray, long currentTime) {
449            mXtSnapshot = xtSnapshot;
450            mUidSnapshot = uidSnapshot;
451            mActiveIfaces = activeIfaces;
452            mActiveUidIfaces = activeUidIfaces;
453            mVpnArray = vpnArray;
454            mCurrentTime = currentTime;
455        }
456    }
457}
458