HistogramPredictor.java revision 1253e9fb0b5570ab8adaed222655a5b052aa072e
1/* 2 * Copyright (C) 2011 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18package android.bordeaux.learning; 19 20import android.util.Log; 21import android.util.Pair; 22 23import java.io.ByteArrayInputStream; 24import java.io.ByteArrayOutputStream; 25import java.io.IOException; 26import java.io.ObjectInputStream; 27import java.io.ObjectOutputStream; 28import java.io.Serializable; 29import java.util.ArrayList; 30import java.util.Collections; 31import java.util.Comparator; 32import java.util.HashMap; 33import java.util.HashSet; 34import java.util.Iterator; 35import java.util.List; 36import java.util.Map; 37import java.util.Map.Entry; 38 39/** 40 * A histogram based predictor which records co-occurrences of applations with a speficic feature, 41 * for example, location, * time of day, etc. The histogram is kept in a two level hash table. 42 * The first level key is the feature value and the second level key is the app id. 43 */ 44// TODOS: 45// 1. Use forgetting factor to downweight istances propotional to the time 46// 2. Different features could have different weights on prediction scores. 47// 3. Make prediction (on each feature) only when the histogram has collected 48// sufficient counts. 49public class HistogramPredictor { 50 final static String TAG = "HistogramPredictor"; 51 52 private HashMap<String, HistogramCounter> mPredictor = 53 new HashMap<String, HistogramCounter>(); 54 55 private HashMap<String, Integer> mClassCounts = new HashMap<String, Integer>(); 56 private int mTotalClassCount = 0; 57 58 private static final double FEATURE_INACTIVE_LIKELIHOOD = 0.00000001; 59 private static final double LOG_INACTIVE = Math.log(FEATURE_INACTIVE_LIKELIHOOD); 60 61 /* 62 * This class keeps the histogram counts for each feature and provide the 63 * joint probabilities of <feature, class>. 64 */ 65 private class HistogramCounter { 66 private HashMap<String, HashMap<String, Integer> > mCounter = 67 new HashMap<String, HashMap<String, Integer> >(); 68 private int mTotalCount; 69 70 public HistogramCounter() { 71 resetCounter(); 72 } 73 74 public void setCounter(HashMap<String, HashMap<String, Integer> > counter) { 75 resetCounter(); 76 mCounter.putAll(counter); 77 78 // get total count 79 for (Map.Entry<String, HashMap<String, Integer> > entry : counter.entrySet()) { 80 for (Integer value : entry.getValue().values()) { 81 mTotalCount += value.intValue(); 82 } 83 } 84 } 85 86 public void resetCounter() { 87 mCounter.clear(); 88 mTotalCount = 0; 89 } 90 91 public void addSample(String className, String featureValue) { 92 HashMap<String, Integer> classCounts; 93 94 if (!mCounter.containsKey(featureValue)) { 95 classCounts = new HashMap<String, Integer>(); 96 mCounter.put(featureValue, classCounts); 97 } else { 98 classCounts = mCounter.get(featureValue); 99 } 100 int count = (classCounts.containsKey(className)) ? 101 classCounts.get(className) + 1 : 1; 102 classCounts.put(className, count); 103 mTotalCount++; 104 } 105 106 public HashMap<String, Double> getClassScores(String featureValue) { 107 HashMap<String, Double> classScores = new HashMap<String, Double>(); 108 109 double logTotalCount = Math.log((double) mTotalCount); 110 if (mCounter.containsKey(featureValue)) { 111 for(Map.Entry<String, Integer> entry : 112 mCounter.get(featureValue).entrySet()) { 113 double score = 114 Math.log((double) entry.getValue()) - logTotalCount; 115 classScores.put(entry.getKey(), score); 116 } 117 } 118 return classScores; 119 } 120 121 public HashMap<String, HashMap<String, Integer> > getCounter() { 122 return mCounter; 123 } 124 } 125 126 private double getDefaultLikelihood(Map<String, String> features) { 127 int featureCount = 0; 128 129 for(String featureName : features.keySet()) { 130 if (mPredictor.containsKey(featureName)) { 131 featureCount++; 132 } 133 } 134 return LOG_INACTIVE * featureCount; 135 } 136 137 /* 138 * Given a map of feature name -value pairs returns the mostly likely apps to 139 * be launched with corresponding likelihoods. 140 */ 141 public List<Map.Entry<String, Double> > findTopClasses(Map<String, String> features, int topK) { 142 // Most sophisticated function in this class 143 HashMap<String, Double> appScores = new HashMap<String, Double>(); 144 double defaultLikelihood = getDefaultLikelihood(features); 145 146 HashMap<String, Integer> appearCounts = new HashMap<String, Integer>(); 147 148 // compute all app scores 149 for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) { 150 String featureName = entry.getKey(); 151 HistogramCounter counter = entry.getValue(); 152 153 if (features.containsKey(featureName)) { 154 String featureValue = features.get(featureName); 155 HashMap<String, Double> scoreMap = counter.getClassScores(featureValue); 156 157 for (Map.Entry<String, Double> item : scoreMap.entrySet()) { 158 String appName = item.getKey(); 159 double appScore = item.getValue(); 160 double score = (appScores.containsKey(appName)) ? 161 appScores.get(appName) : defaultLikelihood; 162 score += appScore - LOG_INACTIVE; 163 appScores.put(appName, score); 164 165 int count = (appearCounts.containsKey(appName)) ? 166 appearCounts.get(appName) + 1 : 1; 167 appearCounts.put(appName, count); 168 } 169 } 170 } 171 172 // TODO: this check should be unnecessary 173 if (mClassCounts.size() != 0 && mTotalClassCount != 0) { 174 for (Map.Entry<String, Double> entry : appScores.entrySet()) { 175 String appName = entry.getKey(); 176 double appScore = entry.getValue(); 177 if (!appearCounts.containsKey(appName)) { 178 throw new RuntimeException("appearance count error!"); 179 } 180 int appearCount = appearCounts.get(appName); 181 182 if (!mClassCounts.containsKey(appName)) { 183 throw new RuntimeException("class count error!"); 184 } 185 double appPrior = 186 Math.log(mClassCounts.get(appName)) - Math.log(mTotalClassCount); 187 appScores.put(appName, appScore - appPrior * (appearCount - 1)); 188 } 189 } 190 191 // sort app scores 192 List<Map.Entry<String, Double> > appList = 193 new ArrayList<Map.Entry<String, Double> >(appScores.size()); 194 appList.addAll(appScores.entrySet()); 195 Collections.sort(appList, new Comparator<Map.Entry<String, Double> >() { 196 public int compare(Map.Entry<String, Double> o1, 197 Map.Entry<String, Double> o2) { 198 return o2.getValue().compareTo(o1.getValue()); 199 } 200 }); 201 202 Log.v(TAG, "findTopApps appList: " + appList); 203 return appList; 204 } 205 206 /* 207 * Add a new observation of given sample id and features to the histograms 208 */ 209 public void addSample(String sampleId, Map<String, String> features) { 210 for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) { 211 String featureName = entry.getKey(); 212 HistogramCounter counter = entry.getValue(); 213 214 if (features.containsKey(featureName)) { 215 String featureValue = features.get(featureName); 216 counter.addSample(sampleId, featureValue); 217 } 218 } 219 220 int sampleCount = (mClassCounts.containsKey(sampleId)) ? 221 mClassCounts.get(sampleId) + 1 : 1; 222 mClassCounts.put(sampleId, sampleCount); 223 } 224 225 /* 226 * reset predictor to a empty model 227 */ 228 public void resetPredictor() { 229 // TODO: not sure this step would reduce memory waste 230 for (HistogramCounter counter : mPredictor.values()) { 231 counter.resetCounter(); 232 } 233 mPredictor.clear(); 234 235 mClassCounts.clear(); 236 mTotalClassCount = 0; 237 } 238 239 /* 240 * specify a feature to used for prediction 241 */ 242 public void useFeature(String featureName) { 243 if (!mPredictor.containsKey(featureName)) { 244 mPredictor.put(featureName, new HistogramCounter()); 245 } 246 } 247 248 /* 249 * convert the prediction model into a byte array 250 */ 251 public byte[] getModel() { 252 // TODO: convert model to a more memory efficient data structure. 253 HashMap<String, HashMap<String, HashMap<String, Integer > > > model = 254 new HashMap<String, HashMap<String, HashMap<String, Integer > > >(); 255 for(Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) { 256 model.put(entry.getKey(), entry.getValue().getCounter()); 257 } 258 259 try { 260 ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); 261 ObjectOutputStream objStream = new ObjectOutputStream(byteStream); 262 objStream.writeObject(model); 263 byte[] bytes = byteStream.toByteArray(); 264 //Log.i(TAG, "getModel: " + bytes); 265 return bytes; 266 } catch (IOException e) { 267 throw new RuntimeException("Can't get model"); 268 } 269 } 270 271 /* 272 * set the prediction model from a model data in the format of byte array 273 */ 274 public boolean setModel(final byte[] modelData) { 275 HashMap<String, HashMap<String, HashMap<String, Integer > > > model; 276 277 try { 278 ByteArrayInputStream input = new ByteArrayInputStream(modelData); 279 ObjectInputStream objStream = new ObjectInputStream(input); 280 model = (HashMap<String, HashMap<String, HashMap<String, Integer > > >) 281 objStream.readObject(); 282 } catch (IOException e) { 283 throw new RuntimeException("Can't load model"); 284 } catch (ClassNotFoundException e) { 285 throw new RuntimeException("Learning class not found"); 286 } 287 288 resetPredictor(); 289 for (Map.Entry<String, HashMap<String, HashMap<String, Integer> > > entry : 290 model.entrySet()) { 291 useFeature(entry.getKey()); 292 mPredictor.get(entry.getKey()).setCounter(entry.getValue()); 293 } 294 295 // TODO: this is a temporary fix for now 296 loadClassCounter(); 297 298 return true; 299 } 300 301 private void loadClassCounter() { 302 String TIME_OF_WEEK = "Time of Week"; 303 304 if (!mPredictor.containsKey(TIME_OF_WEEK)) { 305 throw new RuntimeException("Precition model error: missing Time of Week!"); 306 } 307 308 HashMap<String, HashMap<String, Integer> > counter = 309 mPredictor.get(TIME_OF_WEEK).getCounter(); 310 311 mTotalClassCount = 0; 312 mClassCounts.clear(); 313 for (HashMap<String, Integer> map : counter.values()) { 314 for (Map.Entry<String, Integer> entry : map.entrySet()) { 315 int classCount = entry.getValue(); 316 String className = entry.getKey(); 317 mTotalClassCount += classCount; 318 319 if (mClassCounts.containsKey(className)) { 320 classCount += mClassCounts.get(className); 321 } 322 mClassCounts.put(className, classCount); 323 } 324 } 325 326 Log.e(TAG, "class counts: " + mClassCounts + ", total count: " + 327 mTotalClassCount); 328 } 329} 330