HistogramPredictor.java revision 1253e9fb0b5570ab8adaed222655a5b052aa072e
1/*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18package android.bordeaux.learning;
19
20import android.util.Log;
21import android.util.Pair;
22
23import java.io.ByteArrayInputStream;
24import java.io.ByteArrayOutputStream;
25import java.io.IOException;
26import java.io.ObjectInputStream;
27import java.io.ObjectOutputStream;
28import java.io.Serializable;
29import java.util.ArrayList;
30import java.util.Collections;
31import java.util.Comparator;
32import java.util.HashMap;
33import java.util.HashSet;
34import java.util.Iterator;
35import java.util.List;
36import java.util.Map;
37import java.util.Map.Entry;
38
39/**
40 * A histogram based predictor which records co-occurrences of applations with a speficic feature,
41 * for example, location, * time of day, etc. The histogram is kept in a two level hash table.
42 * The first level key is the feature value and the second level key is the app id.
43 */
44// TODOS:
45// 1. Use forgetting factor to downweight istances propotional to the time
46// 2. Different features could have different weights on prediction scores.
47// 3. Make prediction (on each feature) only when the histogram has collected
48//    sufficient counts.
49public class HistogramPredictor {
50    final static String TAG = "HistogramPredictor";
51
52    private HashMap<String, HistogramCounter> mPredictor =
53            new HashMap<String, HistogramCounter>();
54
55    private HashMap<String, Integer> mClassCounts = new HashMap<String, Integer>();
56    private int mTotalClassCount = 0;
57
58    private static final double FEATURE_INACTIVE_LIKELIHOOD = 0.00000001;
59    private static final double LOG_INACTIVE = Math.log(FEATURE_INACTIVE_LIKELIHOOD);
60
61    /*
62     * This class keeps the histogram counts for each feature and provide the
63     * joint probabilities of <feature, class>.
64     */
65    private class HistogramCounter {
66        private HashMap<String, HashMap<String, Integer> > mCounter =
67                new HashMap<String, HashMap<String, Integer> >();
68        private int mTotalCount;
69
70        public HistogramCounter() {
71            resetCounter();
72        }
73
74        public void setCounter(HashMap<String, HashMap<String, Integer> > counter) {
75            resetCounter();
76            mCounter.putAll(counter);
77
78            // get total count
79            for (Map.Entry<String, HashMap<String, Integer> > entry : counter.entrySet()) {
80                for (Integer value : entry.getValue().values()) {
81                    mTotalCount += value.intValue();
82                }
83            }
84        }
85
86        public void resetCounter() {
87            mCounter.clear();
88            mTotalCount = 0;
89        }
90
91        public void addSample(String className, String featureValue) {
92            HashMap<String, Integer> classCounts;
93
94            if (!mCounter.containsKey(featureValue)) {
95                classCounts = new HashMap<String, Integer>();
96                mCounter.put(featureValue, classCounts);
97            } else {
98                classCounts = mCounter.get(featureValue);
99            }
100            int count = (classCounts.containsKey(className)) ?
101                    classCounts.get(className) + 1 : 1;
102            classCounts.put(className, count);
103            mTotalCount++;
104        }
105
106        public HashMap<String, Double> getClassScores(String featureValue) {
107            HashMap<String, Double> classScores = new HashMap<String, Double>();
108
109            double logTotalCount = Math.log((double) mTotalCount);
110            if (mCounter.containsKey(featureValue)) {
111                for(Map.Entry<String, Integer> entry :
112                        mCounter.get(featureValue).entrySet()) {
113                    double score =
114                            Math.log((double) entry.getValue()) - logTotalCount;
115                    classScores.put(entry.getKey(), score);
116                }
117            }
118            return classScores;
119        }
120
121        public HashMap<String, HashMap<String, Integer> > getCounter() {
122            return mCounter;
123        }
124    }
125
126    private double getDefaultLikelihood(Map<String, String> features) {
127        int featureCount = 0;
128
129        for(String featureName : features.keySet()) {
130            if (mPredictor.containsKey(featureName)) {
131                featureCount++;
132            }
133        }
134        return LOG_INACTIVE * featureCount;
135    }
136
137    /*
138     * Given a map of feature name -value pairs returns the mostly likely apps to
139     * be launched with corresponding likelihoods.
140     */
141    public List<Map.Entry<String, Double> > findTopClasses(Map<String, String> features, int topK) {
142        // Most sophisticated function in this class
143        HashMap<String, Double> appScores = new HashMap<String, Double>();
144        double defaultLikelihood = getDefaultLikelihood(features);
145
146        HashMap<String, Integer> appearCounts = new HashMap<String, Integer>();
147
148        // compute all app scores
149        for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
150            String featureName = entry.getKey();
151            HistogramCounter counter = entry.getValue();
152
153            if (features.containsKey(featureName)) {
154                String featureValue = features.get(featureName);
155                HashMap<String, Double> scoreMap = counter.getClassScores(featureValue);
156
157                for (Map.Entry<String, Double> item : scoreMap.entrySet()) {
158                    String appName = item.getKey();
159                    double appScore = item.getValue();
160                    double score = (appScores.containsKey(appName)) ?
161                        appScores.get(appName) : defaultLikelihood;
162                    score += appScore - LOG_INACTIVE;
163                    appScores.put(appName, score);
164
165                    int count = (appearCounts.containsKey(appName)) ?
166                        appearCounts.get(appName) + 1 : 1;
167                    appearCounts.put(appName, count);
168                }
169            }
170        }
171
172        // TODO: this check should be unnecessary
173        if (mClassCounts.size() != 0 && mTotalClassCount != 0) {
174            for (Map.Entry<String, Double> entry : appScores.entrySet()) {
175                String appName = entry.getKey();
176                double appScore = entry.getValue();
177                if (!appearCounts.containsKey(appName)) {
178                    throw new RuntimeException("appearance count error!");
179                }
180                int appearCount = appearCounts.get(appName);
181
182                if (!mClassCounts.containsKey(appName)) {
183                    throw new RuntimeException("class count error!");
184                }
185                double appPrior =
186                    Math.log(mClassCounts.get(appName)) - Math.log(mTotalClassCount);
187                appScores.put(appName, appScore - appPrior * (appearCount - 1));
188            }
189        }
190
191        // sort app scores
192        List<Map.Entry<String, Double> > appList =
193               new ArrayList<Map.Entry<String, Double> >(appScores.size());
194        appList.addAll(appScores.entrySet());
195        Collections.sort(appList, new  Comparator<Map.Entry<String, Double> >() {
196            public int compare(Map.Entry<String, Double> o1,
197                               Map.Entry<String, Double> o2) {
198                return o2.getValue().compareTo(o1.getValue());
199            }
200        });
201
202        Log.v(TAG, "findTopApps appList: " + appList);
203        return appList;
204    }
205
206    /*
207     * Add a new observation of given sample id and features to the histograms
208     */
209    public void addSample(String sampleId, Map<String, String> features) {
210        for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
211            String featureName = entry.getKey();
212            HistogramCounter counter = entry.getValue();
213
214            if (features.containsKey(featureName)) {
215                String featureValue = features.get(featureName);
216                counter.addSample(sampleId, featureValue);
217            }
218        }
219
220        int sampleCount = (mClassCounts.containsKey(sampleId)) ?
221            mClassCounts.get(sampleId) + 1 : 1;
222        mClassCounts.put(sampleId, sampleCount);
223    }
224
225    /*
226     * reset predictor to a empty model
227     */
228    public void resetPredictor() {
229        // TODO: not sure this step would reduce memory waste
230        for (HistogramCounter counter : mPredictor.values()) {
231            counter.resetCounter();
232        }
233        mPredictor.clear();
234
235        mClassCounts.clear();
236        mTotalClassCount = 0;
237    }
238
239    /*
240     * specify a feature to used for prediction
241     */
242    public void useFeature(String featureName) {
243        if (!mPredictor.containsKey(featureName)) {
244            mPredictor.put(featureName, new HistogramCounter());
245        }
246    }
247
248    /*
249     * convert the prediction model into a byte array
250     */
251    public byte[] getModel() {
252        // TODO: convert model to a more memory efficient data structure.
253        HashMap<String, HashMap<String, HashMap<String, Integer > > > model =
254                new HashMap<String, HashMap<String, HashMap<String, Integer > > >();
255        for(Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
256            model.put(entry.getKey(), entry.getValue().getCounter());
257        }
258
259        try {
260            ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
261            ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
262            objStream.writeObject(model);
263            byte[] bytes = byteStream.toByteArray();
264            //Log.i(TAG, "getModel: " + bytes);
265            return bytes;
266        } catch (IOException e) {
267            throw new RuntimeException("Can't get model");
268        }
269    }
270
271    /*
272     * set the prediction model from a model data in the format of byte array
273     */
274    public boolean setModel(final byte[] modelData) {
275        HashMap<String, HashMap<String, HashMap<String, Integer > > > model;
276
277        try {
278            ByteArrayInputStream input = new ByteArrayInputStream(modelData);
279            ObjectInputStream objStream = new ObjectInputStream(input);
280            model = (HashMap<String, HashMap<String, HashMap<String, Integer > > >)
281                    objStream.readObject();
282        } catch (IOException e) {
283            throw new RuntimeException("Can't load model");
284        } catch (ClassNotFoundException e) {
285            throw new RuntimeException("Learning class not found");
286        }
287
288        resetPredictor();
289        for (Map.Entry<String, HashMap<String, HashMap<String, Integer> > > entry :
290                model.entrySet()) {
291            useFeature(entry.getKey());
292            mPredictor.get(entry.getKey()).setCounter(entry.getValue());
293        }
294
295        // TODO: this is a temporary fix for now
296        loadClassCounter();
297
298        return true;
299    }
300
301    private void loadClassCounter() {
302        String TIME_OF_WEEK = "Time of Week";
303
304        if (!mPredictor.containsKey(TIME_OF_WEEK)) {
305            throw new RuntimeException("Precition model error: missing Time of Week!");
306        }
307
308        HashMap<String, HashMap<String, Integer> > counter =
309            mPredictor.get(TIME_OF_WEEK).getCounter();
310
311        mTotalClassCount = 0;
312        mClassCounts.clear();
313        for (HashMap<String, Integer> map : counter.values()) {
314            for (Map.Entry<String, Integer> entry : map.entrySet()) {
315                int classCount = entry.getValue();
316                String className = entry.getKey();
317                mTotalClassCount += classCount;
318
319                if (mClassCounts.containsKey(className)) {
320                    classCount += mClassCounts.get(className);
321                }
322                mClassCounts.put(className, classCount);
323            }
324        }
325
326        Log.e(TAG, "class counts: " + mClassCounts + ", total count: " +
327              mTotalClassCount);
328    }
329}
330