HistogramPredictor.java revision 5d42ffa9462f87edbbdc61a8719f6c521c700de5
1/* 2 * Copyright (C) 2011 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18package android.bordeaux.learning; 19 20import android.util.Log; 21import android.util.Pair; 22 23import java.io.ByteArrayInputStream; 24import java.io.ByteArrayOutputStream; 25import java.io.IOException; 26import java.io.ObjectInputStream; 27import java.io.ObjectOutputStream; 28import java.io.Serializable; 29import java.util.ArrayList; 30import java.util.Collections; 31import java.util.Comparator; 32import java.util.HashMap; 33import java.util.HashSet; 34import java.util.Iterator; 35import java.util.List; 36import java.util.Map; 37import java.util.Map.Entry; 38 39 40/** 41 * A histogram based predictor which records co-occurrences of applations with a speficic feature, 42 * for example, location, * time of day, etc. The histogram is kept in a two level hash table. 43 * The first level key is the feature value and the second level key is the app id. 44 */ 45 46// TODO: use forgetting factor to downweight istances propotional to the time 47// difference between the occurrance and now. 48public class HistogramPredictor { 49 final static String TAG = "HistogramPredictor"; 50 51 private HashMap<String, HistogramCounter> mPredictor = 52 new HashMap<String, HistogramCounter>(); 53 54 private static final double FEATURE_INACTIVE_LIKELIHOOD = 0.00000001; 55 private final double logInactive = Math.log(FEATURE_INACTIVE_LIKELIHOOD); 56 57 /* 58 * This class keeps the histogram counts for each feature and provide the 59 * joint probabilities of <feature, class>. 60 */ 61 private class HistogramCounter { 62 private HashMap<String, HashMap<String, Integer> > mCounter = 63 new HashMap<String, HashMap<String, Integer> >(); 64 private int mTotalCount; 65 66 public HistogramCounter() { 67 resetCounter(); 68 } 69 70 public void setCounter(HashMap<String, HashMap<String, Integer> > counter) { 71 resetCounter(); 72 mCounter.putAll(counter); 73 74 // get total count 75 for (Map.Entry<String, HashMap<String, Integer> > entry : counter.entrySet()) { 76 for (Integer value : entry.getValue().values()) { 77 mTotalCount += value.intValue(); 78 } 79 } 80 } 81 82 public void resetCounter() { 83 mCounter.clear(); 84 mTotalCount = 0; 85 } 86 87 public void addSample(String className, String featureValue) { 88 HashMap<String, Integer> classCounts; 89 90 if (!mCounter.containsKey(featureValue)) { 91 classCounts = new HashMap<String, Integer>(); 92 mCounter.put(featureValue, classCounts); 93 } 94 classCounts = mCounter.get(featureValue); 95 96 int count = (classCounts.containsKey(className)) ? 97 classCounts.get(className) + 1 : 1; 98 classCounts.put(className, count); 99 mTotalCount++; 100 } 101 102 public HashMap<String, Double> getClassScores(String featureValue) { 103 HashMap<String, Double> classScores = new HashMap<String, Double>(); 104 105 double logTotalCount = Math.log((double) mTotalCount); 106 if (mCounter.containsKey(featureValue)) { 107 for(Map.Entry<String, Integer> entry : 108 mCounter.get(featureValue).entrySet()) { 109 double score = 110 Math.log((double) entry.getValue()) - logTotalCount; 111 classScores.put(entry.getKey(), score); 112 } 113 } 114 return classScores; 115 } 116 117 public HashMap<String, HashMap<String, Integer> > getCounter() { 118 return mCounter; 119 } 120 } 121 122 /* 123 * Given a map of feature name -value pairs returns the mostly likely apps to 124 * be launched with corresponding likelihoods. 125 */ 126 public List<Map.Entry<String, Double> > findTopClasses(Map<String, String> features, int topK) { 127 // Most sophisticated function in this class 128 HashMap<String, Double> appScores = new HashMap<String, Double>(); 129 double defaultLikelihood = mPredictor.size() * logInactive; 130 131 // compute all app scores 132 for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) { 133 String featureName = entry.getKey(); 134 HistogramCounter counter = entry.getValue(); 135 136 if (features.containsKey(featureName)) { 137 String featureValue = features.get(featureName); 138 HashMap<String, Double> scoreMap = counter.getClassScores(featureValue); 139 140 for (Map.Entry<String, Double> item : scoreMap.entrySet()) { 141 String appName = item.getKey(); 142 double appScore = item.getValue(); 143 144 double score = (appScores.containsKey(appName)) ? 145 appScores.get(appName) : defaultLikelihood; 146 score += appScore - logInactive; 147 148 appScores.put(appName, score); 149 } 150 } 151 } 152 153 // sort app scores 154 List<Map.Entry<String, Double> > appList = 155 new ArrayList<Map.Entry<String, Double> >(appScores.size()); 156 appList.addAll(appScores.entrySet()); 157 Collections.sort(appList, new Comparator<Map.Entry<String, Double> >() { 158 public int compare(Map.Entry<String, Double> o1, 159 Map.Entry<String, Double> o2) { 160 return o2.getValue().compareTo(o1.getValue()); 161 } 162 }); 163 164 Log.e(TAG, "findTopApps appList: " + appList); 165 return appList; 166 } 167 168 /* 169 * Add a new observation of given sample id and features to the histograms 170 */ 171 public void addSample(String sampleId, Map<String, String> features) { 172 for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) { 173 String featureName = entry.getKey(); 174 HistogramCounter counter = entry.getValue(); 175 176 if (features.containsKey(featureName)) { 177 String featureValue = features.get(featureName); 178 counter.addSample(sampleId, featureValue); 179 } 180 } 181 } 182 183 /* 184 * reset predictor to a empty model 185 */ 186 public void resetPredictor() { 187 // TODO: not sure this step would reduce memory waste 188 for (HistogramCounter counter : mPredictor.values()) { 189 counter.resetCounter(); 190 } 191 mPredictor.clear(); 192 } 193 194 /* 195 * specify a feature to used for prediction 196 */ 197 public void useFeature(String featureName) { 198 if (!mPredictor.containsKey(featureName)) { 199 mPredictor.put(featureName, new HistogramCounter()); 200 } 201 } 202 203 /* 204 * convert the prediction model into a byte array 205 */ 206 public byte[] getModel() { 207 // TODO: convert model to a more memory efficient data structure. 208 HashMap<String, HashMap<String, HashMap<String, Integer > > > model = 209 new HashMap<String, HashMap<String, HashMap<String, Integer > > >(); 210 for(Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) { 211 model.put(entry.getKey(), entry.getValue().getCounter()); 212 } 213 214 try { 215 ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); 216 ObjectOutputStream objStream = new ObjectOutputStream(byteStream); 217 objStream.writeObject(model); 218 byte[] bytes = byteStream.toByteArray(); 219 //Log.i(TAG, "getModel: " + bytes); 220 return bytes; 221 } catch (IOException e) { 222 throw new RuntimeException("Can't get model"); 223 } 224 } 225 226 /* 227 * set the prediction model from a model data in the format of byte array 228 */ 229 public boolean setModel(final byte[] modelData) { 230 HashMap<String, HashMap<String, HashMap<String, Integer > > > model; 231 232 try { 233 ByteArrayInputStream input = new ByteArrayInputStream(modelData); 234 ObjectInputStream objStream = new ObjectInputStream(input); 235 model = (HashMap<String, HashMap<String, HashMap<String, Integer > > >) 236 objStream.readObject(); 237 } catch (IOException e) { 238 throw new RuntimeException("Can't load model"); 239 } catch (ClassNotFoundException e) { 240 throw new RuntimeException("Learning class not found"); 241 } 242 243 resetPredictor(); 244 for (Map.Entry<String, HashMap<String, HashMap<String, Integer> > > entry : 245 model.entrySet()) { 246 useFeature(entry.getKey()); 247 mPredictor.get(entry.getKey()).setCounter(entry.getValue()); 248 } 249 return true; 250 } 251} 252