1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package android.bordeaux.services; 18import android.util.Log; 19 20import android.bordeaux.learning.StochasticLinearRanker; 21import java.util.HashMap; 22import java.util.Map; 23import java.io.Serializable; 24 25public class StochasticLinearRankerWithPrior extends StochasticLinearRanker { 26 private final String TAG = "StochasticLinearRankerWithPrior"; 27 private final float EPSILON = 0.0001f; 28 29 /* If the is parameter is true, the final score would be a 30 linear combination of user model and prior model */ 31 private final String USE_PRIOR = "usePriorInformation"; 32 33 /* When prior model is used, this parmaeter will set the mixing factor, alpha. */ 34 private final String SET_ALPHA = "setAlpha"; 35 36 /* When prior model is used, If this parameter is true then algorithm will use 37 the automatic cross validated alpha for mixing user model and prior model */ 38 private final String USE_AUTO_ALPHA = "useAutoAlpha"; 39 40 /* When automatic cross validation is active, this parameter will 41 set the forget rate in cross validation. */ 42 private final String SET_FORGET_RATE = "setForgetRate"; 43 44 /* When automatic cross validation is active, this parameter will 45 set the minium number of required training pairs before using the user model */ 46 private final String SET_MIN_TRAIN_PAIR = "setMinTrainingPair"; 47 48 private final String SET_USER_PERF = "setUserPerformance"; 49 private final String SET_PRIOR_PERF = "setPriorPerformance"; 50 private final String SET_NUM_TRAIN_PAIR = "setNumberTrainingPairs"; 51 private final String SET_AUTO_ALPHA = "setAutoAlpha"; 52 53 54 55 private HashMap<String, Float> mPriorWeights = new HashMap<String, Float>(); 56 private float mAlpha = 0; 57 private float mAutoAlpha = 0; 58 private float mForgetRate = 0; 59 private float mUserRankerPerf = 0; 60 private float mPriorRankerPerf = 0; 61 private int mMinReqTrainingPair = 0; 62 private int mNumTrainPair = 0; 63 private boolean mUsePrior = false; 64 private boolean mUseAutoAlpha = false; 65 66 static public class Model implements Serializable { 67 public StochasticLinearRanker.Model uModel = new StochasticLinearRanker.Model(); 68 public HashMap<String, Float> priorWeights = new HashMap<String, Float>(); 69 public HashMap<String, String> priorParameters = new HashMap<String, String>(); 70 } 71 72 @Override 73 public void resetRanker(){ 74 super.resetRanker(); 75 mPriorWeights.clear(); 76 mAlpha = 0; 77 mAutoAlpha = 0; 78 mForgetRate = 0; 79 mMinReqTrainingPair = 0; 80 mUserRankerPerf = 0; 81 mPriorRankerPerf = 0; 82 mNumTrainPair = 0; 83 mUsePrior = false; 84 mUseAutoAlpha = false; 85 } 86 87 @Override 88 public float scoreSample(String[] keys, float[] values) { 89 if (!mUsePrior){ 90 return super.scoreSample(keys, values); 91 } else { 92 if (mUseAutoAlpha) { 93 if (mNumTrainPair > mMinReqTrainingPair) 94 return (1 - mAutoAlpha) * super.scoreSample(keys,values) + 95 mAutoAlpha * priorScoreSample(keys,values); 96 else 97 return priorScoreSample(keys,values); 98 } else 99 return (1 - mAlpha) * super.scoreSample(keys,values) + 100 mAlpha * priorScoreSample(keys,values); 101 } 102 } 103 104 public float priorScoreSample(String[] keys, float[] values) { 105 float score = 0; 106 for (int i=0; i< keys.length; i++){ 107 if (mPriorWeights.get(keys[i]) != null ) 108 score = score + mPriorWeights.get(keys[i]) * values[i]; 109 } 110 return score; 111 } 112 113 @Override 114 public boolean updateClassifier(String[] keys_positive, 115 float[] values_positive, 116 String[] keys_negative, 117 float[] values_negative){ 118 if (mUsePrior && mUseAutoAlpha && (mNumTrainPair > mMinReqTrainingPair)) 119 updateAutoAlpha(keys_positive, values_positive, keys_negative, values_negative); 120 mNumTrainPair ++; 121 return super.updateClassifier(keys_positive, values_positive, 122 keys_negative, values_negative); 123 } 124 125 void updateAutoAlpha(String[] keys_positive, 126 float[] values_positive, 127 String[] keys_negative, 128 float[] values_negative) { 129 float positiveUserScore = super.scoreSample(keys_positive, values_positive); 130 float negativeUserScore = super.scoreSample(keys_negative, values_negative); 131 float positivePriorScore = priorScoreSample(keys_positive, values_positive); 132 float negativePriorScore = priorScoreSample(keys_negative, values_negative); 133 float userDecision = 0; 134 float priorDecision = 0; 135 if (positiveUserScore > negativeUserScore) 136 userDecision = 1; 137 if (positivePriorScore > negativePriorScore) 138 priorDecision = 1; 139 mUserRankerPerf = (1 - mForgetRate) * mUserRankerPerf + userDecision; 140 mPriorRankerPerf = (1 - mForgetRate) * mPriorRankerPerf + priorDecision; 141 mAutoAlpha = (mPriorRankerPerf + EPSILON) / (mUserRankerPerf + mPriorRankerPerf + EPSILON); 142 } 143 144 public Model getModel(){ 145 Model m = new Model(); 146 m.uModel = super.getUModel(); 147 m.priorWeights.putAll(mPriorWeights); 148 m.priorParameters.put(SET_ALPHA, String.valueOf(mAlpha)); 149 m.priorParameters.put(SET_AUTO_ALPHA, String.valueOf(mAutoAlpha)); 150 m.priorParameters.put(SET_FORGET_RATE, String.valueOf(mForgetRate)); 151 m.priorParameters.put(SET_MIN_TRAIN_PAIR, String.valueOf(mMinReqTrainingPair)); 152 m.priorParameters.put(SET_USER_PERF, String.valueOf(mUserRankerPerf)); 153 m.priorParameters.put(SET_PRIOR_PERF, String.valueOf(mPriorRankerPerf)); 154 m.priorParameters.put(SET_NUM_TRAIN_PAIR, String.valueOf(mNumTrainPair)); 155 m.priorParameters.put(USE_AUTO_ALPHA, String.valueOf(mUseAutoAlpha)); 156 m.priorParameters.put(USE_PRIOR, String.valueOf(mUsePrior)); 157 return m; 158 } 159 160 public boolean loadModel(Model m) { 161 mPriorWeights.clear(); 162 mPriorWeights.putAll(m.priorWeights); 163 for (Map.Entry<String, String> e : m.priorParameters.entrySet()) { 164 boolean res = setModelParameter(e.getKey(), e.getValue()); 165 if (!res) return false; 166 } 167 return super.loadModel(m.uModel); 168 } 169 170 public boolean setModelPriorWeights(HashMap<String, Float> pw){ 171 mPriorWeights.clear(); 172 mPriorWeights.putAll(pw); 173 return true; 174 } 175 176 public boolean setModelParameter(String key, String value){ 177 if (key.equals(USE_AUTO_ALPHA)){ 178 mUseAutoAlpha = Boolean.parseBoolean(value); 179 } else if (key.equals(USE_PRIOR)){ 180 mUsePrior = Boolean.parseBoolean(value); 181 } else if (key.equals(SET_ALPHA)){ 182 mAlpha = Float.valueOf(value.trim()).floatValue(); 183 }else if (key.equals(SET_AUTO_ALPHA)){ 184 mAutoAlpha = Float.valueOf(value.trim()).floatValue(); 185 }else if (key.equals(SET_FORGET_RATE)){ 186 mForgetRate = Float.valueOf(value.trim()).floatValue(); 187 }else if (key.equals(SET_MIN_TRAIN_PAIR)){ 188 mMinReqTrainingPair = (int) Float.valueOf(value.trim()).floatValue(); 189 }else if (key.equals(SET_USER_PERF)){ 190 mUserRankerPerf = Float.valueOf(value.trim()).floatValue(); 191 }else if (key.equals(SET_PRIOR_PERF)){ 192 mPriorRankerPerf = Float.valueOf(value.trim()).floatValue(); 193 }else if (key.equals(SET_NUM_TRAIN_PAIR)){ 194 mNumTrainPair = (int) Float.valueOf(value.trim()).floatValue(); 195 }else 196 return super.setModelParameter(key, value); 197 return true; 198 } 199 200 public void print(Model m){ 201 super.print(m.uModel); 202 String Spw = ""; 203 for (Map.Entry<String, Float> e : m.priorWeights.entrySet()) 204 Spw = Spw + "<" + e.getKey() + "," + e.getValue() + "> "; 205 Log.i(TAG, "Prior model is " + Spw); 206 String Spp = ""; 207 for (Map.Entry<String, String> e : m.priorParameters.entrySet()) 208 Spp = Spp + "<" + e.getKey() + "," + e.getValue() + "> "; 209 Log.i(TAG, "Prior parameters are " + Spp); 210 } 211} 212