1/*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18package android.bordeaux.learning;
19
20import android.util.Log;
21
22import java.io.Serializable;
23import java.util.List;
24import java.util.Arrays;
25import java.util.ArrayList;
26import java.util.HashMap;
27import java.util.Map;
28
29/**
30 * Stochastic Linear Ranker, learns how to rank a sample. The learned rank score
31 * can be used to compare samples.
32 * This java class wraps the native StochasticLinearRanker class.
33 * To update the ranker, call updateClassifier with two samples, with the first
34 * one having higher rank than the second one.
35 * To get the rank score of the sample call scoreSample.
36 *  TODO: adding more interfaces for changing the learning parameters
37 */
38public class StochasticLinearRanker {
39    String TAG = "StochasticLinearRanker";
40    public static int VAR_NUM = 14;
41    static public class Model implements Serializable {
42        public HashMap<String, Float> weights = new HashMap<String, Float>();
43        public float weightNormalizer = 1;
44        public HashMap<String, String> parameters = new HashMap<String, String>();
45    }
46
47    /**
48     * Initializing a ranker
49     */
50    public StochasticLinearRanker() {
51        mNativeClassifier = initNativeClassifier();
52    }
53
54    /**
55     * Reset the ranker
56     */
57    public void resetRanker(){
58        deleteNativeClassifier(mNativeClassifier);
59        mNativeClassifier = initNativeClassifier();
60    }
61
62    /**
63     * Train the ranker with a pair of samples. A sample,  a pair of arrays of
64     * keys and values. The first sample should have higher rank than the second
65     * one.
66     */
67    public boolean updateClassifier(String[] keys_positive,
68                                    float[] values_positive,
69                                    String[] keys_negative,
70                                    float[] values_negative) {
71        return nativeUpdateClassifier(keys_positive, values_positive,
72                                      keys_negative, values_negative,
73                                      mNativeClassifier);
74    }
75
76    /**
77     * Get the rank score of the sample, a sample is a list of key, value pairs.
78     */
79    public float scoreSample(String[] keys, float[] values) {
80        return nativeScoreSample(keys, values, mNativeClassifier);
81    }
82
83    /**
84     * Get the current model and parameters of ranker
85     */
86    public Model getUModel(){
87        Model slrModel = new Model();
88        int len = nativeGetLengthClassifier(mNativeClassifier);
89        String[] wKeys = new String[len];
90        float[] wValues = new float[len];
91        float wNormalizer = 1;
92        nativeGetWeightClassifier(wKeys, wValues, wNormalizer, mNativeClassifier);
93        slrModel.weightNormalizer = wNormalizer;
94        for (int  i=0; i< wKeys.length ; i++)
95            slrModel.weights.put(wKeys[i], wValues[i]);
96
97        String[] paramKeys = new String[VAR_NUM];
98        String[] paramValues = new String[VAR_NUM];
99        nativeGetParameterClassifier(paramKeys, paramValues, mNativeClassifier);
100        for (int  i=0; i< paramKeys.length ; i++)
101            slrModel.parameters.put(paramKeys[i], paramValues[i]);
102        return slrModel;
103    }
104
105    /**
106     * load the given model and parameters to the ranker
107     */
108    public boolean loadModel(Model model) {
109        String[] wKeys = new String[model.weights.size()];
110        float[] wValues = new float[model.weights.size()];
111        int i = 0 ;
112        for (Map.Entry<String, Float> e : model.weights.entrySet()){
113            wKeys[i] = e.getKey();
114            wValues[i] = e.getValue();
115            i++;
116        }
117        boolean res = setModelWeights(wKeys, wValues, model.weightNormalizer);
118        if (!res)
119            return false;
120
121        for (Map.Entry<String, String> e : model.parameters.entrySet()){
122            res = setModelParameter(e.getKey(), e.getValue());
123            if (!res)
124                return false;
125        }
126        return res;
127    }
128
129    public boolean setModelWeights(String[] keys, float [] values, float normalizer){
130        return nativeSetWeightClassifier(keys, values, normalizer, mNativeClassifier);
131    }
132
133    public boolean setModelParameter(String key, String value){
134        boolean res = nativeSetParameterClassifier(key, value, mNativeClassifier);
135        return res;
136    }
137
138    /**
139     * Print a model for debugging
140     */
141    public void print(Model model){
142        String Sw = "";
143        String Sp = "";
144        for (Map.Entry<String, Float> e : model.weights.entrySet())
145            Sw = Sw + "<" + e.getKey() + "," + e.getValue() + "> ";
146        for (Map.Entry<String, String> e : model.parameters.entrySet())
147            Sp = Sp + "<" + e.getKey() + "," + e.getValue() + "> ";
148        Log.i(TAG, "Weights are " + Sw);
149        Log.i(TAG, "Normalizer is " + model.weightNormalizer);
150        Log.i(TAG, "Parameters are " + Sp);
151    }
152
153    @Override
154    protected void finalize() throws Throwable {
155        deleteNativeClassifier(mNativeClassifier);
156    }
157
158    static {
159        System.loadLibrary("bordeaux");
160    }
161
162    private int mNativeClassifier;
163
164    /*
165     * The following methods are the java stubs for the jni implementations.
166     */
167    private native int initNativeClassifier();
168
169    private native void deleteNativeClassifier(int classifierPtr);
170
171    private native boolean nativeUpdateClassifier(
172            String[] keys_positive,
173            float[] values_positive,
174            String[] keys_negative,
175            float[] values_negative,
176            int classifierPtr);
177
178    private native float nativeScoreSample(String[] keys, float[] values, int classifierPtr);
179
180    private native void nativeGetWeightClassifier(String [] keys, float[] values, float normalizer,
181                                                  int classifierPtr);
182
183    private native void nativeGetParameterClassifier(String [] keys, String[] values,
184                                                  int classifierPtr);
185
186    private native int nativeGetLengthClassifier(int classifierPtr);
187
188    private native boolean nativeSetWeightClassifier(String [] keys, float[] values,
189                                                     float normalizer, int classifierPtr);
190
191    private native boolean nativeSetParameterClassifier(String key, String value,
192                                                        int classifierPtr);
193}
194