StochasticLinearRanker.java revision 1dd8ef56681617db46caec7776c9bf416f01d8dd
1/*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18package android.bordeaux.learning;
19
20import android.util.Log;
21
22import java.io.Serializable;
23import java.util.List;
24import java.util.Arrays;
25import java.util.ArrayList;
26
27/**
28 * Stochastic Linear Ranker, learns how to rank a sample. The learned rank score
29 * can be used to compare samples.
30 * This java class wraps the native StochasticLinearRanker class.
31 * To update the ranker, call updateClassifier with two samples, with the first
32 * one having higher rank than the second one.
33 * To get the rank score of the sample call scoreSample.
34 *  TODO: adding more interfaces for changing the learning parameters
35 */
36public class StochasticLinearRanker {
37    String TAG = "StochasticLinearRanker";
38
39    static public class Model implements Serializable {
40        public ArrayList<String> keys = new ArrayList<String>();
41        public ArrayList<Float> values = new ArrayList<Float>();
42        public ArrayList<Float> parameters = new ArrayList<Float>();
43    }
44
45    static int VAR_NUM = 15;
46    public StochasticLinearRanker() {
47        mNativeClassifier = initNativeClassifier();
48    }
49
50    /**
51     * Train the ranker with a pair of samples. A sample,  a pair of arrays of
52     * keys and values. The first sample should have higher rank than the second
53     * one.
54     */
55    public boolean updateClassifier(String[] keys_positive,
56                                    float[] values_positive,
57                                    String[] keys_negative,
58                                    float[] values_negative) {
59        return nativeUpdateClassifier(keys_positive, values_positive,
60                                      keys_negative, values_negative,
61                                      mNativeClassifier);
62    }
63
64    /**
65     * Get the rank score of the sample, a sample is a list of key, value pairs.
66     */
67    public float scoreSample(String[] keys, float[] values) {
68        return nativeScoreSample(keys, values, mNativeClassifier);
69    }
70
71    /**
72     * Get the current model and parameters of ranker
73     */
74    public Model getModel(){
75        Model model = new Model();
76        int len = nativeGetLengthClassifier(mNativeClassifier);
77        String[] keys = new String[len];
78        float[] values = new float[len];
79        float[] param = new float[VAR_NUM];
80        nativeGetClassifier(keys, values, param, mNativeClassifier);
81        boolean add_flag;
82        for (int  i=0; i< keys.length ; i++){
83            add_flag = model.keys.add(keys[i]);
84            add_flag = model.values.add(values[i]);
85        }
86        for (int  i=0; i< param.length ; i++)
87            add_flag = model.parameters.add(param[i]);
88        return model;
89    }
90
91    /**
92     * use the given model and parameters for ranker
93     */
94    public boolean loadModel(Model model) {
95        float[] values = new float[model.values.size()];
96        float[] param = new float[model.parameters.size()];
97        for (int i = 0; i < model.values.size(); ++i) {
98            values[i]  = model.values.get(i);
99        }
100        for (int i = 0; i < model.parameters.size(); ++i) {
101            param[i]  = model.parameters.get(i);
102        }
103        String[] keys = new String[model.keys.size()];
104        model.keys.toArray(keys);
105        return nativeLoadClassifier(keys, values, param, mNativeClassifier);
106    }
107
108    @Override
109    protected void finalize() throws Throwable {
110        deleteNativeClassifier(mNativeClassifier);
111    }
112
113    static {
114        System.loadLibrary("bordeaux");
115    }
116
117    private int mNativeClassifier;
118
119    /*
120     * The following methods are the java stubs for the jni implementations.
121     */
122    private native int initNativeClassifier();
123
124    private native void deleteNativeClassifier(int classifierPtr);
125
126    private native boolean nativeUpdateClassifier(
127            String[] keys_positive,
128            float[] values_positive,
129            String[] keys_negative,
130            float[] values_negative,
131            int classifierPtr);
132
133    private native float nativeScoreSample(String[] keys,
134                                           float[] values,
135                                           int classifierPtr);
136    private native void nativeGetClassifier(String [] keys, float[] values, float[] param,
137                                             int classifierPtr);
138    private native boolean nativeLoadClassifier(String [] keys, float[] values,
139                                                 float[] param, int classifierPtr);
140    private native int nativeGetLengthClassifier(int classifierPtr);
141}
142