1b019e89cbea221598c482b05ab68b7660b41aa23saberian/* 2b019e89cbea221598c482b05ab68b7660b41aa23saberian * Copyright (C) 2012 The Android Open Source Project 3b019e89cbea221598c482b05ab68b7660b41aa23saberian * 4b019e89cbea221598c482b05ab68b7660b41aa23saberian * Licensed under the Apache License, Version 2.0 (the "License"); 5b019e89cbea221598c482b05ab68b7660b41aa23saberian * you may not use this file except in compliance with the License. 6b019e89cbea221598c482b05ab68b7660b41aa23saberian * You may obtain a copy of the License at 7b019e89cbea221598c482b05ab68b7660b41aa23saberian * 8b019e89cbea221598c482b05ab68b7660b41aa23saberian * http://www.apache.org/licenses/LICENSE-2.0 9b019e89cbea221598c482b05ab68b7660b41aa23saberian * 10b019e89cbea221598c482b05ab68b7660b41aa23saberian * Unless required by applicable law or agreed to in writing, software 11b019e89cbea221598c482b05ab68b7660b41aa23saberian * distributed under the License is distributed on an "AS IS" BASIS, 12b019e89cbea221598c482b05ab68b7660b41aa23saberian * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13b019e89cbea221598c482b05ab68b7660b41aa23saberian * See the License for the specific language governing permissions and 14b019e89cbea221598c482b05ab68b7660b41aa23saberian * limitations under the License. 15b019e89cbea221598c482b05ab68b7660b41aa23saberian */ 16b019e89cbea221598c482b05ab68b7660b41aa23saberian 17b019e89cbea221598c482b05ab68b7660b41aa23saberianpackage android.bordeaux.services; 18b019e89cbea221598c482b05ab68b7660b41aa23saberianimport android.util.Log; 19b019e89cbea221598c482b05ab68b7660b41aa23saberian 20b019e89cbea221598c482b05ab68b7660b41aa23saberianimport android.bordeaux.learning.StochasticLinearRanker; 21b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.HashMap; 22b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.Map; 23b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.io.Serializable; 24b019e89cbea221598c482b05ab68b7660b41aa23saberian 25b019e89cbea221598c482b05ab68b7660b41aa23saberianpublic class StochasticLinearRankerWithPrior extends StochasticLinearRanker { 26b019e89cbea221598c482b05ab68b7660b41aa23saberian private final String TAG = "StochasticLinearRankerWithPrior"; 27b019e89cbea221598c482b05ab68b7660b41aa23saberian private final float EPSILON = 0.0001f; 28b019e89cbea221598c482b05ab68b7660b41aa23saberian 29b019e89cbea221598c482b05ab68b7660b41aa23saberian /* If the is parameter is true, the final score would be a 30b019e89cbea221598c482b05ab68b7660b41aa23saberian linear combination of user model and prior model */ 31b019e89cbea221598c482b05ab68b7660b41aa23saberian private final String USE_PRIOR = "usePriorInformation"; 32b019e89cbea221598c482b05ab68b7660b41aa23saberian 33b019e89cbea221598c482b05ab68b7660b41aa23saberian /* When prior model is used, this parmaeter will set the mixing factor, alpha. */ 34b019e89cbea221598c482b05ab68b7660b41aa23saberian private final String SET_ALPHA = "setAlpha"; 35b019e89cbea221598c482b05ab68b7660b41aa23saberian 36b019e89cbea221598c482b05ab68b7660b41aa23saberian /* When prior model is used, If this parameter is true then algorithm will use 37b019e89cbea221598c482b05ab68b7660b41aa23saberian the automatic cross validated alpha for mixing user model and prior model */ 38b019e89cbea221598c482b05ab68b7660b41aa23saberian private final String USE_AUTO_ALPHA = "useAutoAlpha"; 39b019e89cbea221598c482b05ab68b7660b41aa23saberian 40b019e89cbea221598c482b05ab68b7660b41aa23saberian /* When automatic cross validation is active, this parameter will 41b019e89cbea221598c482b05ab68b7660b41aa23saberian set the forget rate in cross validation. */ 42b019e89cbea221598c482b05ab68b7660b41aa23saberian private final String SET_FORGET_RATE = "setForgetRate"; 43b019e89cbea221598c482b05ab68b7660b41aa23saberian 44b019e89cbea221598c482b05ab68b7660b41aa23saberian /* When automatic cross validation is active, this parameter will 45b019e89cbea221598c482b05ab68b7660b41aa23saberian set the minium number of required training pairs before using the user model */ 46b019e89cbea221598c482b05ab68b7660b41aa23saberian private final String SET_MIN_TRAIN_PAIR = "setMinTrainingPair"; 47b019e89cbea221598c482b05ab68b7660b41aa23saberian 48b019e89cbea221598c482b05ab68b7660b41aa23saberian private final String SET_USER_PERF = "setUserPerformance"; 49b019e89cbea221598c482b05ab68b7660b41aa23saberian private final String SET_PRIOR_PERF = "setPriorPerformance"; 50b019e89cbea221598c482b05ab68b7660b41aa23saberian private final String SET_NUM_TRAIN_PAIR = "setNumberTrainingPairs"; 51b019e89cbea221598c482b05ab68b7660b41aa23saberian private final String SET_AUTO_ALPHA = "setAutoAlpha"; 52b019e89cbea221598c482b05ab68b7660b41aa23saberian 53b019e89cbea221598c482b05ab68b7660b41aa23saberian 54b019e89cbea221598c482b05ab68b7660b41aa23saberian 55b019e89cbea221598c482b05ab68b7660b41aa23saberian private HashMap<String, Float> mPriorWeights = new HashMap<String, Float>(); 56b019e89cbea221598c482b05ab68b7660b41aa23saberian private float mAlpha = 0; 57b019e89cbea221598c482b05ab68b7660b41aa23saberian private float mAutoAlpha = 0; 58b019e89cbea221598c482b05ab68b7660b41aa23saberian private float mForgetRate = 0; 59b019e89cbea221598c482b05ab68b7660b41aa23saberian private float mUserRankerPerf = 0; 60b019e89cbea221598c482b05ab68b7660b41aa23saberian private float mPriorRankerPerf = 0; 61b019e89cbea221598c482b05ab68b7660b41aa23saberian private int mMinReqTrainingPair = 0; 62b019e89cbea221598c482b05ab68b7660b41aa23saberian private int mNumTrainPair = 0; 63b019e89cbea221598c482b05ab68b7660b41aa23saberian private boolean mUsePrior = false; 64b019e89cbea221598c482b05ab68b7660b41aa23saberian private boolean mUseAutoAlpha = false; 65b019e89cbea221598c482b05ab68b7660b41aa23saberian 66b019e89cbea221598c482b05ab68b7660b41aa23saberian static public class Model implements Serializable { 67b019e89cbea221598c482b05ab68b7660b41aa23saberian public StochasticLinearRanker.Model uModel = new StochasticLinearRanker.Model(); 68b019e89cbea221598c482b05ab68b7660b41aa23saberian public HashMap<String, Float> priorWeights = new HashMap<String, Float>(); 69b019e89cbea221598c482b05ab68b7660b41aa23saberian public HashMap<String, String> priorParameters = new HashMap<String, String>(); 70b019e89cbea221598c482b05ab68b7660b41aa23saberian } 71b019e89cbea221598c482b05ab68b7660b41aa23saberian 72b019e89cbea221598c482b05ab68b7660b41aa23saberian @Override 73b019e89cbea221598c482b05ab68b7660b41aa23saberian public void resetRanker(){ 74b019e89cbea221598c482b05ab68b7660b41aa23saberian super.resetRanker(); 75b019e89cbea221598c482b05ab68b7660b41aa23saberian mPriorWeights.clear(); 76b019e89cbea221598c482b05ab68b7660b41aa23saberian mAlpha = 0; 77b019e89cbea221598c482b05ab68b7660b41aa23saberian mAutoAlpha = 0; 78b019e89cbea221598c482b05ab68b7660b41aa23saberian mForgetRate = 0; 79b019e89cbea221598c482b05ab68b7660b41aa23saberian mMinReqTrainingPair = 0; 80b019e89cbea221598c482b05ab68b7660b41aa23saberian mUserRankerPerf = 0; 81b019e89cbea221598c482b05ab68b7660b41aa23saberian mPriorRankerPerf = 0; 82b019e89cbea221598c482b05ab68b7660b41aa23saberian mNumTrainPair = 0; 83b019e89cbea221598c482b05ab68b7660b41aa23saberian mUsePrior = false; 84b019e89cbea221598c482b05ab68b7660b41aa23saberian mUseAutoAlpha = false; 85b019e89cbea221598c482b05ab68b7660b41aa23saberian } 86b019e89cbea221598c482b05ab68b7660b41aa23saberian 87b019e89cbea221598c482b05ab68b7660b41aa23saberian @Override 88b019e89cbea221598c482b05ab68b7660b41aa23saberian public float scoreSample(String[] keys, float[] values) { 89b019e89cbea221598c482b05ab68b7660b41aa23saberian if (!mUsePrior){ 90b019e89cbea221598c482b05ab68b7660b41aa23saberian return super.scoreSample(keys, values); 91b019e89cbea221598c482b05ab68b7660b41aa23saberian } else { 92b019e89cbea221598c482b05ab68b7660b41aa23saberian if (mUseAutoAlpha) { 93b019e89cbea221598c482b05ab68b7660b41aa23saberian if (mNumTrainPair > mMinReqTrainingPair) 94b019e89cbea221598c482b05ab68b7660b41aa23saberian return (1 - mAutoAlpha) * super.scoreSample(keys,values) + 95b019e89cbea221598c482b05ab68b7660b41aa23saberian mAutoAlpha * priorScoreSample(keys,values); 96b019e89cbea221598c482b05ab68b7660b41aa23saberian else 97b019e89cbea221598c482b05ab68b7660b41aa23saberian return priorScoreSample(keys,values); 98b019e89cbea221598c482b05ab68b7660b41aa23saberian } else 99b019e89cbea221598c482b05ab68b7660b41aa23saberian return (1 - mAlpha) * super.scoreSample(keys,values) + 100b019e89cbea221598c482b05ab68b7660b41aa23saberian mAlpha * priorScoreSample(keys,values); 101b019e89cbea221598c482b05ab68b7660b41aa23saberian } 102b019e89cbea221598c482b05ab68b7660b41aa23saberian } 103b019e89cbea221598c482b05ab68b7660b41aa23saberian 104b019e89cbea221598c482b05ab68b7660b41aa23saberian public float priorScoreSample(String[] keys, float[] values) { 105b019e89cbea221598c482b05ab68b7660b41aa23saberian float score = 0; 106b019e89cbea221598c482b05ab68b7660b41aa23saberian for (int i=0; i< keys.length; i++){ 107b019e89cbea221598c482b05ab68b7660b41aa23saberian if (mPriorWeights.get(keys[i]) != null ) 108b019e89cbea221598c482b05ab68b7660b41aa23saberian score = score + mPriorWeights.get(keys[i]) * values[i]; 109b019e89cbea221598c482b05ab68b7660b41aa23saberian } 110b019e89cbea221598c482b05ab68b7660b41aa23saberian return score; 111b019e89cbea221598c482b05ab68b7660b41aa23saberian } 112b019e89cbea221598c482b05ab68b7660b41aa23saberian 113b019e89cbea221598c482b05ab68b7660b41aa23saberian @Override 114b019e89cbea221598c482b05ab68b7660b41aa23saberian public boolean updateClassifier(String[] keys_positive, 115b019e89cbea221598c482b05ab68b7660b41aa23saberian float[] values_positive, 116b019e89cbea221598c482b05ab68b7660b41aa23saberian String[] keys_negative, 117b019e89cbea221598c482b05ab68b7660b41aa23saberian float[] values_negative){ 118b019e89cbea221598c482b05ab68b7660b41aa23saberian if (mUsePrior && mUseAutoAlpha && (mNumTrainPair > mMinReqTrainingPair)) 119b019e89cbea221598c482b05ab68b7660b41aa23saberian updateAutoAlpha(keys_positive, values_positive, keys_negative, values_negative); 120b019e89cbea221598c482b05ab68b7660b41aa23saberian mNumTrainPair ++; 121b019e89cbea221598c482b05ab68b7660b41aa23saberian return super.updateClassifier(keys_positive, values_positive, 122b019e89cbea221598c482b05ab68b7660b41aa23saberian keys_negative, values_negative); 123b019e89cbea221598c482b05ab68b7660b41aa23saberian } 124b019e89cbea221598c482b05ab68b7660b41aa23saberian 125b019e89cbea221598c482b05ab68b7660b41aa23saberian void updateAutoAlpha(String[] keys_positive, 126b019e89cbea221598c482b05ab68b7660b41aa23saberian float[] values_positive, 127b019e89cbea221598c482b05ab68b7660b41aa23saberian String[] keys_negative, 128b019e89cbea221598c482b05ab68b7660b41aa23saberian float[] values_negative) { 129b019e89cbea221598c482b05ab68b7660b41aa23saberian float positiveUserScore = super.scoreSample(keys_positive, values_positive); 130b019e89cbea221598c482b05ab68b7660b41aa23saberian float negativeUserScore = super.scoreSample(keys_negative, values_negative); 131b019e89cbea221598c482b05ab68b7660b41aa23saberian float positivePriorScore = priorScoreSample(keys_positive, values_positive); 132b019e89cbea221598c482b05ab68b7660b41aa23saberian float negativePriorScore = priorScoreSample(keys_negative, values_negative); 133b019e89cbea221598c482b05ab68b7660b41aa23saberian float userDecision = 0; 134b019e89cbea221598c482b05ab68b7660b41aa23saberian float priorDecision = 0; 135b019e89cbea221598c482b05ab68b7660b41aa23saberian if (positiveUserScore > negativeUserScore) 136b019e89cbea221598c482b05ab68b7660b41aa23saberian userDecision = 1; 137b019e89cbea221598c482b05ab68b7660b41aa23saberian if (positivePriorScore > negativePriorScore) 138b019e89cbea221598c482b05ab68b7660b41aa23saberian priorDecision = 1; 139b019e89cbea221598c482b05ab68b7660b41aa23saberian mUserRankerPerf = (1 - mForgetRate) * mUserRankerPerf + userDecision; 140b019e89cbea221598c482b05ab68b7660b41aa23saberian mPriorRankerPerf = (1 - mForgetRate) * mPriorRankerPerf + priorDecision; 141b019e89cbea221598c482b05ab68b7660b41aa23saberian mAutoAlpha = (mPriorRankerPerf + EPSILON) / (mUserRankerPerf + mPriorRankerPerf + EPSILON); 142b019e89cbea221598c482b05ab68b7660b41aa23saberian } 143b019e89cbea221598c482b05ab68b7660b41aa23saberian 144b019e89cbea221598c482b05ab68b7660b41aa23saberian public Model getModel(){ 145b019e89cbea221598c482b05ab68b7660b41aa23saberian Model m = new Model(); 146b019e89cbea221598c482b05ab68b7660b41aa23saberian m.uModel = super.getUModel(); 147b019e89cbea221598c482b05ab68b7660b41aa23saberian m.priorWeights.putAll(mPriorWeights); 148b019e89cbea221598c482b05ab68b7660b41aa23saberian m.priorParameters.put(SET_ALPHA, String.valueOf(mAlpha)); 149b019e89cbea221598c482b05ab68b7660b41aa23saberian m.priorParameters.put(SET_AUTO_ALPHA, String.valueOf(mAutoAlpha)); 150b019e89cbea221598c482b05ab68b7660b41aa23saberian m.priorParameters.put(SET_FORGET_RATE, String.valueOf(mForgetRate)); 151b019e89cbea221598c482b05ab68b7660b41aa23saberian m.priorParameters.put(SET_MIN_TRAIN_PAIR, String.valueOf(mMinReqTrainingPair)); 152b019e89cbea221598c482b05ab68b7660b41aa23saberian m.priorParameters.put(SET_USER_PERF, String.valueOf(mUserRankerPerf)); 153b019e89cbea221598c482b05ab68b7660b41aa23saberian m.priorParameters.put(SET_PRIOR_PERF, String.valueOf(mPriorRankerPerf)); 154b019e89cbea221598c482b05ab68b7660b41aa23saberian m.priorParameters.put(SET_NUM_TRAIN_PAIR, String.valueOf(mNumTrainPair)); 155b019e89cbea221598c482b05ab68b7660b41aa23saberian m.priorParameters.put(USE_AUTO_ALPHA, String.valueOf(mUseAutoAlpha)); 156b019e89cbea221598c482b05ab68b7660b41aa23saberian m.priorParameters.put(USE_PRIOR, String.valueOf(mUsePrior)); 157b019e89cbea221598c482b05ab68b7660b41aa23saberian return m; 158b019e89cbea221598c482b05ab68b7660b41aa23saberian } 159b019e89cbea221598c482b05ab68b7660b41aa23saberian 160b019e89cbea221598c482b05ab68b7660b41aa23saberian public boolean loadModel(Model m) { 161b019e89cbea221598c482b05ab68b7660b41aa23saberian mPriorWeights.clear(); 162b019e89cbea221598c482b05ab68b7660b41aa23saberian mPriorWeights.putAll(m.priorWeights); 163b019e89cbea221598c482b05ab68b7660b41aa23saberian for (Map.Entry<String, String> e : m.priorParameters.entrySet()) { 164b019e89cbea221598c482b05ab68b7660b41aa23saberian boolean res = setModelParameter(e.getKey(), e.getValue()); 165b019e89cbea221598c482b05ab68b7660b41aa23saberian if (!res) return false; 166b019e89cbea221598c482b05ab68b7660b41aa23saberian } 167b019e89cbea221598c482b05ab68b7660b41aa23saberian return super.loadModel(m.uModel); 168b019e89cbea221598c482b05ab68b7660b41aa23saberian } 169b019e89cbea221598c482b05ab68b7660b41aa23saberian 170b019e89cbea221598c482b05ab68b7660b41aa23saberian public boolean setModelPriorWeights(HashMap<String, Float> pw){ 171b019e89cbea221598c482b05ab68b7660b41aa23saberian mPriorWeights.clear(); 172b019e89cbea221598c482b05ab68b7660b41aa23saberian mPriorWeights.putAll(pw); 173b019e89cbea221598c482b05ab68b7660b41aa23saberian return true; 174b019e89cbea221598c482b05ab68b7660b41aa23saberian } 175b019e89cbea221598c482b05ab68b7660b41aa23saberian 176b019e89cbea221598c482b05ab68b7660b41aa23saberian public boolean setModelParameter(String key, String value){ 177b019e89cbea221598c482b05ab68b7660b41aa23saberian if (key.equals(USE_AUTO_ALPHA)){ 178b019e89cbea221598c482b05ab68b7660b41aa23saberian mUseAutoAlpha = Boolean.parseBoolean(value); 179b019e89cbea221598c482b05ab68b7660b41aa23saberian } else if (key.equals(USE_PRIOR)){ 180b019e89cbea221598c482b05ab68b7660b41aa23saberian mUsePrior = Boolean.parseBoolean(value); 181b019e89cbea221598c482b05ab68b7660b41aa23saberian } else if (key.equals(SET_ALPHA)){ 182b019e89cbea221598c482b05ab68b7660b41aa23saberian mAlpha = Float.valueOf(value.trim()).floatValue(); 183b019e89cbea221598c482b05ab68b7660b41aa23saberian }else if (key.equals(SET_AUTO_ALPHA)){ 184b019e89cbea221598c482b05ab68b7660b41aa23saberian mAutoAlpha = Float.valueOf(value.trim()).floatValue(); 185b019e89cbea221598c482b05ab68b7660b41aa23saberian }else if (key.equals(SET_FORGET_RATE)){ 186b019e89cbea221598c482b05ab68b7660b41aa23saberian mForgetRate = Float.valueOf(value.trim()).floatValue(); 187b019e89cbea221598c482b05ab68b7660b41aa23saberian }else if (key.equals(SET_MIN_TRAIN_PAIR)){ 188b019e89cbea221598c482b05ab68b7660b41aa23saberian mMinReqTrainingPair = (int) Float.valueOf(value.trim()).floatValue(); 189b019e89cbea221598c482b05ab68b7660b41aa23saberian }else if (key.equals(SET_USER_PERF)){ 190b019e89cbea221598c482b05ab68b7660b41aa23saberian mUserRankerPerf = Float.valueOf(value.trim()).floatValue(); 191b019e89cbea221598c482b05ab68b7660b41aa23saberian }else if (key.equals(SET_PRIOR_PERF)){ 192b019e89cbea221598c482b05ab68b7660b41aa23saberian mPriorRankerPerf = Float.valueOf(value.trim()).floatValue(); 193b019e89cbea221598c482b05ab68b7660b41aa23saberian }else if (key.equals(SET_NUM_TRAIN_PAIR)){ 194b019e89cbea221598c482b05ab68b7660b41aa23saberian mNumTrainPair = (int) Float.valueOf(value.trim()).floatValue(); 195b019e89cbea221598c482b05ab68b7660b41aa23saberian }else 196b019e89cbea221598c482b05ab68b7660b41aa23saberian return super.setModelParameter(key, value); 197b019e89cbea221598c482b05ab68b7660b41aa23saberian return true; 198b019e89cbea221598c482b05ab68b7660b41aa23saberian } 199b019e89cbea221598c482b05ab68b7660b41aa23saberian 200b019e89cbea221598c482b05ab68b7660b41aa23saberian public void print(Model m){ 201b019e89cbea221598c482b05ab68b7660b41aa23saberian super.print(m.uModel); 202b019e89cbea221598c482b05ab68b7660b41aa23saberian String Spw = ""; 203b019e89cbea221598c482b05ab68b7660b41aa23saberian for (Map.Entry<String, Float> e : m.priorWeights.entrySet()) 204b019e89cbea221598c482b05ab68b7660b41aa23saberian Spw = Spw + "<" + e.getKey() + "," + e.getValue() + "> "; 205b019e89cbea221598c482b05ab68b7660b41aa23saberian Log.i(TAG, "Prior model is " + Spw); 206b019e89cbea221598c482b05ab68b7660b41aa23saberian String Spp = ""; 207b019e89cbea221598c482b05ab68b7660b41aa23saberian for (Map.Entry<String, String> e : m.priorParameters.entrySet()) 208b019e89cbea221598c482b05ab68b7660b41aa23saberian Spp = Spp + "<" + e.getKey() + "," + e.getValue() + "> "; 209b019e89cbea221598c482b05ab68b7660b41aa23saberian Log.i(TAG, "Prior parameters are " + Spp); 210b019e89cbea221598c482b05ab68b7660b41aa23saberian } 211b019e89cbea221598c482b05ab68b7660b41aa23saberian} 212