jni_stochastic_linear_ranker.cpp revision 6b4eebc73439cbc3ddfb547444a341d1f9be7996
1/*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *    http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "jni/jni_stochastic_linear_ranker.h"
18#include "native/common_defs.h"
19#include "native/sparse_weight_vector.h"
20#include "native/stochastic_linear_ranker.h"
21
22#include <vector>
23#include <string>
24using std::string;
25using std::vector;
26using std::hash_map;
27using learning_stochastic_linear::StochasticLinearRanker;
28using learning_stochastic_linear::SparseWeightVector;
29
30void CreateSparseWeightVector(JNIEnv* env, const jobjectArray keys, const float* values,
31    const int length, SparseWeightVector<string> * sample) {
32  for (int i = 0; i < length; ++i) {
33    jboolean iscopy;
34    jstring s = (jstring) env->GetObjectArrayElement(keys, i);
35    const char *key = env->GetStringUTFChars(s, &iscopy);
36    sample->SetElement(key, static_cast<double>(values[i]));
37    env->ReleaseStringUTFChars(s,key);
38  }
39}
40
41void DecomposeSparseWeightVector(JNIEnv* env, jobjectArray *keys, jfloatArray *values,
42    const int length, SparseWeightVector<string> *sample) {
43
44  SparseWeightVector<string>::Wmap w_ = sample->GetMap();
45  int i=0;
46  for ( SparseWeightVector<string>::Witer_const iter = w_.begin();
47    iter != w_.end(); ++iter) {
48    std::string key = iter->first;
49    jstring jstr = env->NewStringUTF(key.c_str());
50    env->SetObjectArrayElement(*keys, i, jstr);
51    double value = iter->second;
52    jfloat s[1];
53    s[0] = value;
54    env->SetFloatArrayRegion(*values, i, 1, s);
55    i++;
56  }
57}
58
59jboolean Java_android_bordeaux_learning_StochasticLinearRanker_nativeLoadClassifier(
60    JNIEnv* env,
61    jobject thiz,
62    jobjectArray key_array_model,
63    jfloatArray value_array_model,
64    jfloatArray value_array_param,
65    jint paPtr) {
66
67  StochasticLinearRanker<string>* classifier = (StochasticLinearRanker<string>*) paPtr;
68  if (classifier && key_array_model && value_array_model && value_array_param) {
69    const int keys_m_len = env->GetArrayLength(key_array_model);
70    jfloat* values_m = env->GetFloatArrayElements(value_array_model, NULL);
71    const int values_m_len = env->GetArrayLength(value_array_model);
72    jfloat* param_m = env->GetFloatArrayElements(value_array_param, NULL);
73
74    if (values_m && key_array_model && values_m_len == keys_m_len) {
75      SparseWeightVector<string> model;
76      CreateSparseWeightVector(env, key_array_model, values_m, values_m_len, &model);
77      model.SetNormalizer((double) param_m[0]);
78      classifier->LoadWeights(model);
79      classifier->SetIterationNumber((uint64) param_m[1]);
80      classifier->SetNormConstraint((double) param_m[2]);
81
82      switch ((int) param_m[3]){
83      case 0 :
84        classifier->SetRegularizationType(learning_stochastic_linear::L0);
85        break;
86      case 1 :
87        classifier->SetRegularizationType(learning_stochastic_linear::L1);
88        break;
89      case 2 :
90        classifier->SetRegularizationType(learning_stochastic_linear::L2);
91        break;
92      case 3 :
93        classifier->SetRegularizationType(learning_stochastic_linear::L1L2);
94        break;
95      case 4 :
96        classifier->SetRegularizationType(learning_stochastic_linear::L1LInf);
97        break;
98      }
99
100      classifier->SetLambda((double) param_m[4]);
101
102      switch ((int) param_m[5]){
103      case 0 :
104        classifier->SetUpdateType(learning_stochastic_linear::FULL_CS);
105        break;
106      case 1 :
107        classifier->SetUpdateType(learning_stochastic_linear::CLIP_CS);
108        break;
109      case 2 :
110        classifier->SetUpdateType(learning_stochastic_linear::REG_CS);
111        break;
112      case 3 :
113        classifier->SetUpdateType(learning_stochastic_linear::SL);
114        break;
115      case 4 :
116        classifier->SetUpdateType(learning_stochastic_linear::ADAPTIVE_REG);
117        break;
118      }
119
120      switch ((int) param_m[6]){
121      case 0 :
122        classifier->SetAdaptationMode(learning_stochastic_linear::CONST);
123        break;
124      case 1 :
125        classifier->SetAdaptationMode(learning_stochastic_linear::INV_LINEAR);
126        break;
127      case 2 :
128        classifier->SetAdaptationMode(learning_stochastic_linear::INV_QUADRATIC);
129        break;
130      case 3 :
131        classifier->SetAdaptationMode(learning_stochastic_linear::INV_SQRT);
132        break;
133      }
134
135      switch ((int) param_m[7]){
136      case 0 :
137        classifier->SetKernelType(learning_stochastic_linear::LINEAR, (double) param_m[8],
138                                  (double) param_m[9],(double) param_m[10]);
139        break;
140      case 1 : classifier->SetKernelType(learning_stochastic_linear::POLY, (double) param_m[8],
141                                         (double) param_m[9],(double) param_m[10]);
142        break;
143      case 2 : classifier->SetKernelType(learning_stochastic_linear::RBF, (double) param_m[8],
144                                          (double) param_m[9],(double) param_m[10]);
145        break;
146      }
147
148      switch ((int) param_m[11]){
149      case 0 :
150        classifier->SetRankLossType(learning_stochastic_linear::PAIRWISE);
151        break;
152      case 1 :
153        classifier->SetRankLossType(learning_stochastic_linear::RECIPROCAL_RANK);
154        break;
155      }
156
157      classifier->SetAcceptanceProbability((double) param_m[12]);
158      classifier->SetMiniBatchSize((uint64)param_m[13]);
159      classifier->SetGradientL0Norm((int32)param_m[14]);
160      env->ReleaseFloatArrayElements(value_array_model, values_m, JNI_ABORT);
161      env->ReleaseFloatArrayElements(value_array_param, param_m, JNI_ABORT);
162      return JNI_TRUE;
163    }
164  }
165  return JNI_FALSE;
166}
167
168jint Java_android_bordeaux_learning_StochasticLinearRanker_nativeGetLengthClassifier(
169  JNIEnv* env,
170  jobject thiz,
171  jint paPtr) {
172
173  StochasticLinearRanker<string>* classifier = (StochasticLinearRanker<string>*) paPtr;
174  SparseWeightVector<string> M_weights;
175  classifier->SaveWeights(&M_weights);
176
177  SparseWeightVector<string>::Wmap w_map = M_weights.GetMap();
178  int len = w_map.size();
179  return len;
180}
181
182void Java_android_bordeaux_learning_StochasticLinearRanker_nativeGetClassifier(
183  JNIEnv* env,
184  jobject thiz,
185  jobjectArray key_array_model,
186  jfloatArray value_array_model,
187  jfloatArray value_array_param,
188  jint paPtr) {
189
190  StochasticLinearRanker<string>* classifier = (StochasticLinearRanker<string>*) paPtr;
191
192  SparseWeightVector<string> M_weights;
193  classifier->SaveWeights(&M_weights);
194  double Jni_weight_normalizer = M_weights.GetNormalizer();
195  int Jni_itr_num = classifier->GetIterationNumber();
196  double Jni_norm_cont = classifier->GetNormContraint();
197  int Jni_reg_type = classifier->GetRegularizationType();
198  double Jni_lambda = classifier->GetLambda();
199  int Jni_update_type = classifier->GetUpdateType();
200  int Jni_AdaptationMode = classifier->GetAdaptationMode();
201  double Jni_kernel_param, Jni_kernel_gain, Jni_kernel_bias;
202  int Jni_kernel_type = classifier->GetKernelType(&Jni_kernel_param, &Jni_kernel_gain, &Jni_kernel_bias);
203  int Jni_rank_loss_type = classifier->GetRankLossType();
204  double Jni_accp_prob = classifier->GetAcceptanceProbability();
205  uint64 Jni_min_batch_size = classifier->GetMiniBatchSize();
206  int32 Jni_GradL0Norm = classifier->GetGradientL0Norm();
207  const int Var_num = 15;
208  jfloat s[Var_num]= {  (float) Jni_weight_normalizer,
209                        (float) Jni_itr_num,
210                        (float) Jni_norm_cont,
211                        (float) Jni_reg_type,
212                        (float) Jni_lambda,
213                        (float) Jni_update_type,
214                        (float) Jni_AdaptationMode,
215                        (float) Jni_kernel_type,
216                        (float) Jni_kernel_param,
217                        (float) Jni_kernel_gain,
218                        (float) Jni_kernel_bias,
219                        (float) Jni_rank_loss_type,
220                        (float) Jni_accp_prob,
221                        (float) Jni_min_batch_size,
222                        (float) Jni_GradL0Norm};
223
224  env->SetFloatArrayRegion(value_array_param, 0, Var_num, s);
225
226  SparseWeightVector<string>::Wmap w_map = M_weights.GetMap();
227  int array_len = w_map.size();
228
229  DecomposeSparseWeightVector(env, &key_array_model, &value_array_model, array_len, &M_weights);
230}
231
232jint Java_android_bordeaux_learning_StochasticLinearRanker_initNativeClassifier(JNIEnv* env,
233                             jobject thiz) {
234  StochasticLinearRanker<string>* classifier = new StochasticLinearRanker<string>();
235  classifier->SetUpdateType(learning_stochastic_linear::REG_CS);
236  classifier->SetRegularizationType(learning_stochastic_linear::L2);
237  return ((jint) classifier);
238}
239
240
241jboolean Java_android_bordeaux_learning_StochasticLinearRanker_deleteNativeClassifier(JNIEnv* env,
242                               jobject thiz,
243                               jint paPtr) {
244  StochasticLinearRanker<string>* classifier = (StochasticLinearRanker<string>*) paPtr;
245  delete classifier;
246  return JNI_TRUE;
247}
248
249jboolean Java_android_bordeaux_learning_StochasticLinearRanker_nativeUpdateClassifier(
250  JNIEnv* env,
251  jobject thiz,
252  jobjectArray key_array_positive,
253  jfloatArray value_array_positive,
254  jobjectArray key_array_negative,
255  jfloatArray value_array_negative,
256  jint paPtr) {
257  StochasticLinearRanker<string>* classifier = (StochasticLinearRanker<string>*) paPtr;
258
259  if (classifier && key_array_positive && value_array_positive &&
260      key_array_negative && value_array_negative) {
261
262    const int keys_p_len = env->GetArrayLength(key_array_positive);
263    jfloat* values_p = env->GetFloatArrayElements(value_array_positive, NULL);
264    const int values_p_len = env->GetArrayLength(value_array_positive);
265    jfloat* values_n = env->GetFloatArrayElements(value_array_negative, NULL);
266    const int values_n_len = env->GetArrayLength(value_array_negative);
267    const int keys_n_len = env->GetArrayLength(key_array_negative);
268
269    if (values_p && key_array_positive && values_p_len == keys_p_len &&
270      values_n && key_array_negative && values_n_len == keys_n_len) {
271
272      SparseWeightVector<string> sample_pos;
273      SparseWeightVector<string> sample_neg;
274      CreateSparseWeightVector(env, key_array_positive, values_p, values_p_len, &sample_pos);
275      CreateSparseWeightVector(env, key_array_negative, values_n, values_n_len, &sample_neg);
276      classifier->UpdateClassifier(sample_pos, sample_neg);
277      env->ReleaseFloatArrayElements(value_array_negative, values_n, JNI_ABORT);
278      env->ReleaseFloatArrayElements(value_array_positive, values_p, JNI_ABORT);
279
280      return JNI_TRUE;
281    }
282    env->ReleaseFloatArrayElements(value_array_negative, values_n, JNI_ABORT);
283    env->ReleaseFloatArrayElements(value_array_positive, values_p, JNI_ABORT);
284  }
285  return JNI_FALSE;
286}
287
288
289jfloat Java_android_bordeaux_learning_StochasticLinearRanker_nativeScoreSample(
290  JNIEnv* env,
291  jobject thiz,
292  jobjectArray key_array,
293  jfloatArray value_array,
294  jint paPtr) {
295
296  StochasticLinearRanker<string>* classifier = (StochasticLinearRanker<string>*) paPtr;
297
298  if (classifier && key_array && value_array) {
299
300    jfloat* values = env->GetFloatArrayElements(value_array, NULL);
301    const int values_len = env->GetArrayLength(value_array);
302    const int keys_len = env->GetArrayLength(key_array);
303
304    if (values && key_array && values_len == keys_len) {
305      SparseWeightVector<string> sample;
306      CreateSparseWeightVector(env, key_array, values, values_len, &sample);
307      env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
308      return classifier->ScoreSample(sample);
309    }
310  }
311  return -1;
312}
313