HistogramPredictor.java revision 5d42ffa9462f87edbbdc61a8719f6c521c700de5
1/*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18package android.bordeaux.learning;
19
20import android.util.Log;
21import android.util.Pair;
22
23import java.io.ByteArrayInputStream;
24import java.io.ByteArrayOutputStream;
25import java.io.IOException;
26import java.io.ObjectInputStream;
27import java.io.ObjectOutputStream;
28import java.io.Serializable;
29import java.util.ArrayList;
30import java.util.Collections;
31import java.util.Comparator;
32import java.util.HashMap;
33import java.util.HashSet;
34import java.util.Iterator;
35import java.util.List;
36import java.util.Map;
37import java.util.Map.Entry;
38
39
40/**
41 * A histogram based predictor which records co-occurrences of applations with a speficic feature,
42 * for example, location, * time of day, etc. The histogram is kept in a two level hash table.
43 * The first level key is the feature value and the second level key is the app id.
44 */
45
46// TODO: use forgetting factor to downweight istances propotional to the time
47// difference between the occurrance and now.
48public class HistogramPredictor {
49    final static String TAG = "HistogramPredictor";
50
51    private HashMap<String, HistogramCounter> mPredictor =
52            new HashMap<String, HistogramCounter>();
53
54    private static final double FEATURE_INACTIVE_LIKELIHOOD = 0.00000001;
55    private final double logInactive = Math.log(FEATURE_INACTIVE_LIKELIHOOD);
56
57    /*
58     * This class keeps the histogram counts for each feature and provide the
59     * joint probabilities of <feature, class>.
60     */
61    private class HistogramCounter {
62        private HashMap<String, HashMap<String, Integer> > mCounter =
63                new HashMap<String, HashMap<String, Integer> >();
64        private int mTotalCount;
65
66        public HistogramCounter() {
67            resetCounter();
68        }
69
70        public void setCounter(HashMap<String, HashMap<String, Integer> > counter) {
71            resetCounter();
72            mCounter.putAll(counter);
73
74            // get total count
75            for (Map.Entry<String, HashMap<String, Integer> > entry : counter.entrySet()) {
76                for (Integer value : entry.getValue().values()) {
77                    mTotalCount += value.intValue();
78                }
79            }
80        }
81
82        public void resetCounter() {
83            mCounter.clear();
84            mTotalCount = 0;
85        }
86
87        public void addSample(String className, String featureValue) {
88            HashMap<String, Integer> classCounts;
89
90            if (!mCounter.containsKey(featureValue)) {
91                classCounts = new HashMap<String, Integer>();
92                mCounter.put(featureValue, classCounts);
93            }
94            classCounts = mCounter.get(featureValue);
95
96            int count = (classCounts.containsKey(className)) ?
97                    classCounts.get(className) + 1 : 1;
98            classCounts.put(className, count);
99            mTotalCount++;
100        }
101
102        public HashMap<String, Double> getClassScores(String featureValue) {
103            HashMap<String, Double> classScores = new HashMap<String, Double>();
104
105            double logTotalCount = Math.log((double) mTotalCount);
106            if (mCounter.containsKey(featureValue)) {
107                for(Map.Entry<String, Integer> entry :
108                        mCounter.get(featureValue).entrySet()) {
109                    double score =
110                            Math.log((double) entry.getValue()) - logTotalCount;
111                    classScores.put(entry.getKey(), score);
112                }
113            }
114            return classScores;
115        }
116
117        public HashMap<String, HashMap<String, Integer> > getCounter() {
118            return mCounter;
119        }
120    }
121
122    /*
123     * Given a map of feature name -value pairs returns the mostly likely apps to
124     * be launched with corresponding likelihoods.
125     */
126    public List<Map.Entry<String, Double> > findTopClasses(Map<String, String> features, int topK) {
127        // Most sophisticated function in this class
128        HashMap<String, Double> appScores = new HashMap<String, Double>();
129        double defaultLikelihood = mPredictor.size() * logInactive;
130
131        // compute all app scores
132        for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
133            String featureName = entry.getKey();
134            HistogramCounter counter = entry.getValue();
135
136            if (features.containsKey(featureName)) {
137                String featureValue = features.get(featureName);
138                HashMap<String, Double> scoreMap = counter.getClassScores(featureValue);
139
140                for (Map.Entry<String, Double> item : scoreMap.entrySet()) {
141                    String appName = item.getKey();
142                    double appScore = item.getValue();
143
144                    double score = (appScores.containsKey(appName)) ?
145                        appScores.get(appName) : defaultLikelihood;
146                    score += appScore - logInactive;
147
148                    appScores.put(appName, score);
149                }
150            }
151        }
152
153        // sort app scores
154        List<Map.Entry<String, Double> > appList =
155               new ArrayList<Map.Entry<String, Double> >(appScores.size());
156        appList.addAll(appScores.entrySet());
157        Collections.sort(appList, new  Comparator<Map.Entry<String, Double> >() {
158            public int compare(Map.Entry<String, Double> o1,
159                               Map.Entry<String, Double> o2) {
160                return o2.getValue().compareTo(o1.getValue());
161            }
162        });
163
164        Log.e(TAG, "findTopApps appList: " + appList);
165        return appList;
166    }
167
168    /*
169     * Add a new observation of given sample id and features to the histograms
170     */
171    public void addSample(String sampleId, Map<String, String> features) {
172        for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
173            String featureName = entry.getKey();
174            HistogramCounter counter = entry.getValue();
175
176            if (features.containsKey(featureName)) {
177                String featureValue = features.get(featureName);
178                counter.addSample(sampleId, featureValue);
179            }
180        }
181    }
182
183    /*
184     * reset predictor to a empty model
185     */
186    public void resetPredictor() {
187        // TODO: not sure this step would reduce memory waste
188        for (HistogramCounter counter : mPredictor.values()) {
189            counter.resetCounter();
190        }
191        mPredictor.clear();
192    }
193
194    /*
195     * specify a feature to used for prediction
196     */
197    public void useFeature(String featureName) {
198        if (!mPredictor.containsKey(featureName)) {
199            mPredictor.put(featureName, new HistogramCounter());
200        }
201    }
202
203    /*
204     * convert the prediction model into a byte array
205     */
206    public byte[] getModel() {
207        // TODO: convert model to a more memory efficient data structure.
208        HashMap<String, HashMap<String, HashMap<String, Integer > > > model =
209                new HashMap<String, HashMap<String, HashMap<String, Integer > > >();
210        for(Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
211            model.put(entry.getKey(), entry.getValue().getCounter());
212        }
213
214        try {
215            ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
216            ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
217            objStream.writeObject(model);
218            byte[] bytes = byteStream.toByteArray();
219            //Log.i(TAG, "getModel: " + bytes);
220            return bytes;
221        } catch (IOException e) {
222            throw new RuntimeException("Can't get model");
223        }
224    }
225
226    /*
227     * set the prediction model from a model data in the format of byte array
228     */
229    public boolean setModel(final byte[] modelData) {
230        HashMap<String, HashMap<String, HashMap<String, Integer > > > model;
231
232        try {
233            ByteArrayInputStream input = new ByteArrayInputStream(modelData);
234            ObjectInputStream objStream = new ObjectInputStream(input);
235            model = (HashMap<String, HashMap<String, HashMap<String, Integer > > >)
236                    objStream.readObject();
237        } catch (IOException e) {
238            throw new RuntimeException("Can't load model");
239        } catch (ClassNotFoundException e) {
240            throw new RuntimeException("Learning class not found");
241        }
242
243        resetPredictor();
244        for (Map.Entry<String, HashMap<String, HashMap<String, Integer> > > entry :
245                model.entrySet()) {
246            useFeature(entry.getKey());
247            mPredictor.get(entry.getKey()).setCounter(entry.getValue());
248        }
249        return true;
250    }
251}
252