1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.ext.services.resolver;
18
19import android.content.Context;
20import android.content.Intent;
21import android.content.SharedPreferences;
22import android.os.Environment;
23import android.os.IBinder;
24import android.os.storage.StorageManager;
25import android.service.resolver.ResolverRankerService;
26import android.service.resolver.ResolverTarget;
27import android.util.ArrayMap;
28import android.util.Log;
29
30import java.io.File;
31import java.util.Collection;
32import java.util.List;
33import java.util.Map;
34
35/**
36 * A Logistic Regression based {@link android.service.resolver.ResolverRankerService}, to be used
37 * in {@link ResolverComparator}.
38 */
39public final class LRResolverRankerService extends ResolverRankerService {
40    private static final String TAG = "LRResolverRankerService";
41
42    private static final boolean DEBUG = false;
43
44    private static final String PARAM_SHARED_PREF_NAME = "resolver_ranker_params";
45    private static final String BIAS_PREF_KEY = "bias";
46    private static final String VERSION_PREF_KEY = "version";
47
48    private static final String LAUNCH_SCORE = "launch";
49    private static final String TIME_SPENT_SCORE = "timeSpent";
50    private static final String RECENCY_SCORE = "recency";
51    private static final String CHOOSER_SCORE = "chooser";
52
53    // parameters for a pre-trained model, to initialize the app ranker. When updating the
54    // pre-trained model, please update these params, as well as initModel().
55    private static final int CURRENT_VERSION = 1;
56    private static final float LEARNING_RATE = 0.0001f;
57    private static final float REGULARIZER_PARAM = 0.0001f;
58
59    private SharedPreferences mParamSharedPref;
60    private ArrayMap<String, Float> mFeatureWeights;
61    private float mBias;
62
63    @Override
64    public IBinder onBind(Intent intent) {
65        initModel();
66        return super.onBind(intent);
67    }
68
69    @Override
70    public void onPredictSharingProbabilities(List<ResolverTarget> targets) {
71        final int size = targets.size();
72        for (int i = 0; i < size; ++i) {
73            ResolverTarget target = targets.get(i);
74            ArrayMap<String, Float> features = getFeatures(target);
75            target.setSelectProbability(predict(features));
76        }
77    }
78
79    @Override
80    public void onTrainRankingModel(List<ResolverTarget> targets, int selectedPosition) {
81        final int size = targets.size();
82        if (selectedPosition < 0 || selectedPosition >= size) {
83            if (DEBUG) {
84                Log.d(TAG, "Invalid Position of Selected App " + selectedPosition);
85            }
86            return;
87        }
88        final ArrayMap<String, Float> positive = getFeatures(targets.get(selectedPosition));
89        final float positiveProbability = targets.get(selectedPosition).getSelectProbability();
90        final int targetSize = targets.size();
91        for (int i = 0; i < targetSize; ++i) {
92            if (i == selectedPosition) {
93                continue;
94            }
95            final ArrayMap<String, Float> negative = getFeatures(targets.get(i));
96            final float negativeProbability = targets.get(i).getSelectProbability();
97            if (negativeProbability > positiveProbability) {
98                update(negative, negativeProbability, false);
99                update(positive, positiveProbability, true);
100            }
101        }
102        commitUpdate();
103    }
104
105    private void initModel() {
106        mParamSharedPref = getParamSharedPref();
107        mFeatureWeights = new ArrayMap<>(4);
108        if (mParamSharedPref == null ||
109                mParamSharedPref.getInt(VERSION_PREF_KEY, 0) < CURRENT_VERSION) {
110            // Initializing the app ranker to a pre-trained model. When updating the pre-trained
111            // model, please increment CURRENT_VERSION, and update LEARNING_RATE and
112            // REGULARIZER_PARAM.
113            mBias = -1.6568f;
114            mFeatureWeights.put(LAUNCH_SCORE, 2.5543f);
115            mFeatureWeights.put(TIME_SPENT_SCORE, 2.8412f);
116            mFeatureWeights.put(RECENCY_SCORE, 0.269f);
117            mFeatureWeights.put(CHOOSER_SCORE, 4.2222f);
118        } else {
119            mBias = mParamSharedPref.getFloat(BIAS_PREF_KEY, 0.0f);
120            mFeatureWeights.put(LAUNCH_SCORE, mParamSharedPref.getFloat(LAUNCH_SCORE, 0.0f));
121            mFeatureWeights.put(
122                    TIME_SPENT_SCORE, mParamSharedPref.getFloat(TIME_SPENT_SCORE, 0.0f));
123            mFeatureWeights.put(RECENCY_SCORE, mParamSharedPref.getFloat(RECENCY_SCORE, 0.0f));
124            mFeatureWeights.put(CHOOSER_SCORE, mParamSharedPref.getFloat(CHOOSER_SCORE, 0.0f));
125        }
126    }
127
128    private ArrayMap<String, Float> getFeatures(ResolverTarget target) {
129        ArrayMap<String, Float> features = new ArrayMap<>(4);
130        features.put(RECENCY_SCORE, target.getRecencyScore());
131        features.put(TIME_SPENT_SCORE, target.getTimeSpentScore());
132        features.put(LAUNCH_SCORE, target.getLaunchScore());
133        features.put(CHOOSER_SCORE, target.getChooserScore());
134        return features;
135    }
136
137    private float predict(ArrayMap<String, Float> target) {
138        if (target == null) {
139            return 0.0f;
140        }
141        final int featureSize = target.size();
142        float sum = 0.0f;
143        for (int i = 0; i < featureSize; i++) {
144            String featureName = target.keyAt(i);
145            float weight = mFeatureWeights.getOrDefault(featureName, 0.0f);
146            sum += weight * target.valueAt(i);
147        }
148        return (float) (1.0 / (1.0 + Math.exp(-mBias - sum)));
149    }
150
151    private void update(ArrayMap<String, Float> target, float predict, boolean isSelected) {
152        if (target == null) {
153            return;
154        }
155        final int featureSize = target.size();
156        float error = isSelected ? 1.0f - predict : -predict;
157        for (int i = 0; i < featureSize; i++) {
158            String featureName = target.keyAt(i);
159            float currentWeight = mFeatureWeights.getOrDefault(featureName, 0.0f);
160            mBias += LEARNING_RATE * error;
161            currentWeight = currentWeight - LEARNING_RATE * REGULARIZER_PARAM * currentWeight +
162                    LEARNING_RATE * error * target.valueAt(i);
163            mFeatureWeights.put(featureName, currentWeight);
164        }
165        if (DEBUG) {
166            Log.d(TAG, "Weights: " + mFeatureWeights + " Bias: " + mBias);
167        }
168    }
169
170    private void commitUpdate() {
171        try {
172            SharedPreferences.Editor editor = mParamSharedPref.edit();
173            editor.putFloat(BIAS_PREF_KEY, mBias);
174            final int size = mFeatureWeights.size();
175            for (int i = 0; i < size; i++) {
176                editor.putFloat(mFeatureWeights.keyAt(i), mFeatureWeights.valueAt(i));
177            }
178            editor.putInt(VERSION_PREF_KEY, CURRENT_VERSION);
179            editor.apply();
180        } catch (Exception e) {
181            Log.e(TAG, "Failed to commit update" + e);
182        }
183    }
184
185    private SharedPreferences getParamSharedPref() {
186        // The package info in the context isn't initialized in the way it is for normal apps,
187        // so the standard, name-based context.getSharedPreferences doesn't work. Instead, we
188        // build the path manually below using the same policy that appears in ContextImpl.
189        if (DEBUG) {
190            Log.d(TAG, "Context Package Name: " + getPackageName());
191        }
192        final File prefsFile = new File(new File(
193                Environment.getDataUserCePackageDirectory(
194                        StorageManager.UUID_PRIVATE_INTERNAL, getUserId(), getPackageName()),
195                "shared_prefs"),
196                PARAM_SHARED_PREF_NAME + ".xml");
197        return getSharedPreferences(prefsFile, Context.MODE_PRIVATE);
198    }
199}