1b019e89cbea221598c482b05ab68b7660b41aa23saberian/*
2b019e89cbea221598c482b05ab68b7660b41aa23saberian * Copyright (C) 2012 The Android Open Source Project
3b019e89cbea221598c482b05ab68b7660b41aa23saberian *
4b019e89cbea221598c482b05ab68b7660b41aa23saberian * Licensed under the Apache License, Version 2.0 (the "License");
5b019e89cbea221598c482b05ab68b7660b41aa23saberian * you may not use this file except in compliance with the License.
6b019e89cbea221598c482b05ab68b7660b41aa23saberian * You may obtain a copy of the License at
7b019e89cbea221598c482b05ab68b7660b41aa23saberian *
8b019e89cbea221598c482b05ab68b7660b41aa23saberian *      http://www.apache.org/licenses/LICENSE-2.0
9b019e89cbea221598c482b05ab68b7660b41aa23saberian *
10b019e89cbea221598c482b05ab68b7660b41aa23saberian * Unless required by applicable law or agreed to in writing, software
11b019e89cbea221598c482b05ab68b7660b41aa23saberian * distributed under the License is distributed on an "AS IS" BASIS,
12b019e89cbea221598c482b05ab68b7660b41aa23saberian * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13b019e89cbea221598c482b05ab68b7660b41aa23saberian * See the License for the specific language governing permissions and
14b019e89cbea221598c482b05ab68b7660b41aa23saberian * limitations under the License.
15b019e89cbea221598c482b05ab68b7660b41aa23saberian */
16b019e89cbea221598c482b05ab68b7660b41aa23saberian
17b019e89cbea221598c482b05ab68b7660b41aa23saberianpackage android.bordeaux.services;
18b019e89cbea221598c482b05ab68b7660b41aa23saberianimport android.util.Log;
19b019e89cbea221598c482b05ab68b7660b41aa23saberian
20b019e89cbea221598c482b05ab68b7660b41aa23saberianimport android.bordeaux.learning.StochasticLinearRanker;
21b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.HashMap;
22b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.Map;
23b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.io.Serializable;
24b019e89cbea221598c482b05ab68b7660b41aa23saberian
25b019e89cbea221598c482b05ab68b7660b41aa23saberianpublic class StochasticLinearRankerWithPrior extends StochasticLinearRanker {
26b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String TAG = "StochasticLinearRankerWithPrior";
27b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final float EPSILON = 0.0001f;
28b019e89cbea221598c482b05ab68b7660b41aa23saberian
29b019e89cbea221598c482b05ab68b7660b41aa23saberian    /* If the is parameter is true, the final score would be a
30b019e89cbea221598c482b05ab68b7660b41aa23saberian    linear combination of user model and prior model */
31b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String USE_PRIOR = "usePriorInformation";
32b019e89cbea221598c482b05ab68b7660b41aa23saberian
33b019e89cbea221598c482b05ab68b7660b41aa23saberian    /* When prior model is used, this parmaeter will set the mixing factor, alpha. */
34b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String SET_ALPHA = "setAlpha";
35b019e89cbea221598c482b05ab68b7660b41aa23saberian
36b019e89cbea221598c482b05ab68b7660b41aa23saberian    /* When prior model is used, If this parameter is true then algorithm will use
37b019e89cbea221598c482b05ab68b7660b41aa23saberian    the automatic cross validated alpha for mixing user model and prior model */
38b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String USE_AUTO_ALPHA = "useAutoAlpha";
39b019e89cbea221598c482b05ab68b7660b41aa23saberian
40b019e89cbea221598c482b05ab68b7660b41aa23saberian    /* When automatic cross validation is active, this parameter will
41b019e89cbea221598c482b05ab68b7660b41aa23saberian    set the forget rate in cross validation. */
42b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String SET_FORGET_RATE = "setForgetRate";
43b019e89cbea221598c482b05ab68b7660b41aa23saberian
44b019e89cbea221598c482b05ab68b7660b41aa23saberian    /* When automatic cross validation is active, this parameter will
45b019e89cbea221598c482b05ab68b7660b41aa23saberian    set the minium number of required training pairs before using the user model */
46b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String SET_MIN_TRAIN_PAIR = "setMinTrainingPair";
47b019e89cbea221598c482b05ab68b7660b41aa23saberian
48b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String SET_USER_PERF = "setUserPerformance";
49b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String SET_PRIOR_PERF = "setPriorPerformance";
50b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String SET_NUM_TRAIN_PAIR = "setNumberTrainingPairs";
51b019e89cbea221598c482b05ab68b7660b41aa23saberian    private final String SET_AUTO_ALPHA = "setAutoAlpha";
52b019e89cbea221598c482b05ab68b7660b41aa23saberian
53b019e89cbea221598c482b05ab68b7660b41aa23saberian
54b019e89cbea221598c482b05ab68b7660b41aa23saberian
55b019e89cbea221598c482b05ab68b7660b41aa23saberian    private HashMap<String, Float> mPriorWeights = new HashMap<String, Float>();
56b019e89cbea221598c482b05ab68b7660b41aa23saberian    private float mAlpha = 0;
57b019e89cbea221598c482b05ab68b7660b41aa23saberian    private float mAutoAlpha = 0;
58b019e89cbea221598c482b05ab68b7660b41aa23saberian    private float mForgetRate = 0;
59b019e89cbea221598c482b05ab68b7660b41aa23saberian    private float mUserRankerPerf = 0;
60b019e89cbea221598c482b05ab68b7660b41aa23saberian    private float mPriorRankerPerf = 0;
61b019e89cbea221598c482b05ab68b7660b41aa23saberian    private int mMinReqTrainingPair = 0;
62b019e89cbea221598c482b05ab68b7660b41aa23saberian    private int mNumTrainPair = 0;
63b019e89cbea221598c482b05ab68b7660b41aa23saberian    private boolean mUsePrior = false;
64b019e89cbea221598c482b05ab68b7660b41aa23saberian    private boolean mUseAutoAlpha = false;
65b019e89cbea221598c482b05ab68b7660b41aa23saberian
66b019e89cbea221598c482b05ab68b7660b41aa23saberian    static public class Model implements Serializable {
67b019e89cbea221598c482b05ab68b7660b41aa23saberian        public StochasticLinearRanker.Model uModel = new StochasticLinearRanker.Model();
68b019e89cbea221598c482b05ab68b7660b41aa23saberian        public HashMap<String, Float> priorWeights = new HashMap<String, Float>();
69b019e89cbea221598c482b05ab68b7660b41aa23saberian        public HashMap<String, String> priorParameters = new HashMap<String, String>();
70b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
71b019e89cbea221598c482b05ab68b7660b41aa23saberian
72b019e89cbea221598c482b05ab68b7660b41aa23saberian    @Override
73b019e89cbea221598c482b05ab68b7660b41aa23saberian    public void resetRanker(){
74b019e89cbea221598c482b05ab68b7660b41aa23saberian        super.resetRanker();
75b019e89cbea221598c482b05ab68b7660b41aa23saberian        mPriorWeights.clear();
76b019e89cbea221598c482b05ab68b7660b41aa23saberian        mAlpha = 0;
77b019e89cbea221598c482b05ab68b7660b41aa23saberian        mAutoAlpha = 0;
78b019e89cbea221598c482b05ab68b7660b41aa23saberian        mForgetRate = 0;
79b019e89cbea221598c482b05ab68b7660b41aa23saberian        mMinReqTrainingPair = 0;
80b019e89cbea221598c482b05ab68b7660b41aa23saberian        mUserRankerPerf = 0;
81b019e89cbea221598c482b05ab68b7660b41aa23saberian        mPriorRankerPerf = 0;
82b019e89cbea221598c482b05ab68b7660b41aa23saberian        mNumTrainPair = 0;
83b019e89cbea221598c482b05ab68b7660b41aa23saberian        mUsePrior = false;
84b019e89cbea221598c482b05ab68b7660b41aa23saberian        mUseAutoAlpha = false;
85b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
86b019e89cbea221598c482b05ab68b7660b41aa23saberian
87b019e89cbea221598c482b05ab68b7660b41aa23saberian    @Override
88b019e89cbea221598c482b05ab68b7660b41aa23saberian    public float scoreSample(String[] keys, float[] values) {
89b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (!mUsePrior){
90b019e89cbea221598c482b05ab68b7660b41aa23saberian            return super.scoreSample(keys, values);
91b019e89cbea221598c482b05ab68b7660b41aa23saberian        } else {
92b019e89cbea221598c482b05ab68b7660b41aa23saberian            if (mUseAutoAlpha) {
93b019e89cbea221598c482b05ab68b7660b41aa23saberian                if (mNumTrainPair > mMinReqTrainingPair)
94b019e89cbea221598c482b05ab68b7660b41aa23saberian                    return (1 - mAutoAlpha) * super.scoreSample(keys,values) +
95b019e89cbea221598c482b05ab68b7660b41aa23saberian                            mAutoAlpha * priorScoreSample(keys,values);
96b019e89cbea221598c482b05ab68b7660b41aa23saberian                else
97b019e89cbea221598c482b05ab68b7660b41aa23saberian                    return priorScoreSample(keys,values);
98b019e89cbea221598c482b05ab68b7660b41aa23saberian            } else
99b019e89cbea221598c482b05ab68b7660b41aa23saberian                return (1 - mAlpha) * super.scoreSample(keys,values) +
100b019e89cbea221598c482b05ab68b7660b41aa23saberian                        mAlpha * priorScoreSample(keys,values);
101b019e89cbea221598c482b05ab68b7660b41aa23saberian        }
102b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
103b019e89cbea221598c482b05ab68b7660b41aa23saberian
104b019e89cbea221598c482b05ab68b7660b41aa23saberian    public float priorScoreSample(String[] keys, float[] values) {
105b019e89cbea221598c482b05ab68b7660b41aa23saberian        float score = 0;
106b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (int i=0; i< keys.length; i++){
107b019e89cbea221598c482b05ab68b7660b41aa23saberian            if (mPriorWeights.get(keys[i]) != null )
108b019e89cbea221598c482b05ab68b7660b41aa23saberian                score = score + mPriorWeights.get(keys[i]) * values[i];
109b019e89cbea221598c482b05ab68b7660b41aa23saberian        }
110b019e89cbea221598c482b05ab68b7660b41aa23saberian        return score;
111b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
112b019e89cbea221598c482b05ab68b7660b41aa23saberian
113b019e89cbea221598c482b05ab68b7660b41aa23saberian    @Override
114b019e89cbea221598c482b05ab68b7660b41aa23saberian    public boolean updateClassifier(String[] keys_positive,
115b019e89cbea221598c482b05ab68b7660b41aa23saberian                                    float[] values_positive,
116b019e89cbea221598c482b05ab68b7660b41aa23saberian                                    String[] keys_negative,
117b019e89cbea221598c482b05ab68b7660b41aa23saberian                                    float[] values_negative){
118b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (mUsePrior && mUseAutoAlpha && (mNumTrainPair > mMinReqTrainingPair))
119b019e89cbea221598c482b05ab68b7660b41aa23saberian            updateAutoAlpha(keys_positive, values_positive, keys_negative, values_negative);
120b019e89cbea221598c482b05ab68b7660b41aa23saberian        mNumTrainPair ++;
121b019e89cbea221598c482b05ab68b7660b41aa23saberian        return super.updateClassifier(keys_positive, values_positive,
122b019e89cbea221598c482b05ab68b7660b41aa23saberian                                      keys_negative, values_negative);
123b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
124b019e89cbea221598c482b05ab68b7660b41aa23saberian
125b019e89cbea221598c482b05ab68b7660b41aa23saberian    void updateAutoAlpha(String[] keys_positive,
126b019e89cbea221598c482b05ab68b7660b41aa23saberian                     float[] values_positive,
127b019e89cbea221598c482b05ab68b7660b41aa23saberian                     String[] keys_negative,
128b019e89cbea221598c482b05ab68b7660b41aa23saberian                     float[] values_negative) {
129b019e89cbea221598c482b05ab68b7660b41aa23saberian        float positiveUserScore = super.scoreSample(keys_positive, values_positive);
130b019e89cbea221598c482b05ab68b7660b41aa23saberian        float negativeUserScore = super.scoreSample(keys_negative, values_negative);
131b019e89cbea221598c482b05ab68b7660b41aa23saberian        float positivePriorScore = priorScoreSample(keys_positive, values_positive);
132b019e89cbea221598c482b05ab68b7660b41aa23saberian        float negativePriorScore = priorScoreSample(keys_negative, values_negative);
133b019e89cbea221598c482b05ab68b7660b41aa23saberian        float userDecision = 0;
134b019e89cbea221598c482b05ab68b7660b41aa23saberian        float priorDecision = 0;
135b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (positiveUserScore > negativeUserScore)
136b019e89cbea221598c482b05ab68b7660b41aa23saberian            userDecision = 1;
137b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (positivePriorScore > negativePriorScore)
138b019e89cbea221598c482b05ab68b7660b41aa23saberian            priorDecision = 1;
139b019e89cbea221598c482b05ab68b7660b41aa23saberian        mUserRankerPerf = (1 - mForgetRate) * mUserRankerPerf + userDecision;
140b019e89cbea221598c482b05ab68b7660b41aa23saberian        mPriorRankerPerf = (1 - mForgetRate) * mPriorRankerPerf + priorDecision;
141b019e89cbea221598c482b05ab68b7660b41aa23saberian        mAutoAlpha = (mPriorRankerPerf + EPSILON) / (mUserRankerPerf + mPriorRankerPerf + EPSILON);
142b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
143b019e89cbea221598c482b05ab68b7660b41aa23saberian
144b019e89cbea221598c482b05ab68b7660b41aa23saberian    public Model getModel(){
145b019e89cbea221598c482b05ab68b7660b41aa23saberian        Model m = new Model();
146b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.uModel = super.getUModel();
147b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.priorWeights.putAll(mPriorWeights);
148b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.priorParameters.put(SET_ALPHA, String.valueOf(mAlpha));
149b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.priorParameters.put(SET_AUTO_ALPHA, String.valueOf(mAutoAlpha));
150b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.priorParameters.put(SET_FORGET_RATE, String.valueOf(mForgetRate));
151b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.priorParameters.put(SET_MIN_TRAIN_PAIR, String.valueOf(mMinReqTrainingPair));
152b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.priorParameters.put(SET_USER_PERF, String.valueOf(mUserRankerPerf));
153b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.priorParameters.put(SET_PRIOR_PERF, String.valueOf(mPriorRankerPerf));
154b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.priorParameters.put(SET_NUM_TRAIN_PAIR, String.valueOf(mNumTrainPair));
155b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.priorParameters.put(USE_AUTO_ALPHA, String.valueOf(mUseAutoAlpha));
156b019e89cbea221598c482b05ab68b7660b41aa23saberian        m.priorParameters.put(USE_PRIOR, String.valueOf(mUsePrior));
157b019e89cbea221598c482b05ab68b7660b41aa23saberian        return m;
158b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
159b019e89cbea221598c482b05ab68b7660b41aa23saberian
160b019e89cbea221598c482b05ab68b7660b41aa23saberian    public boolean loadModel(Model m) {
161b019e89cbea221598c482b05ab68b7660b41aa23saberian        mPriorWeights.clear();
162b019e89cbea221598c482b05ab68b7660b41aa23saberian        mPriorWeights.putAll(m.priorWeights);
163b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (Map.Entry<String, String> e : m.priorParameters.entrySet()) {
164b019e89cbea221598c482b05ab68b7660b41aa23saberian            boolean res = setModelParameter(e.getKey(), e.getValue());
165b019e89cbea221598c482b05ab68b7660b41aa23saberian            if (!res) return false;
166b019e89cbea221598c482b05ab68b7660b41aa23saberian        }
167b019e89cbea221598c482b05ab68b7660b41aa23saberian        return super.loadModel(m.uModel);
168b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
169b019e89cbea221598c482b05ab68b7660b41aa23saberian
170b019e89cbea221598c482b05ab68b7660b41aa23saberian    public boolean setModelPriorWeights(HashMap<String, Float> pw){
171b019e89cbea221598c482b05ab68b7660b41aa23saberian        mPriorWeights.clear();
172b019e89cbea221598c482b05ab68b7660b41aa23saberian        mPriorWeights.putAll(pw);
173b019e89cbea221598c482b05ab68b7660b41aa23saberian        return true;
174b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
175b019e89cbea221598c482b05ab68b7660b41aa23saberian
176b019e89cbea221598c482b05ab68b7660b41aa23saberian    public boolean setModelParameter(String key, String value){
177b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (key.equals(USE_AUTO_ALPHA)){
178b019e89cbea221598c482b05ab68b7660b41aa23saberian            mUseAutoAlpha = Boolean.parseBoolean(value);
179b019e89cbea221598c482b05ab68b7660b41aa23saberian        } else if (key.equals(USE_PRIOR)){
180b019e89cbea221598c482b05ab68b7660b41aa23saberian            mUsePrior = Boolean.parseBoolean(value);
181b019e89cbea221598c482b05ab68b7660b41aa23saberian        } else if (key.equals(SET_ALPHA)){
182b019e89cbea221598c482b05ab68b7660b41aa23saberian            mAlpha = Float.valueOf(value.trim()).floatValue();
183b019e89cbea221598c482b05ab68b7660b41aa23saberian        }else if (key.equals(SET_AUTO_ALPHA)){
184b019e89cbea221598c482b05ab68b7660b41aa23saberian            mAutoAlpha = Float.valueOf(value.trim()).floatValue();
185b019e89cbea221598c482b05ab68b7660b41aa23saberian        }else if (key.equals(SET_FORGET_RATE)){
186b019e89cbea221598c482b05ab68b7660b41aa23saberian            mForgetRate = Float.valueOf(value.trim()).floatValue();
187b019e89cbea221598c482b05ab68b7660b41aa23saberian        }else if (key.equals(SET_MIN_TRAIN_PAIR)){
188b019e89cbea221598c482b05ab68b7660b41aa23saberian            mMinReqTrainingPair = (int) Float.valueOf(value.trim()).floatValue();
189b019e89cbea221598c482b05ab68b7660b41aa23saberian        }else if (key.equals(SET_USER_PERF)){
190b019e89cbea221598c482b05ab68b7660b41aa23saberian            mUserRankerPerf = Float.valueOf(value.trim()).floatValue();
191b019e89cbea221598c482b05ab68b7660b41aa23saberian        }else if (key.equals(SET_PRIOR_PERF)){
192b019e89cbea221598c482b05ab68b7660b41aa23saberian            mPriorRankerPerf = Float.valueOf(value.trim()).floatValue();
193b019e89cbea221598c482b05ab68b7660b41aa23saberian        }else if (key.equals(SET_NUM_TRAIN_PAIR)){
194b019e89cbea221598c482b05ab68b7660b41aa23saberian            mNumTrainPair = (int) Float.valueOf(value.trim()).floatValue();
195b019e89cbea221598c482b05ab68b7660b41aa23saberian        }else
196b019e89cbea221598c482b05ab68b7660b41aa23saberian            return super.setModelParameter(key, value);
197b019e89cbea221598c482b05ab68b7660b41aa23saberian        return true;
198b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
199b019e89cbea221598c482b05ab68b7660b41aa23saberian
200b019e89cbea221598c482b05ab68b7660b41aa23saberian    public void print(Model m){
201b019e89cbea221598c482b05ab68b7660b41aa23saberian        super.print(m.uModel);
202b019e89cbea221598c482b05ab68b7660b41aa23saberian        String Spw = "";
203b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (Map.Entry<String, Float> e : m.priorWeights.entrySet())
204b019e89cbea221598c482b05ab68b7660b41aa23saberian            Spw = Spw + "<" + e.getKey() + "," + e.getValue() + "> ";
205b019e89cbea221598c482b05ab68b7660b41aa23saberian        Log.i(TAG, "Prior model is " + Spw);
206b019e89cbea221598c482b05ab68b7660b41aa23saberian        String Spp = "";
207b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (Map.Entry<String, String> e : m.priorParameters.entrySet())
208b019e89cbea221598c482b05ab68b7660b41aa23saberian            Spp = Spp + "<" + e.getKey() + "," + e.getValue() + "> ";
209b019e89cbea221598c482b05ab68b7660b41aa23saberian        Log.i(TAG, "Prior parameters are " + Spp);
210b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
211b019e89cbea221598c482b05ab68b7660b41aa23saberian}
212