1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.server.net;
18
19import static android.net.TrafficStats.MB_IN_BYTES;
20import static com.android.internal.util.Preconditions.checkArgument;
21
22import android.app.usage.NetworkStatsManager;
23import android.net.DataUsageRequest;
24import android.net.NetworkStats;
25import android.net.NetworkStats.NonMonotonicObserver;
26import android.net.NetworkStatsHistory;
27import android.net.NetworkTemplate;
28import android.os.Binder;
29import android.os.Bundle;
30import android.os.Looper;
31import android.os.Message;
32import android.os.Messenger;
33import android.os.Handler;
34import android.os.HandlerThread;
35import android.os.IBinder;
36import android.os.Process;
37import android.os.RemoteException;
38import android.util.ArrayMap;
39import android.util.IntArray;
40import android.util.SparseArray;
41import android.util.Slog;
42
43import com.android.internal.annotations.VisibleForTesting;
44import com.android.internal.net.VpnInfo;
45
46import java.util.concurrent.atomic.AtomicInteger;
47
48/**
49 * Manages observers of {@link NetworkStats}. Allows observers to be notified when
50 * data usage has been reported in {@link NetworkStatsService}. An observer can set
51 * a threshold of how much data it cares about to be notified.
52 */
53class NetworkStatsObservers {
54    private static final String TAG = "NetworkStatsObservers";
55    private static final boolean LOGV = false;
56
57    private static final long MIN_THRESHOLD_BYTES = 2 * MB_IN_BYTES;
58
59    private static final int MSG_REGISTER = 1;
60    private static final int MSG_UNREGISTER = 2;
61    private static final int MSG_UPDATE_STATS = 3;
62
63    // All access to this map must be done from the handler thread.
64    // indexed by DataUsageRequest#requestId
65    private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();
66
67    // Sequence number of DataUsageRequests
68    private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();
69
70    // Lazily instantiated when an observer is registered.
71    private Handler mHandler;
72
73    /**
74     * Creates a wrapper that contains the caller context and a normalized request.
75     * The request should be returned to the caller app, and the wrapper should be sent to this
76     * object through #addObserver by the service handler.
77     *
78     * <p>It will register the observer asynchronously, so it is safe to call from any thread.
79     *
80     * @return the normalized request wrapped within {@link RequestInfo}.
81     */
82    public DataUsageRequest register(DataUsageRequest inputRequest, Messenger messenger,
83                IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel) {
84        DataUsageRequest request = buildRequest(inputRequest);
85        RequestInfo requestInfo = buildRequestInfo(request, messenger, binder, callingUid,
86                accessLevel);
87
88        if (LOGV) Slog.v(TAG, "Registering observer for " + request);
89        getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
90        return request;
91    }
92
93    /**
94     * Unregister a data usage observer.
95     *
96     * <p>It will unregister the observer asynchronously, so it is safe to call from any thread.
97     */
98    public void unregister(DataUsageRequest request, int callingUid) {
99        getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */,
100                request));
101    }
102
103    /**
104     * Updates data usage statistics of registered observers and notifies if limits are reached.
105     *
106     * <p>It will update stats asynchronously, so it is safe to call from any thread.
107     */
108    public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
109                ArrayMap<String, NetworkIdentitySet> activeIfaces,
110                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
111                VpnInfo[] vpnArray, long currentTime) {
112        StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces,
113                activeUidIfaces, vpnArray, currentTime);
114        getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext));
115    }
116
117    private Handler getHandler() {
118        if (mHandler == null) {
119            synchronized (this) {
120                if (mHandler == null) {
121                    if (LOGV) Slog.v(TAG, "Creating handler");
122                    mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback);
123                }
124            }
125        }
126        return mHandler;
127    }
128
129    @VisibleForTesting
130    protected Looper getHandlerLooperLocked() {
131        HandlerThread handlerThread = new HandlerThread(TAG);
132        handlerThread.start();
133        return handlerThread.getLooper();
134    }
135
136    private Handler.Callback mHandlerCallback = new Handler.Callback() {
137        @Override
138        public boolean handleMessage(Message msg) {
139            switch (msg.what) {
140                case MSG_REGISTER: {
141                    handleRegister((RequestInfo) msg.obj);
142                    return true;
143                }
144                case MSG_UNREGISTER: {
145                    handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */);
146                    return true;
147                }
148                case MSG_UPDATE_STATS: {
149                    handleUpdateStats((StatsContext) msg.obj);
150                    return true;
151                }
152                default: {
153                    return false;
154                }
155            }
156        }
157    };
158
159    /**
160     * Adds a {@link RequestInfo} as an observer.
161     * Should only be called from the handler thread otherwise there will be a race condition
162     * on mDataUsageRequests.
163     */
164    private void handleRegister(RequestInfo requestInfo) {
165        mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo);
166    }
167
168    /**
169     * Removes a {@link DataUsageRequest} if the calling uid is authorized.
170     * Should only be called from the handler thread otherwise there will be a race condition
171     * on mDataUsageRequests.
172     */
173    private void handleUnregister(DataUsageRequest request, int callingUid) {
174        RequestInfo requestInfo;
175        requestInfo = mDataUsageRequests.get(request.requestId);
176        if (requestInfo == null) {
177            if (LOGV) Slog.v(TAG, "Trying to unregister unknown request " + request);
178            return;
179        }
180        if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
181            Slog.w(TAG, "Caller uid " + callingUid + " is not owner of " + request);
182            return;
183        }
184
185        if (LOGV) Slog.v(TAG, "Unregistering " + request);
186        mDataUsageRequests.remove(request.requestId);
187        requestInfo.unlinkDeathRecipient();
188        requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
189    }
190
191    private void handleUpdateStats(StatsContext statsContext) {
192        if (mDataUsageRequests.size() == 0) {
193            return;
194        }
195
196        for (int i = 0; i < mDataUsageRequests.size(); i++) {
197            RequestInfo requestInfo = mDataUsageRequests.valueAt(i);
198            requestInfo.updateStats(statsContext);
199        }
200    }
201
202    private DataUsageRequest buildRequest(DataUsageRequest request) {
203        // Cap the minimum threshold to a safe default to avoid too many callbacks
204        long thresholdInBytes = Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes);
205        if (thresholdInBytes < request.thresholdInBytes) {
206            Slog.w(TAG, "Threshold was too low for " + request
207                    + ". Overriding to a safer default of " + thresholdInBytes + " bytes");
208        }
209        return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(),
210                request.template, thresholdInBytes);
211    }
212
213    private RequestInfo buildRequestInfo(DataUsageRequest request,
214                Messenger messenger, IBinder binder, int callingUid,
215                @NetworkStatsAccess.Level int accessLevel) {
216        if (accessLevel <= NetworkStatsAccess.Level.USER) {
217            return new UserUsageRequestInfo(this, request, messenger, binder, callingUid,
218                    accessLevel);
219        } else {
220            // Safety check in case a new access level is added and we forgot to update this
221            checkArgument(accessLevel >= NetworkStatsAccess.Level.DEVICESUMMARY);
222            return new NetworkUsageRequestInfo(this, request, messenger, binder, callingUid,
223                    accessLevel);
224        }
225    }
226
227    /**
228     * Tracks information relevant to a data usage observer.
229     * It will notice when the calling process dies so we can self-expire.
230     */
231    private abstract static class RequestInfo implements IBinder.DeathRecipient {
232        private final NetworkStatsObservers mStatsObserver;
233        protected final DataUsageRequest mRequest;
234        private final Messenger mMessenger;
235        private final IBinder mBinder;
236        protected final int mCallingUid;
237        protected final @NetworkStatsAccess.Level int mAccessLevel;
238        protected NetworkStatsRecorder mRecorder;
239        protected NetworkStatsCollection mCollection;
240
241        RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
242                    Messenger messenger, IBinder binder, int callingUid,
243                    @NetworkStatsAccess.Level int accessLevel) {
244            mStatsObserver = statsObserver;
245            mRequest = request;
246            mMessenger = messenger;
247            mBinder = binder;
248            mCallingUid = callingUid;
249            mAccessLevel = accessLevel;
250
251            try {
252                mBinder.linkToDeath(this, 0);
253            } catch (RemoteException e) {
254                binderDied();
255            }
256        }
257
258        @Override
259        public void binderDied() {
260            if (LOGV) Slog.v(TAG, "RequestInfo binderDied("
261                    + mRequest + ", " + mBinder + ")");
262            mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
263            callCallback(NetworkStatsManager.CALLBACK_RELEASED);
264        }
265
266        @Override
267        public String toString() {
268            return "RequestInfo from uid:" + mCallingUid
269                    + " for " + mRequest + " accessLevel:" + mAccessLevel;
270        }
271
272        private void unlinkDeathRecipient() {
273            if (mBinder != null) {
274                mBinder.unlinkToDeath(this, 0);
275            }
276        }
277
278        /**
279         * Update stats given the samples and interface to identity mappings.
280         */
281        private void updateStats(StatsContext statsContext) {
282            if (mRecorder == null) {
283                // First run; establish baseline stats
284                resetRecorder();
285                recordSample(statsContext);
286                return;
287            }
288            recordSample(statsContext);
289
290            if (checkStats()) {
291                resetRecorder();
292                callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED);
293            }
294        }
295
296        private void callCallback(int callbackType) {
297            Bundle bundle = new Bundle();
298            bundle.putParcelable(DataUsageRequest.PARCELABLE_KEY, mRequest);
299            Message msg = Message.obtain();
300            msg.what = callbackType;
301            msg.setData(bundle);
302            try {
303                if (LOGV) {
304                    Slog.v(TAG, "sending notification " + callbackTypeToName(callbackType)
305                            + " for " + mRequest);
306                }
307                mMessenger.send(msg);
308            } catch (RemoteException e) {
309                // May occur naturally in the race of binder death.
310                Slog.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
311            }
312        }
313
314        private void resetRecorder() {
315            mRecorder = new NetworkStatsRecorder();
316            mCollection = mRecorder.getSinceBoot();
317        }
318
319        protected abstract boolean checkStats();
320
321        protected abstract void recordSample(StatsContext statsContext);
322
323        private String callbackTypeToName(int callbackType) {
324            switch (callbackType) {
325                case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
326                    return "LIMIT_REACHED";
327                case NetworkStatsManager.CALLBACK_RELEASED:
328                    return "RELEASED";
329                default:
330                    return "UNKNOWN";
331            }
332        }
333    }
334
335    private static class NetworkUsageRequestInfo extends RequestInfo {
336        NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
337                    Messenger messenger, IBinder binder, int callingUid,
338                    @NetworkStatsAccess.Level int accessLevel) {
339            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
340        }
341
342        @Override
343        protected boolean checkStats() {
344            long bytesSoFar = getTotalBytesForNetwork(mRequest.template);
345            if (LOGV) {
346                Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
347                        + mRequest.template);
348            }
349            if (bytesSoFar > mRequest.thresholdInBytes) {
350                return true;
351            }
352            return false;
353        }
354
355        @Override
356        protected void recordSample(StatsContext statsContext) {
357            // Recorder does not need to be locked in this context since only the handler
358            // thread will update it. We pass a null VPN array because usage is aggregated by uid
359            // for this snapshot, so VPN traffic can't be reattributed to responsible apps.
360            mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces,
361                    null /* vpnArray */, statsContext.mCurrentTime);
362        }
363
364        /**
365         * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate
366         * over all buckets, which in this case should be only one since we built it big enough
367         * that it will outlive the caller. If it doesn't, then there will be multiple buckets.
368         */
369        private long getTotalBytesForNetwork(NetworkTemplate template) {
370            NetworkStats stats = mCollection.getSummary(template,
371                    Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
372                    mAccessLevel, mCallingUid);
373            return stats.getTotalBytes();
374        }
375    }
376
377    private static class UserUsageRequestInfo extends RequestInfo {
378        UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
379                    Messenger messenger, IBinder binder, int callingUid,
380                    @NetworkStatsAccess.Level int accessLevel) {
381            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
382        }
383
384        @Override
385        protected boolean checkStats() {
386            int[] uidsToMonitor = mCollection.getRelevantUids(mAccessLevel, mCallingUid);
387
388            for (int i = 0; i < uidsToMonitor.length; i++) {
389                long bytesSoFar = getTotalBytesForNetworkUid(mRequest.template, uidsToMonitor[i]);
390                if (bytesSoFar > mRequest.thresholdInBytes) {
391                    return true;
392                }
393            }
394            return false;
395        }
396
397        @Override
398        protected void recordSample(StatsContext statsContext) {
399            // Recorder does not need to be locked in this context since only the handler
400            // thread will update it. We pass the VPN info so VPN traffic is reattributed to
401            // responsible apps.
402            mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces,
403                    statsContext.mVpnArray, statsContext.mCurrentTime);
404        }
405
406        /**
407         * Reads all stats matching the given template and uid. Ther history will likely only
408         * contain one bucket per ident since we build it big enough that it will outlive the
409         * caller lifetime.
410         */
411        private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
412            try {
413                NetworkStatsHistory history = mCollection.getHistory(template, uid,
414                        NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
415                        NetworkStatsHistory.FIELD_ALL,
416                        Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
417                        mAccessLevel, mCallingUid);
418                return history.getTotalBytes();
419            } catch (SecurityException e) {
420                if (LOGV) {
421                    Slog.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid "
422                            + uid);
423                }
424                return 0;
425            }
426        }
427    }
428
429    private static class StatsContext {
430        NetworkStats mXtSnapshot;
431        NetworkStats mUidSnapshot;
432        ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
433        ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
434        VpnInfo[] mVpnArray;
435        long mCurrentTime;
436
437        StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
438                ArrayMap<String, NetworkIdentitySet> activeIfaces,
439                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
440                VpnInfo[] vpnArray, long currentTime) {
441            mXtSnapshot = xtSnapshot;
442            mUidSnapshot = uidSnapshot;
443            mActiveIfaces = activeIfaces;
444            mActiveUidIfaces = activeUidIfaces;
445            mVpnArray = vpnArray;
446            mCurrentTime = currentTime;
447        }
448    }
449}
450