1/*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "jni/jni_multiclass_pa.h"
18#include "native/multiclass_pa.h"
19
20#include <vector>
21
22using learningfw::MulticlassPA;
23using std::vector;
24using std::pair;
25
26void CreateIndexValuePairs(const int* indices, const float* values,
27                           const int length, vector<pair<int, float> >* pairs) {
28  pairs->clear();
29
30  for (int i = 0; i < length; ++i) {
31    pair<int, float> new_pair(indices[i], values[i]);
32    pairs->push_back(new_pair);
33  }
34}
35
36jint Java_android_bordeaux_learning_MulticlassPA_initNativeClassifier(JNIEnv* env,
37                                                       jobject thiz,
38                                                       jint num_classes,
39                                                       jint num_dims,
40                                                       jfloat aggressiveness) {
41  MulticlassPA* classifier = new MulticlassPA(num_classes,
42                                              num_dims,
43                                              aggressiveness);
44  return ((jint) classifier);
45}
46
47
48jboolean Java_android_bordeaux_learning_MulticlassPA_deleteNativeClassifier(JNIEnv* env,
49                                                             jobject thiz,
50                                                             jint paPtr) {
51  MulticlassPA* classifier = (MulticlassPA*) paPtr;
52  delete classifier;
53  return JNI_TRUE;
54}
55
56jboolean Java_android_bordeaux_learning_MulticlassPA_nativeSparseTrainOneExample(JNIEnv* env,
57                                                                  jobject thiz,
58                                                                  jintArray index_array,
59                                                                  jfloatArray value_array,
60                                                                  jint target,
61                                                                  jint paPtr) {
62  MulticlassPA* classifier = (MulticlassPA*) paPtr;
63
64  if (classifier && index_array && value_array) {
65
66    jfloat* values = env->GetFloatArrayElements(value_array, NULL);
67    jint* indices = env->GetIntArrayElements(index_array, NULL);
68    const int value_len = env->GetArrayLength(value_array);
69    const int index_len = env->GetArrayLength(index_array);
70
71    if (values && indices && value_len == index_len) {
72      vector<pair<int, float> > inputs;
73
74      CreateIndexValuePairs(indices, values, value_len, &inputs);
75      classifier->SparseTrainOneExample(inputs, target);
76      env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
77      env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
78
79      return JNI_TRUE;
80    }
81    env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
82    env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
83  }
84
85  return JNI_FALSE;
86}
87
88
89jint Java_android_bordeaux_learning_MulticlassPA_nativeSparseGetClass(JNIEnv* env,
90                                                       jobject thiz,
91                                                       jintArray index_array,
92                                                       jfloatArray value_array,
93                                                       jint paPtr) {
94
95  MulticlassPA* classifier = (MulticlassPA*) paPtr;
96
97  if (classifier && index_array && value_array) {
98
99    jfloat* values = env->GetFloatArrayElements(value_array, NULL);
100    jint* indices = env->GetIntArrayElements(index_array, NULL);
101    const int value_len = env->GetArrayLength(value_array);
102    const int index_len = env->GetArrayLength(index_array);
103
104    if (values && indices && value_len == index_len) {
105      vector<pair<int, float> > inputs;
106      CreateIndexValuePairs(indices, values, value_len, &inputs);
107      env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
108      env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
109      return classifier->SparseGetClass(inputs);
110    }
111    env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
112    env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
113  }
114
115  return -1;
116}
117