1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.android.server.net;
18
19import static android.app.usage.NetworkStatsManager.MIN_THRESHOLD_BYTES;
20
21import static com.android.internal.util.Preconditions.checkArgument;
22
23import android.app.usage.NetworkStatsManager;
24import android.net.DataUsageRequest;
25import android.net.NetworkStats;
26import android.net.NetworkStatsHistory;
27import android.net.NetworkTemplate;
28import android.os.Bundle;
29import android.os.Handler;
30import android.os.HandlerThread;
31import android.os.IBinder;
32import android.os.Looper;
33import android.os.Message;
34import android.os.Messenger;
35import android.os.Process;
36import android.os.RemoteException;
37import android.util.ArrayMap;
38import android.util.Slog;
39import android.util.SparseArray;
40
41import com.android.internal.annotations.VisibleForTesting;
42import com.android.internal.net.VpnInfo;
43
44import java.util.concurrent.atomic.AtomicInteger;
45
46/**
47 * Manages observers of {@link NetworkStats}. Allows observers to be notified when
48 * data usage has been reported in {@link NetworkStatsService}. An observer can set
49 * a threshold of how much data it cares about to be notified.
50 */
51class NetworkStatsObservers {
52    private static final String TAG = "NetworkStatsObservers";
53    private static final boolean LOGV = false;
54
55    private static final int MSG_REGISTER = 1;
56    private static final int MSG_UNREGISTER = 2;
57    private static final int MSG_UPDATE_STATS = 3;
58
59    // All access to this map must be done from the handler thread.
60    // indexed by DataUsageRequest#requestId
61    private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();
62
63    // Sequence number of DataUsageRequests
64    private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();
65
66    // Lazily instantiated when an observer is registered.
67    private volatile Handler mHandler;
68
69    /**
70     * Creates a wrapper that contains the caller context and a normalized request.
71     * The request should be returned to the caller app, and the wrapper should be sent to this
72     * object through #addObserver by the service handler.
73     *
74     * <p>It will register the observer asynchronously, so it is safe to call from any thread.
75     *
76     * @return the normalized request wrapped within {@link RequestInfo}.
77     */
78    public DataUsageRequest register(DataUsageRequest inputRequest, Messenger messenger,
79                IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel) {
80        DataUsageRequest request = buildRequest(inputRequest);
81        RequestInfo requestInfo = buildRequestInfo(request, messenger, binder, callingUid,
82                accessLevel);
83
84        if (LOGV) Slog.v(TAG, "Registering observer for " + request);
85        getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
86        return request;
87    }
88
89    /**
90     * Unregister a data usage observer.
91     *
92     * <p>It will unregister the observer asynchronously, so it is safe to call from any thread.
93     */
94    public void unregister(DataUsageRequest request, int callingUid) {
95        getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */,
96                request));
97    }
98
99    /**
100     * Updates data usage statistics of registered observers and notifies if limits are reached.
101     *
102     * <p>It will update stats asynchronously, so it is safe to call from any thread.
103     */
104    public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
105                ArrayMap<String, NetworkIdentitySet> activeIfaces,
106                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
107                VpnInfo[] vpnArray, long currentTime) {
108        StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces,
109                activeUidIfaces, vpnArray, currentTime);
110        getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext));
111    }
112
113    private Handler getHandler() {
114        if (mHandler == null) {
115            synchronized (this) {
116                if (mHandler == null) {
117                    if (LOGV) Slog.v(TAG, "Creating handler");
118                    mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback);
119                }
120            }
121        }
122        return mHandler;
123    }
124
125    @VisibleForTesting
126    protected Looper getHandlerLooperLocked() {
127        HandlerThread handlerThread = new HandlerThread(TAG);
128        handlerThread.start();
129        return handlerThread.getLooper();
130    }
131
132    private Handler.Callback mHandlerCallback = new Handler.Callback() {
133        @Override
134        public boolean handleMessage(Message msg) {
135            switch (msg.what) {
136                case MSG_REGISTER: {
137                    handleRegister((RequestInfo) msg.obj);
138                    return true;
139                }
140                case MSG_UNREGISTER: {
141                    handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */);
142                    return true;
143                }
144                case MSG_UPDATE_STATS: {
145                    handleUpdateStats((StatsContext) msg.obj);
146                    return true;
147                }
148                default: {
149                    return false;
150                }
151            }
152        }
153    };
154
155    /**
156     * Adds a {@link RequestInfo} as an observer.
157     * Should only be called from the handler thread otherwise there will be a race condition
158     * on mDataUsageRequests.
159     */
160    private void handleRegister(RequestInfo requestInfo) {
161        mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo);
162    }
163
164    /**
165     * Removes a {@link DataUsageRequest} if the calling uid is authorized.
166     * Should only be called from the handler thread otherwise there will be a race condition
167     * on mDataUsageRequests.
168     */
169    private void handleUnregister(DataUsageRequest request, int callingUid) {
170        RequestInfo requestInfo;
171        requestInfo = mDataUsageRequests.get(request.requestId);
172        if (requestInfo == null) {
173            if (LOGV) Slog.v(TAG, "Trying to unregister unknown request " + request);
174            return;
175        }
176        if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
177            Slog.w(TAG, "Caller uid " + callingUid + " is not owner of " + request);
178            return;
179        }
180
181        if (LOGV) Slog.v(TAG, "Unregistering " + request);
182        mDataUsageRequests.remove(request.requestId);
183        requestInfo.unlinkDeathRecipient();
184        requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
185    }
186
187    private void handleUpdateStats(StatsContext statsContext) {
188        if (mDataUsageRequests.size() == 0) {
189            return;
190        }
191
192        for (int i = 0; i < mDataUsageRequests.size(); i++) {
193            RequestInfo requestInfo = mDataUsageRequests.valueAt(i);
194            requestInfo.updateStats(statsContext);
195        }
196    }
197
198    private DataUsageRequest buildRequest(DataUsageRequest request) {
199        // Cap the minimum threshold to a safe default to avoid too many callbacks
200        long thresholdInBytes = Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes);
201        if (thresholdInBytes < request.thresholdInBytes) {
202            Slog.w(TAG, "Threshold was too low for " + request
203                    + ". Overriding to a safer default of " + thresholdInBytes + " bytes");
204        }
205        return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(),
206                request.template, thresholdInBytes);
207    }
208
209    private RequestInfo buildRequestInfo(DataUsageRequest request,
210                Messenger messenger, IBinder binder, int callingUid,
211                @NetworkStatsAccess.Level int accessLevel) {
212        if (accessLevel <= NetworkStatsAccess.Level.USER) {
213            return new UserUsageRequestInfo(this, request, messenger, binder, callingUid,
214                    accessLevel);
215        } else {
216            // Safety check in case a new access level is added and we forgot to update this
217            checkArgument(accessLevel >= NetworkStatsAccess.Level.DEVICESUMMARY);
218            return new NetworkUsageRequestInfo(this, request, messenger, binder, callingUid,
219                    accessLevel);
220        }
221    }
222
223    /**
224     * Tracks information relevant to a data usage observer.
225     * It will notice when the calling process dies so we can self-expire.
226     */
227    private abstract static class RequestInfo implements IBinder.DeathRecipient {
228        private final NetworkStatsObservers mStatsObserver;
229        protected final DataUsageRequest mRequest;
230        private final Messenger mMessenger;
231        private final IBinder mBinder;
232        protected final int mCallingUid;
233        protected final @NetworkStatsAccess.Level int mAccessLevel;
234        protected NetworkStatsRecorder mRecorder;
235        protected NetworkStatsCollection mCollection;
236
237        RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
238                    Messenger messenger, IBinder binder, int callingUid,
239                    @NetworkStatsAccess.Level int accessLevel) {
240            mStatsObserver = statsObserver;
241            mRequest = request;
242            mMessenger = messenger;
243            mBinder = binder;
244            mCallingUid = callingUid;
245            mAccessLevel = accessLevel;
246
247            try {
248                mBinder.linkToDeath(this, 0);
249            } catch (RemoteException e) {
250                binderDied();
251            }
252        }
253
254        @Override
255        public void binderDied() {
256            if (LOGV) Slog.v(TAG, "RequestInfo binderDied("
257                    + mRequest + ", " + mBinder + ")");
258            mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
259            callCallback(NetworkStatsManager.CALLBACK_RELEASED);
260        }
261
262        @Override
263        public String toString() {
264            return "RequestInfo from uid:" + mCallingUid
265                    + " for " + mRequest + " accessLevel:" + mAccessLevel;
266        }
267
268        private void unlinkDeathRecipient() {
269            if (mBinder != null) {
270                mBinder.unlinkToDeath(this, 0);
271            }
272        }
273
274        /**
275         * Update stats given the samples and interface to identity mappings.
276         */
277        private void updateStats(StatsContext statsContext) {
278            if (mRecorder == null) {
279                // First run; establish baseline stats
280                resetRecorder();
281                recordSample(statsContext);
282                return;
283            }
284            recordSample(statsContext);
285
286            if (checkStats()) {
287                resetRecorder();
288                callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED);
289            }
290        }
291
292        private void callCallback(int callbackType) {
293            Bundle bundle = new Bundle();
294            bundle.putParcelable(DataUsageRequest.PARCELABLE_KEY, mRequest);
295            Message msg = Message.obtain();
296            msg.what = callbackType;
297            msg.setData(bundle);
298            try {
299                if (LOGV) {
300                    Slog.v(TAG, "sending notification " + callbackTypeToName(callbackType)
301                            + " for " + mRequest);
302                }
303                mMessenger.send(msg);
304            } catch (RemoteException e) {
305                // May occur naturally in the race of binder death.
306                Slog.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
307            }
308        }
309
310        private void resetRecorder() {
311            mRecorder = new NetworkStatsRecorder();
312            mCollection = mRecorder.getSinceBoot();
313        }
314
315        protected abstract boolean checkStats();
316
317        protected abstract void recordSample(StatsContext statsContext);
318
319        private String callbackTypeToName(int callbackType) {
320            switch (callbackType) {
321                case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
322                    return "LIMIT_REACHED";
323                case NetworkStatsManager.CALLBACK_RELEASED:
324                    return "RELEASED";
325                default:
326                    return "UNKNOWN";
327            }
328        }
329    }
330
331    private static class NetworkUsageRequestInfo extends RequestInfo {
332        NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
333                    Messenger messenger, IBinder binder, int callingUid,
334                    @NetworkStatsAccess.Level int accessLevel) {
335            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
336        }
337
338        @Override
339        protected boolean checkStats() {
340            long bytesSoFar = getTotalBytesForNetwork(mRequest.template);
341            if (LOGV) {
342                Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
343                        + mRequest.template);
344            }
345            if (bytesSoFar > mRequest.thresholdInBytes) {
346                return true;
347            }
348            return false;
349        }
350
351        @Override
352        protected void recordSample(StatsContext statsContext) {
353            // Recorder does not need to be locked in this context since only the handler
354            // thread will update it. We pass a null VPN array because usage is aggregated by uid
355            // for this snapshot, so VPN traffic can't be reattributed to responsible apps.
356            mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces,
357                    null /* vpnArray */, statsContext.mCurrentTime);
358        }
359
360        /**
361         * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate
362         * over all buckets, which in this case should be only one since we built it big enough
363         * that it will outlive the caller. If it doesn't, then there will be multiple buckets.
364         */
365        private long getTotalBytesForNetwork(NetworkTemplate template) {
366            NetworkStats stats = mCollection.getSummary(template,
367                    Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
368                    mAccessLevel, mCallingUid);
369            return stats.getTotalBytes();
370        }
371    }
372
373    private static class UserUsageRequestInfo extends RequestInfo {
374        UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
375                    Messenger messenger, IBinder binder, int callingUid,
376                    @NetworkStatsAccess.Level int accessLevel) {
377            super(statsObserver, request, messenger, binder, callingUid, accessLevel);
378        }
379
380        @Override
381        protected boolean checkStats() {
382            int[] uidsToMonitor = mCollection.getRelevantUids(mAccessLevel, mCallingUid);
383
384            for (int i = 0; i < uidsToMonitor.length; i++) {
385                long bytesSoFar = getTotalBytesForNetworkUid(mRequest.template, uidsToMonitor[i]);
386                if (bytesSoFar > mRequest.thresholdInBytes) {
387                    return true;
388                }
389            }
390            return false;
391        }
392
393        @Override
394        protected void recordSample(StatsContext statsContext) {
395            // Recorder does not need to be locked in this context since only the handler
396            // thread will update it. We pass the VPN info so VPN traffic is reattributed to
397            // responsible apps.
398            mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces,
399                    statsContext.mVpnArray, statsContext.mCurrentTime);
400        }
401
402        /**
403         * Reads all stats matching the given template and uid. Ther history will likely only
404         * contain one bucket per ident since we build it big enough that it will outlive the
405         * caller lifetime.
406         */
407        private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
408            try {
409                NetworkStatsHistory history = mCollection.getHistory(template, null, uid,
410                        NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
411                        NetworkStatsHistory.FIELD_ALL,
412                        Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
413                        mAccessLevel, mCallingUid);
414                return history.getTotalBytes();
415            } catch (SecurityException e) {
416                if (LOGV) {
417                    Slog.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid "
418                            + uid);
419                }
420                return 0;
421            }
422        }
423    }
424
425    private static class StatsContext {
426        NetworkStats mXtSnapshot;
427        NetworkStats mUidSnapshot;
428        ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
429        ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
430        VpnInfo[] mVpnArray;
431        long mCurrentTime;
432
433        StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
434                ArrayMap<String, NetworkIdentitySet> activeIfaces,
435                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
436                VpnInfo[] vpnArray, long currentTime) {
437            mXtSnapshot = xtSnapshot;
438            mUidSnapshot = uidSnapshot;
439            mActiveIfaces = activeIfaces;
440            mActiveUidIfaces = activeUidIfaces;
441            mVpnArray = vpnArray;
442            mCurrentTime = currentTime;
443        }
444    }
445}
446