1/*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.bordeaux.learning;
18
19import android.util.Log;
20
21import java.io.ByteArrayInputStream;
22import java.io.ByteArrayOutputStream;
23import java.io.IOException;
24import java.io.ObjectInputStream;
25import java.io.ObjectOutputStream;
26import java.io.Serializable;
27import java.util.ArrayList;
28import java.util.Collections;
29import java.util.Comparator;
30import java.util.HashMap;
31import java.util.HashSet;
32import java.util.Iterator;
33import java.util.List;
34import java.util.Map;
35import java.util.Map.Entry;
36import java.util.concurrent.ConcurrentHashMap;
37/**
38 * A histogram based predictor which records co-occurrences of applations with a speficic
39 * feature, for example, location, * time of day, etc. The histogram is kept in a two level
40 * hash table. The first level key is the feature value and the second level key is the app
41 * id.
42 */
43// TODOS:
44// 1. Use forgetting factor to downweight istances propotional to the time
45// 2. Different features could have different weights on prediction scores.
46// 3. Add function to remove sampleid (i.e. remove apps that are uninstalled).
47
48
49public class HistogramPredictor {
50    final static String TAG = "HistogramPredictor";
51
52    private HashMap<String, HistogramCounter> mPredictor =
53            new HashMap<String, HistogramCounter>();
54
55    private HashMap<String, Integer> mClassCounts = new HashMap<String, Integer>();
56    private HashSet<String> mBlacklist = new HashSet<String>();
57
58    private static final int MINIMAL_FEATURE_VALUE_COUNTS = 5;
59    private static final int MINIMAL_APP_APPEARANCE_COUNTS = 5;
60
61    // This parameter ranges from 0 to 1 which determines the effect of app prior.
62    // When it is set to 0, app prior means completely neglected. When it is set to 1
63    // the predictor is a standard naive bayes model.
64    private static final int PRIOR_K_VALUE = 1;
65
66    private static final String[] APP_BLACKLIST = {
67        "com.android.contacts",
68        "com.android.chrome",
69        "com.android.providers.downloads.ui",
70        "com.android.settings",
71        "com.android.vending",
72        "com.android.mms",
73        "com.google.android.gm",
74        "com.google.android.gallery3d",
75        "com.google.android.apps.googlevoice",
76    };
77
78    public HistogramPredictor(String[] blackList) {
79        for (String appName : blackList) {
80            mBlacklist.add(appName);
81        }
82    }
83
84    /*
85     * This class keeps the histogram counts for each feature and provide the
86     * joint probabilities of <feature, class>.
87     */
88    private class HistogramCounter {
89        private HashMap<String, HashMap<String, Integer> > mCounter =
90                new HashMap<String, HashMap<String, Integer> >();
91
92        public HistogramCounter() {
93            mCounter.clear();
94        }
95
96        public void setCounter(HashMap<String, HashMap<String, Integer> > counter) {
97            resetCounter();
98            mCounter.putAll(counter);
99        }
100
101        public void resetCounter() {
102            mCounter.clear();
103        }
104
105        public void addSample(String className, String featureValue) {
106            HashMap<String, Integer> classCounts;
107
108            if (!mCounter.containsKey(featureValue)) {
109                classCounts = new HashMap<String, Integer>();
110                mCounter.put(featureValue, classCounts);
111            } else {
112                classCounts = mCounter.get(featureValue);
113            }
114            int count = (classCounts.containsKey(className)) ?
115                    classCounts.get(className) + 1 : 1;
116            classCounts.put(className, count);
117        }
118
119        public HashMap<String, Double> getClassScores(String featureValue) {
120            HashMap<String, Double> classScores = new HashMap<String, Double>();
121
122            if (mCounter.containsKey(featureValue)) {
123                int totalCount = 0;
124                for(Map.Entry<String, Integer> entry :
125                        mCounter.get(featureValue).entrySet()) {
126                    String app = entry.getKey();
127                    int count = entry.getValue();
128
129                    // For apps with counts less than or equal to one, we treated
130                    // those as having count one. Hence their score, i.e. log(count)
131                    // would be zero. classScroes stores only apps with non-zero scores.
132                    // Note that totalCount also neglect app with single occurrence.
133                    if (count > 1) {
134                        double score = Math.log((double) count);
135                        classScores.put(app, score);
136                        totalCount += count;
137                    }
138                }
139                if (totalCount < MINIMAL_FEATURE_VALUE_COUNTS) {
140                    classScores.clear();
141                }
142            }
143            return classScores;
144        }
145
146        public byte[] getModel() {
147            try {
148                ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
149                ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
150                synchronized(mCounter) {
151                    objStream.writeObject(mCounter);
152                }
153                byte[] bytes = byteStream.toByteArray();
154                return bytes;
155            } catch (IOException e) {
156                throw new RuntimeException("Can't get model");
157            }
158        }
159
160        public boolean setModel(final byte[] modelData) {
161            mCounter.clear();
162            HashMap<String, HashMap<String, Integer> > model;
163
164            try {
165                ByteArrayInputStream input = new ByteArrayInputStream(modelData);
166                ObjectInputStream objStream = new ObjectInputStream(input);
167                model = (HashMap<String, HashMap<String, Integer> >) objStream.readObject();
168            } catch (IOException e) {
169                throw new RuntimeException("Can't load model");
170            } catch (ClassNotFoundException e) {
171                throw new RuntimeException("Learning class not found");
172            }
173
174            synchronized(mCounter) {
175                mCounter.putAll(model);
176            }
177
178            return true;
179        }
180
181
182        public HashMap<String, HashMap<String, Integer> > getCounter() {
183            return mCounter;
184        }
185
186        public String toString() {
187            String result = "";
188            for (Map.Entry<String, HashMap<String, Integer> > entry :
189                     mCounter.entrySet()) {
190                result += "{ " + entry.getKey() + " : " +
191                    entry.getValue().toString() + " }";
192            }
193            return result;
194        }
195    }
196
197    /*
198     * Given a map of feature name -value pairs returns topK mostly likely apps to
199     * be launched with corresponding likelihoods. If topK is set zero, it will return
200     * the whole list.
201     */
202    public List<Map.Entry<String, Double> > findTopClasses(Map<String, String> features, int topK) {
203        // Most sophisticated function in this class
204        HashMap<String, Double> appScores = new HashMap<String, Double>();
205        int validFeatureCount = 0;
206
207        // compute all app scores
208        for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
209            String featureName = entry.getKey();
210            HistogramCounter counter = entry.getValue();
211
212            if (features.containsKey(featureName)) {
213                String featureValue = features.get(featureName);
214                HashMap<String, Double> scoreMap = counter.getClassScores(featureValue);
215
216                if (scoreMap.isEmpty()) {
217                  continue;
218                }
219                validFeatureCount++;
220
221                for (Map.Entry<String, Double> item : scoreMap.entrySet()) {
222                    String appName = item.getKey();
223                    double appScore = item.getValue();
224                    if (appScores.containsKey(appName)) {
225                        appScore += appScores.get(appName);
226                    }
227                    appScores.put(appName, appScore);
228                }
229            }
230        }
231
232        HashMap<String, Double> appCandidates = new HashMap<String, Double>();
233        for (Map.Entry<String, Double> entry : appScores.entrySet()) {
234            String appName = entry.getKey();
235            if (mBlacklist.contains(appName)) {
236                Log.i(TAG, appName + " is in blacklist");
237                continue;
238            }
239            if (!mClassCounts.containsKey(appName)) {
240                throw new RuntimeException("class count error!");
241            }
242            int appCount = mClassCounts.get(appName);
243            if (appCount < MINIMAL_APP_APPEARANCE_COUNTS) {
244                Log.i(TAG, appName + " doesn't have enough counts");
245                continue;
246            }
247
248            double appScore = entry.getValue();
249            double appPrior = Math.log((double) appCount);
250            appCandidates.put(appName,
251                              appScore - appPrior * (validFeatureCount - PRIOR_K_VALUE));
252        }
253
254        // sort app scores
255        List<Map.Entry<String, Double> > appList =
256               new ArrayList<Map.Entry<String, Double> >(appCandidates.size());
257        appList.addAll(appCandidates.entrySet());
258        Collections.sort(appList, new  Comparator<Map.Entry<String, Double> >() {
259            public int compare(Map.Entry<String, Double> o1,
260                               Map.Entry<String, Double> o2) {
261                return o2.getValue().compareTo(o1.getValue());
262            }
263        });
264
265        if (topK == 0) {
266            topK = appList.size();
267        }
268        return appList.subList(0, Math.min(topK, appList.size()));
269    }
270
271    /*
272     * Add a new observation of given sample id and features to the histograms
273     */
274    public void addSample(String sampleId, Map<String, String> features) {
275        for (Map.Entry<String, String> entry : features.entrySet()) {
276            String featureName = entry.getKey();
277            String featureValue = entry.getValue();
278
279            useFeature(featureName);
280            HistogramCounter counter = mPredictor.get(featureName);
281            counter.addSample(sampleId, featureValue);
282        }
283
284        int sampleCount = (mClassCounts.containsKey(sampleId)) ?
285            mClassCounts.get(sampleId) + 1 : 1;
286        mClassCounts.put(sampleId, sampleCount);
287    }
288
289    /*
290     * reset predictor to a empty model
291     */
292    public void resetPredictor() {
293        // TODO: not sure this step would reduce memory waste
294        for (HistogramCounter counter : mPredictor.values()) {
295            counter.resetCounter();
296        }
297        mPredictor.clear();
298        mClassCounts.clear();
299    }
300
301    /*
302     * convert the prediction model into a byte array
303     */
304    public byte[] getModel() {
305        // TODO: convert model to a more memory efficient data structure.
306        HashMap<String, HashMap<String, HashMap<String, Integer > > > model =
307                new HashMap<String, HashMap<String, HashMap<String, Integer > > >();
308        for(Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
309            model.put(entry.getKey(), entry.getValue().getCounter());
310        }
311
312        try {
313            ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
314            ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
315            objStream.writeObject(model);
316            byte[] bytes = byteStream.toByteArray();
317            return bytes;
318        } catch (IOException e) {
319            throw new RuntimeException("Can't get model");
320        }
321    }
322
323    /*
324     * set the prediction model from a model data in the format of byte array
325     */
326    public boolean setModel(final byte[] modelData) {
327        HashMap<String, HashMap<String, HashMap<String, Integer > > > model;
328
329        try {
330            ByteArrayInputStream input = new ByteArrayInputStream(modelData);
331            ObjectInputStream objStream = new ObjectInputStream(input);
332            model = (HashMap<String, HashMap<String, HashMap<String, Integer > > >)
333                    objStream.readObject();
334        } catch (IOException e) {
335            throw new RuntimeException("Can't load model");
336        } catch (ClassNotFoundException e) {
337            throw new RuntimeException("Learning class not found");
338        }
339
340        resetPredictor();
341        for (Map.Entry<String, HashMap<String, HashMap<String, Integer> > > entry :
342                model.entrySet()) {
343            useFeature(entry.getKey());
344            mPredictor.get(entry.getKey()).setCounter(entry.getValue());
345        }
346
347        // TODO: this is a temporary fix for now
348        loadClassCounter();
349
350        return true;
351    }
352
353    private void loadClassCounter() {
354        String TIME_OF_WEEK = "Time of Week";
355
356        if (!mPredictor.containsKey(TIME_OF_WEEK)) {
357            throw new RuntimeException("Precition model error: missing Time of Week!");
358        }
359
360        HashMap<String, HashMap<String, Integer> > counter =
361            mPredictor.get(TIME_OF_WEEK).getCounter();
362
363        mClassCounts.clear();
364        for (HashMap<String, Integer> map : counter.values()) {
365            for (Map.Entry<String, Integer> entry : map.entrySet()) {
366                int classCount = entry.getValue();
367                String className = entry.getKey();
368                // mTotalClassCount += classCount;
369
370                if (mClassCounts.containsKey(className)) {
371                    classCount += mClassCounts.get(className);
372                }
373                mClassCounts.put(className, classCount);
374            }
375        }
376        Log.i(TAG, "class counts: " + mClassCounts);
377    }
378
379    private void useFeature(String featureName) {
380        if (!mPredictor.containsKey(featureName)) {
381            mPredictor.put(featureName, new HistogramCounter());
382        }
383    }
384}
385