1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.bordeaux.services;
18import android.util.Log;
19
20import android.bordeaux.learning.StochasticLinearRanker;
21import java.util.HashMap;
22import java.util.Map;
23import java.io.Serializable;
24
25public class StochasticLinearRankerWithPrior extends StochasticLinearRanker {
26    private final String TAG = "StochasticLinearRankerWithPrior";
27    private final float EPSILON = 0.0001f;
28
29    /* If the is parameter is true, the final score would be a
30    linear combination of user model and prior model */
31    private final String USE_PRIOR = "usePriorInformation";
32
33    /* When prior model is used, this parmaeter will set the mixing factor, alpha. */
34    private final String SET_ALPHA = "setAlpha";
35
36    /* When prior model is used, If this parameter is true then algorithm will use
37    the automatic cross validated alpha for mixing user model and prior model */
38    private final String USE_AUTO_ALPHA = "useAutoAlpha";
39
40    /* When automatic cross validation is active, this parameter will
41    set the forget rate in cross validation. */
42    private final String SET_FORGET_RATE = "setForgetRate";
43
44    /* When automatic cross validation is active, this parameter will
45    set the minium number of required training pairs before using the user model */
46    private final String SET_MIN_TRAIN_PAIR = "setMinTrainingPair";
47
48    private final String SET_USER_PERF = "setUserPerformance";
49    private final String SET_PRIOR_PERF = "setPriorPerformance";
50    private final String SET_NUM_TRAIN_PAIR = "setNumberTrainingPairs";
51    private final String SET_AUTO_ALPHA = "setAutoAlpha";
52
53
54
55    private HashMap<String, Float> mPriorWeights = new HashMap<String, Float>();
56    private float mAlpha = 0;
57    private float mAutoAlpha = 0;
58    private float mForgetRate = 0;
59    private float mUserRankerPerf = 0;
60    private float mPriorRankerPerf = 0;
61    private int mMinReqTrainingPair = 0;
62    private int mNumTrainPair = 0;
63    private boolean mUsePrior = false;
64    private boolean mUseAutoAlpha = false;
65
66    static public class Model implements Serializable {
67        public StochasticLinearRanker.Model uModel = new StochasticLinearRanker.Model();
68        public HashMap<String, Float> priorWeights = new HashMap<String, Float>();
69        public HashMap<String, String> priorParameters = new HashMap<String, String>();
70    }
71
72    @Override
73    public void resetRanker(){
74        super.resetRanker();
75        mPriorWeights.clear();
76        mAlpha = 0;
77        mAutoAlpha = 0;
78        mForgetRate = 0;
79        mMinReqTrainingPair = 0;
80        mUserRankerPerf = 0;
81        mPriorRankerPerf = 0;
82        mNumTrainPair = 0;
83        mUsePrior = false;
84        mUseAutoAlpha = false;
85    }
86
87    @Override
88    public float scoreSample(String[] keys, float[] values) {
89        if (!mUsePrior){
90            return super.scoreSample(keys, values);
91        } else {
92            if (mUseAutoAlpha) {
93                if (mNumTrainPair > mMinReqTrainingPair)
94                    return (1 - mAutoAlpha) * super.scoreSample(keys,values) +
95                            mAutoAlpha * priorScoreSample(keys,values);
96                else
97                    return priorScoreSample(keys,values);
98            } else
99                return (1 - mAlpha) * super.scoreSample(keys,values) +
100                        mAlpha * priorScoreSample(keys,values);
101        }
102    }
103
104    public float priorScoreSample(String[] keys, float[] values) {
105        float score = 0;
106        for (int i=0; i< keys.length; i++){
107            if (mPriorWeights.get(keys[i]) != null )
108                score = score + mPriorWeights.get(keys[i]) * values[i];
109        }
110        return score;
111    }
112
113    @Override
114    public boolean updateClassifier(String[] keys_positive,
115                                    float[] values_positive,
116                                    String[] keys_negative,
117                                    float[] values_negative){
118        if (mUsePrior && mUseAutoAlpha && (mNumTrainPair > mMinReqTrainingPair))
119            updateAutoAlpha(keys_positive, values_positive, keys_negative, values_negative);
120        mNumTrainPair ++;
121        return super.updateClassifier(keys_positive, values_positive,
122                                      keys_negative, values_negative);
123    }
124
125    void updateAutoAlpha(String[] keys_positive,
126                     float[] values_positive,
127                     String[] keys_negative,
128                     float[] values_negative) {
129        float positiveUserScore = super.scoreSample(keys_positive, values_positive);
130        float negativeUserScore = super.scoreSample(keys_negative, values_negative);
131        float positivePriorScore = priorScoreSample(keys_positive, values_positive);
132        float negativePriorScore = priorScoreSample(keys_negative, values_negative);
133        float userDecision = 0;
134        float priorDecision = 0;
135        if (positiveUserScore > negativeUserScore)
136            userDecision = 1;
137        if (positivePriorScore > negativePriorScore)
138            priorDecision = 1;
139        mUserRankerPerf = (1 - mForgetRate) * mUserRankerPerf + userDecision;
140        mPriorRankerPerf = (1 - mForgetRate) * mPriorRankerPerf + priorDecision;
141        mAutoAlpha = (mPriorRankerPerf + EPSILON) / (mUserRankerPerf + mPriorRankerPerf + EPSILON);
142    }
143
144    public Model getModel(){
145        Model m = new Model();
146        m.uModel = super.getUModel();
147        m.priorWeights.putAll(mPriorWeights);
148        m.priorParameters.put(SET_ALPHA, String.valueOf(mAlpha));
149        m.priorParameters.put(SET_AUTO_ALPHA, String.valueOf(mAutoAlpha));
150        m.priorParameters.put(SET_FORGET_RATE, String.valueOf(mForgetRate));
151        m.priorParameters.put(SET_MIN_TRAIN_PAIR, String.valueOf(mMinReqTrainingPair));
152        m.priorParameters.put(SET_USER_PERF, String.valueOf(mUserRankerPerf));
153        m.priorParameters.put(SET_PRIOR_PERF, String.valueOf(mPriorRankerPerf));
154        m.priorParameters.put(SET_NUM_TRAIN_PAIR, String.valueOf(mNumTrainPair));
155        m.priorParameters.put(USE_AUTO_ALPHA, String.valueOf(mUseAutoAlpha));
156        m.priorParameters.put(USE_PRIOR, String.valueOf(mUsePrior));
157        return m;
158    }
159
160    public boolean loadModel(Model m) {
161        mPriorWeights.clear();
162        mPriorWeights.putAll(m.priorWeights);
163        for (Map.Entry<String, String> e : m.priorParameters.entrySet()) {
164            boolean res = setModelParameter(e.getKey(), e.getValue());
165            if (!res) return false;
166        }
167        return super.loadModel(m.uModel);
168    }
169
170    public boolean setModelPriorWeights(HashMap<String, Float> pw){
171        mPriorWeights.clear();
172        mPriorWeights.putAll(pw);
173        return true;
174    }
175
176    public boolean setModelParameter(String key, String value){
177        if (key.equals(USE_AUTO_ALPHA)){
178            mUseAutoAlpha = Boolean.parseBoolean(value);
179        } else if (key.equals(USE_PRIOR)){
180            mUsePrior = Boolean.parseBoolean(value);
181        } else if (key.equals(SET_ALPHA)){
182            mAlpha = Float.valueOf(value.trim()).floatValue();
183        }else if (key.equals(SET_AUTO_ALPHA)){
184            mAutoAlpha = Float.valueOf(value.trim()).floatValue();
185        }else if (key.equals(SET_FORGET_RATE)){
186            mForgetRate = Float.valueOf(value.trim()).floatValue();
187        }else if (key.equals(SET_MIN_TRAIN_PAIR)){
188            mMinReqTrainingPair = (int) Float.valueOf(value.trim()).floatValue();
189        }else if (key.equals(SET_USER_PERF)){
190            mUserRankerPerf = Float.valueOf(value.trim()).floatValue();
191        }else if (key.equals(SET_PRIOR_PERF)){
192            mPriorRankerPerf = Float.valueOf(value.trim()).floatValue();
193        }else if (key.equals(SET_NUM_TRAIN_PAIR)){
194            mNumTrainPair = (int) Float.valueOf(value.trim()).floatValue();
195        }else
196            return super.setModelParameter(key, value);
197        return true;
198    }
199
200    public void print(Model m){
201        super.print(m.uModel);
202        String Spw = "";
203        for (Map.Entry<String, Float> e : m.priorWeights.entrySet())
204            Spw = Spw + "<" + e.getKey() + "," + e.getValue() + "> ";
205        Log.i(TAG, "Prior model is " + Spw);
206        String Spp = "";
207        for (Map.Entry<String, String> e : m.priorParameters.entrySet())
208            Spp = Spp + "<" + e.getKey() + "," + e.getValue() + "> ";
209        Log.i(TAG, "Prior parameters are " + Spp);
210    }
211}
212