16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/*
26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2011 The Android Open Source Project
36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License");
56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License.
66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at
76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *      http://www.apache.org/licenses/LICENSE-2.0
96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software
116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS,
126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and
146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License.
156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */
166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huapackage android.bordeaux.learning;
191dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua
206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport android.util.Log;
211dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua
221dd8ef56681617db46caec7776c9bf416f01d8ddWei Huaimport java.io.Serializable;
236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport java.util.List;
246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport java.util.Arrays;
256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaimport java.util.ArrayList;
26b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.HashMap;
27b019e89cbea221598c482b05ab68b7660b41aa23saberianimport java.util.Map;
286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/**
306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Stochastic Linear Ranker, learns how to rank a sample. The learned rank score
316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * can be used to compare samples.
326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * This java class wraps the native StochasticLinearRanker class.
336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * To update the ranker, call updateClassifier with two samples, with the first
346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * one having higher rank than the second one.
356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * To get the rank score of the sample call scoreSample.
366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *  TODO: adding more interfaces for changing the learning parameters
376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */
386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huapublic class StochasticLinearRanker {
396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    String TAG = "StochasticLinearRanker";
40b019e89cbea221598c482b05ab68b7660b41aa23saberian    public static int VAR_NUM = 14;
411dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    static public class Model implements Serializable {
42b019e89cbea221598c482b05ab68b7660b41aa23saberian        public HashMap<String, Float> weights = new HashMap<String, Float>();
43b019e89cbea221598c482b05ab68b7660b41aa23saberian        public float weightNormalizer = 1;
44b019e89cbea221598c482b05ab68b7660b41aa23saberian        public HashMap<String, String> parameters = new HashMap<String, String>();
451dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    }
461dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua
47b019e89cbea221598c482b05ab68b7660b41aa23saberian    /**
48b019e89cbea221598c482b05ab68b7660b41aa23saberian     * Initializing a ranker
49b019e89cbea221598c482b05ab68b7660b41aa23saberian     */
506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    public StochasticLinearRanker() {
516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        mNativeClassifier = initNativeClassifier();
526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    /**
55b019e89cbea221598c482b05ab68b7660b41aa23saberian     * Reset the ranker
56b019e89cbea221598c482b05ab68b7660b41aa23saberian     */
57b019e89cbea221598c482b05ab68b7660b41aa23saberian    public void resetRanker(){
58b019e89cbea221598c482b05ab68b7660b41aa23saberian        deleteNativeClassifier(mNativeClassifier);
59b019e89cbea221598c482b05ab68b7660b41aa23saberian        mNativeClassifier = initNativeClassifier();
60b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
61b019e89cbea221598c482b05ab68b7660b41aa23saberian
62b019e89cbea221598c482b05ab68b7660b41aa23saberian    /**
636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua     * Train the ranker with a pair of samples. A sample,  a pair of arrays of
646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua     * keys and values. The first sample should have higher rank than the second
656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua     * one.
666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua     */
676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    public boolean updateClassifier(String[] keys_positive,
686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                    float[] values_positive,
696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                    String[] keys_negative,
706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                    float[] values_negative) {
716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        return nativeUpdateClassifier(keys_positive, values_positive,
726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                      keys_negative, values_negative,
736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                      mNativeClassifier);
746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    /**
771dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua     * Get the rank score of the sample, a sample is a list of key, value pairs.
786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua     */
796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    public float scoreSample(String[] keys, float[] values) {
806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        return nativeScoreSample(keys, values, mNativeClassifier);
816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    /**
846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua     * Get the current model and parameters of ranker
856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua     */
86b019e89cbea221598c482b05ab68b7660b41aa23saberian    public Model getUModel(){
87b019e89cbea221598c482b05ab68b7660b41aa23saberian        Model slrModel = new Model();
886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        int len = nativeGetLengthClassifier(mNativeClassifier);
89b019e89cbea221598c482b05ab68b7660b41aa23saberian        String[] wKeys = new String[len];
90b019e89cbea221598c482b05ab68b7660b41aa23saberian        float[] wValues = new float[len];
91b019e89cbea221598c482b05ab68b7660b41aa23saberian        float wNormalizer = 1;
92b019e89cbea221598c482b05ab68b7660b41aa23saberian        nativeGetWeightClassifier(wKeys, wValues, wNormalizer, mNativeClassifier);
93b019e89cbea221598c482b05ab68b7660b41aa23saberian        slrModel.weightNormalizer = wNormalizer;
94b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (int  i=0; i< wKeys.length ; i++)
95b019e89cbea221598c482b05ab68b7660b41aa23saberian            slrModel.weights.put(wKeys[i], wValues[i]);
96b019e89cbea221598c482b05ab68b7660b41aa23saberian
97b019e89cbea221598c482b05ab68b7660b41aa23saberian        String[] paramKeys = new String[VAR_NUM];
98b019e89cbea221598c482b05ab68b7660b41aa23saberian        String[] paramValues = new String[VAR_NUM];
99b019e89cbea221598c482b05ab68b7660b41aa23saberian        nativeGetParameterClassifier(paramKeys, paramValues, mNativeClassifier);
100b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (int  i=0; i< paramKeys.length ; i++)
101b019e89cbea221598c482b05ab68b7660b41aa23saberian            slrModel.parameters.put(paramKeys[i], paramValues[i]);
102b019e89cbea221598c482b05ab68b7660b41aa23saberian        return slrModel;
1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    /**
106b019e89cbea221598c482b05ab68b7660b41aa23saberian     * load the given model and parameters to the ranker
1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua     */
1081dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua    public boolean loadModel(Model model) {
109b019e89cbea221598c482b05ab68b7660b41aa23saberian        String[] wKeys = new String[model.weights.size()];
110b019e89cbea221598c482b05ab68b7660b41aa23saberian        float[] wValues = new float[model.weights.size()];
111b019e89cbea221598c482b05ab68b7660b41aa23saberian        int i = 0 ;
112b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (Map.Entry<String, Float> e : model.weights.entrySet()){
113b019e89cbea221598c482b05ab68b7660b41aa23saberian            wKeys[i] = e.getKey();
114b019e89cbea221598c482b05ab68b7660b41aa23saberian            wValues[i] = e.getValue();
115b019e89cbea221598c482b05ab68b7660b41aa23saberian            i++;
1161dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        }
117b019e89cbea221598c482b05ab68b7660b41aa23saberian        boolean res = setModelWeights(wKeys, wValues, model.weightNormalizer);
118b019e89cbea221598c482b05ab68b7660b41aa23saberian        if (!res)
119b019e89cbea221598c482b05ab68b7660b41aa23saberian            return false;
120b019e89cbea221598c482b05ab68b7660b41aa23saberian
121b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (Map.Entry<String, String> e : model.parameters.entrySet()){
122b019e89cbea221598c482b05ab68b7660b41aa23saberian            res = setModelParameter(e.getKey(), e.getValue());
123b019e89cbea221598c482b05ab68b7660b41aa23saberian            if (!res)
124b019e89cbea221598c482b05ab68b7660b41aa23saberian                return false;
1251dd8ef56681617db46caec7776c9bf416f01d8ddWei Hua        }
126b019e89cbea221598c482b05ab68b7660b41aa23saberian        return res;
127b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
128b019e89cbea221598c482b05ab68b7660b41aa23saberian
129b019e89cbea221598c482b05ab68b7660b41aa23saberian    public boolean setModelWeights(String[] keys, float [] values, float normalizer){
130b019e89cbea221598c482b05ab68b7660b41aa23saberian        return nativeSetWeightClassifier(keys, values, normalizer, mNativeClassifier);
131b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
132b019e89cbea221598c482b05ab68b7660b41aa23saberian
133b019e89cbea221598c482b05ab68b7660b41aa23saberian    public boolean setModelParameter(String key, String value){
134b019e89cbea221598c482b05ab68b7660b41aa23saberian        boolean res = nativeSetParameterClassifier(key, value, mNativeClassifier);
135b019e89cbea221598c482b05ab68b7660b41aa23saberian        return res;
136b019e89cbea221598c482b05ab68b7660b41aa23saberian    }
137b019e89cbea221598c482b05ab68b7660b41aa23saberian
138b019e89cbea221598c482b05ab68b7660b41aa23saberian    /**
139b019e89cbea221598c482b05ab68b7660b41aa23saberian     * Print a model for debugging
140b019e89cbea221598c482b05ab68b7660b41aa23saberian     */
141b019e89cbea221598c482b05ab68b7660b41aa23saberian    public void print(Model model){
142b019e89cbea221598c482b05ab68b7660b41aa23saberian        String Sw = "";
143b019e89cbea221598c482b05ab68b7660b41aa23saberian        String Sp = "";
144b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (Map.Entry<String, Float> e : model.weights.entrySet())
145b019e89cbea221598c482b05ab68b7660b41aa23saberian            Sw = Sw + "<" + e.getKey() + "," + e.getValue() + "> ";
146b019e89cbea221598c482b05ab68b7660b41aa23saberian        for (Map.Entry<String, String> e : model.parameters.entrySet())
147b019e89cbea221598c482b05ab68b7660b41aa23saberian            Sp = Sp + "<" + e.getKey() + "," + e.getValue() + "> ";
148b019e89cbea221598c482b05ab68b7660b41aa23saberian        Log.i(TAG, "Weights are " + Sw);
149b019e89cbea221598c482b05ab68b7660b41aa23saberian        Log.i(TAG, "Normalizer is " + model.weightNormalizer);
150b019e89cbea221598c482b05ab68b7660b41aa23saberian        Log.i(TAG, "Parameters are " + Sp);
1516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    @Override
1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    protected void finalize() throws Throwable {
1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        deleteNativeClassifier(mNativeClassifier);
1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    static {
1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        System.loadLibrary("bordeaux");
1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    private int mNativeClassifier;
1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    /*
1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua     * The following methods are the java stubs for the jni implementations.
1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua     */
1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    private native int initNativeClassifier();
1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    private native void deleteNativeClassifier(int classifierPtr);
1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    private native boolean nativeUpdateClassifier(
1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            String[] keys_positive,
1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            float[] values_positive,
1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            String[] keys_negative,
1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            float[] values_negative,
1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua            int classifierPtr);
1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
178b019e89cbea221598c482b05ab68b7660b41aa23saberian    private native float nativeScoreSample(String[] keys, float[] values, int classifierPtr);
179b019e89cbea221598c482b05ab68b7660b41aa23saberian
180b019e89cbea221598c482b05ab68b7660b41aa23saberian    private native void nativeGetWeightClassifier(String [] keys, float[] values, float normalizer,
181b019e89cbea221598c482b05ab68b7660b41aa23saberian                                                  int classifierPtr);
182b019e89cbea221598c482b05ab68b7660b41aa23saberian
183b019e89cbea221598c482b05ab68b7660b41aa23saberian    private native void nativeGetParameterClassifier(String [] keys, String[] values,
184b019e89cbea221598c482b05ab68b7660b41aa23saberian                                                  int classifierPtr);
185b019e89cbea221598c482b05ab68b7660b41aa23saberian
1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    private native int nativeGetLengthClassifier(int classifierPtr);
187b019e89cbea221598c482b05ab68b7660b41aa23saberian
188b019e89cbea221598c482b05ab68b7660b41aa23saberian    private native boolean nativeSetWeightClassifier(String [] keys, float[] values,
189b019e89cbea221598c482b05ab68b7660b41aa23saberian                                                     float normalizer, int classifierPtr);
190b019e89cbea221598c482b05ab68b7660b41aa23saberian
191b019e89cbea221598c482b05ab68b7660b41aa23saberian    private native boolean nativeSetParameterClassifier(String key, String value,
192b019e89cbea221598c482b05ab68b7660b41aa23saberian                                                        int classifierPtr);
1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
194