1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// JNI wrapper for the TextClassifier.
18
19#include "textclassifier_jni.h"
20
21#include <jni.h>
22#include <type_traits>
23#include <vector>
24
25#include "text-classifier.h"
26#include "util/base/integral_types.h"
27#include "util/java/scoped_local_ref.h"
28#include "util/java/string_utils.h"
29#include "util/memory/mmap.h"
30#include "util/utf8/unilib.h"
31
32using libtextclassifier2::AnnotatedSpan;
33using libtextclassifier2::AnnotationOptions;
34using libtextclassifier2::ClassificationOptions;
35using libtextclassifier2::ClassificationResult;
36using libtextclassifier2::CodepointSpan;
37using libtextclassifier2::JStringToUtf8String;
38using libtextclassifier2::Model;
39using libtextclassifier2::ScopedLocalRef;
40using libtextclassifier2::SelectionOptions;
41using libtextclassifier2::TextClassifier;
42#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
43using libtextclassifier2::UniLib;
44#endif
45
46namespace libtextclassifier2 {
47
48using libtextclassifier2::CodepointSpan;
49
50namespace {
51
52std::string ToStlString(JNIEnv* env, const jstring& str) {
53  std::string result;
54  JStringToUtf8String(env, str, &result);
55  return result;
56}
57
58jobjectArray ClassificationResultsToJObjectArray(
59    JNIEnv* env,
60    const std::vector<ClassificationResult>& classification_result) {
61  const ScopedLocalRef<jclass> result_class(
62      env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"),
63      env);
64  if (!result_class) {
65    TC_LOG(ERROR) << "Couldn't find ClassificationResult class.";
66    return nullptr;
67  }
68  const ScopedLocalRef<jclass> datetime_parse_class(
69      env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env);
70  if (!datetime_parse_class) {
71    TC_LOG(ERROR) << "Couldn't find DatetimeResult class.";
72    return nullptr;
73  }
74
75  const jmethodID result_class_constructor =
76      env->GetMethodID(result_class.get(), "<init>",
77                       "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR
78                       "$DatetimeResult;)V");
79  const jmethodID datetime_parse_class_constructor =
80      env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
81
82  const jobjectArray results = env->NewObjectArray(classification_result.size(),
83                                                   result_class.get(), nullptr);
84  for (int i = 0; i < classification_result.size(); i++) {
85    jstring row_string =
86        env->NewStringUTF(classification_result[i].collection.c_str());
87    jobject row_datetime_parse = nullptr;
88    if (classification_result[i].datetime_parse_result.IsSet()) {
89      row_datetime_parse = env->NewObject(
90          datetime_parse_class.get(), datetime_parse_class_constructor,
91          classification_result[i].datetime_parse_result.time_ms_utc,
92          classification_result[i].datetime_parse_result.granularity);
93    }
94    jobject result =
95        env->NewObject(result_class.get(), result_class_constructor, row_string,
96                       static_cast<jfloat>(classification_result[i].score),
97                       row_datetime_parse);
98    env->SetObjectArrayElement(results, i, result);
99    env->DeleteLocalRef(result);
100  }
101  return results;
102}
103
104template <typename T, typename F>
105std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object,
106                                  jclass class_object, F function,
107                                  const std::string& method_name,
108                                  const std::string& return_java_type) {
109  const jmethodID method = env->GetMethodID(class_object, method_name.c_str(),
110                                            ("()" + return_java_type).c_str());
111  if (!method) {
112    return std::make_pair(false, T());
113  }
114  return std::make_pair(true, (env->*function)(object, method));
115}
116
117SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
118  if (!joptions) {
119    return {};
120  }
121
122  const ScopedLocalRef<jclass> options_class(
123      env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"),
124      env);
125  const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
126      env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
127      "getLocales", "Ljava/lang/String;");
128  if (!status_or_locales.first) {
129    return {};
130  }
131
132  SelectionOptions options;
133  options.locales =
134      ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
135
136  return options;
137}
138
139template <typename T>
140T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
141                          const std::string& class_name) {
142  if (!joptions) {
143    return {};
144  }
145
146  const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
147                                             env);
148  if (!options_class) {
149    return {};
150  }
151
152  const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
153      env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
154      "getLocale", "Ljava/lang/String;");
155  const std::pair<bool, jobject> status_or_reference_timezone =
156      CallJniMethod0<jobject>(env, joptions, options_class.get(),
157                              &JNIEnv::CallObjectMethod, "getReferenceTimezone",
158                              "Ljava/lang/String;");
159  const std::pair<bool, int64> status_or_reference_time_ms_utc =
160      CallJniMethod0<int64>(env, joptions, options_class.get(),
161                            &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
162                            "J");
163
164  if (!status_or_locales.first || !status_or_reference_timezone.first ||
165      !status_or_reference_time_ms_utc.first) {
166    return {};
167  }
168
169  T options;
170  options.locales =
171      ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
172  options.reference_timezone = ToStlString(
173      env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
174  options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
175  return options;
176}
177
178ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
179                                                    jobject joptions) {
180  return FromJavaOptionsInternal<ClassificationOptions>(
181      env, joptions,
182      TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions");
183}
184
185AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
186  return FromJavaOptionsInternal<AnnotationOptions>(
187      env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions");
188}
189
190CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
191                                    CodepointSpan orig_indices,
192                                    bool from_utf8) {
193  const libtextclassifier2::UnicodeText unicode_str =
194      libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
195
196  int unicode_index = 0;
197  int bmp_index = 0;
198
199  const int* source_index;
200  const int* target_index;
201  if (from_utf8) {
202    source_index = &unicode_index;
203    target_index = &bmp_index;
204  } else {
205    source_index = &bmp_index;
206    target_index = &unicode_index;
207  }
208
209  CodepointSpan result{-1, -1};
210  std::function<void()> assign_indices_fn = [&result, &orig_indices,
211                                             &source_index, &target_index]() {
212    if (orig_indices.first == *source_index) {
213      result.first = *target_index;
214    }
215
216    if (orig_indices.second == *source_index) {
217      result.second = *target_index;
218    }
219  };
220
221  for (auto it = unicode_str.begin(); it != unicode_str.end();
222       ++it, ++unicode_index, ++bmp_index) {
223    assign_indices_fn();
224
225    // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
226    if (*it > 0xFFFF) {
227      ++bmp_index;
228    }
229  }
230  assign_indices_fn();
231
232  return result;
233}
234
235}  // namespace
236
237CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
238                                      CodepointSpan bmp_indices) {
239  return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
240}
241
242CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
243                                      CodepointSpan utf8_indices) {
244  return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
245}
246
247jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
248  // Get system-level file descriptor from AssetFileDescriptor.
249  ScopedLocalRef<jclass> afd_class(
250      env->FindClass("android/content/res/AssetFileDescriptor"), env);
251  if (afd_class == nullptr) {
252    TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
253    return reinterpret_cast<jlong>(nullptr);
254  }
255  jmethodID afd_class_getFileDescriptor = env->GetMethodID(
256      afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
257  if (afd_class_getFileDescriptor == nullptr) {
258    TC_LOG(ERROR) << "Couldn't find getFileDescriptor.";
259    return reinterpret_cast<jlong>(nullptr);
260  }
261
262  ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
263                                  env);
264  if (fd_class == nullptr) {
265    TC_LOG(ERROR) << "Couldn't find FileDescriptor.";
266    return reinterpret_cast<jlong>(nullptr);
267  }
268  jfieldID fd_class_descriptor =
269      env->GetFieldID(fd_class.get(), "descriptor", "I");
270  if (fd_class_descriptor == nullptr) {
271    TC_LOG(ERROR) << "Couldn't find descriptor.";
272    return reinterpret_cast<jlong>(nullptr);
273  }
274
275  jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
276  return env->GetIntField(bundle_jfd, fd_class_descriptor);
277}
278
279jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
280  if (!mmap->handle().ok()) {
281    return env->NewStringUTF("");
282  }
283  const Model* model = libtextclassifier2::ViewModel(
284      mmap->handle().start(), mmap->handle().num_bytes());
285  if (!model || !model->locales()) {
286    return env->NewStringUTF("");
287  }
288  return env->NewStringUTF(model->locales()->c_str());
289}
290
291jint GetVersionFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
292  if (!mmap->handle().ok()) {
293    return 0;
294  }
295  const Model* model = libtextclassifier2::ViewModel(
296      mmap->handle().start(), mmap->handle().num_bytes());
297  if (!model) {
298    return 0;
299  }
300  return model->version();
301}
302
303jstring GetNameFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
304  if (!mmap->handle().ok()) {
305    return env->NewStringUTF("");
306  }
307  const Model* model = libtextclassifier2::ViewModel(
308      mmap->handle().start(), mmap->handle().num_bytes());
309  if (!model || !model->name()) {
310    return env->NewStringUTF("");
311  }
312  return env->NewStringUTF(model->name()->c_str());
313}
314
315}  // namespace libtextclassifier2
316
317using libtextclassifier2::ClassificationResultsToJObjectArray;
318using libtextclassifier2::ConvertIndicesBMPToUTF8;
319using libtextclassifier2::ConvertIndicesUTF8ToBMP;
320using libtextclassifier2::FromJavaAnnotationOptions;
321using libtextclassifier2::FromJavaClassificationOptions;
322using libtextclassifier2::FromJavaSelectionOptions;
323using libtextclassifier2::ToStlString;
324
325JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew)
326(JNIEnv* env, jobject thiz, jint fd) {
327#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
328  return reinterpret_cast<jlong>(
329      TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env));
330#else
331  return reinterpret_cast<jlong>(
332      TextClassifier::FromFileDescriptor(fd).release());
333#endif
334}
335
336JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath)
337(JNIEnv* env, jobject thiz, jstring path) {
338  const std::string path_str = ToStlString(env, path);
339#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
340  return reinterpret_cast<jlong>(
341      TextClassifier::FromPath(path_str, new UniLib(env)).release());
342#else
343  return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release());
344#endif
345}
346
347JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor)
348(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
349  const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
350#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
351  return reinterpret_cast<jlong>(
352      TextClassifier::FromFileDescriptor(fd, offset, size, new UniLib(env))
353          .release());
354#else
355  return reinterpret_cast<jlong>(
356      TextClassifier::FromFileDescriptor(fd, offset, size).release());
357#endif
358}
359
360JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection)
361(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
362 jint selection_end, jobject options) {
363  if (!ptr) {
364    return nullptr;
365  }
366
367  TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
368
369  const std::string context_utf8 = ToStlString(env, context);
370  CodepointSpan input_indices =
371      ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
372  CodepointSpan selection = model->SuggestSelection(
373      context_utf8, input_indices, FromJavaSelectionOptions(env, options));
374  selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
375
376  jintArray result = env->NewIntArray(2);
377  env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
378  env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
379  return result;
380}
381
382JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText)
383(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
384 jint selection_end, jobject options) {
385  if (!ptr) {
386    return nullptr;
387  }
388  TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr);
389
390  const std::string context_utf8 = ToStlString(env, context);
391  const CodepointSpan input_indices =
392      ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
393  const std::vector<ClassificationResult> classification_result =
394      ff_model->ClassifyText(context_utf8, input_indices,
395                             FromJavaClassificationOptions(env, options));
396
397  return ClassificationResultsToJObjectArray(env, classification_result);
398}
399
400JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate)
401(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) {
402  if (!ptr) {
403    return nullptr;
404  }
405  TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
406  std::string context_utf8 = ToStlString(env, context);
407  std::vector<AnnotatedSpan> annotations =
408      model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options));
409
410  jclass result_class =
411      env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan");
412  if (!result_class) {
413    TC_LOG(ERROR) << "Couldn't find result class: "
414                  << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan";
415    return nullptr;
416  }
417
418  jmethodID result_class_constructor = env->GetMethodID(
419      result_class, "<init>",
420      "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V");
421
422  jobjectArray results =
423      env->NewObjectArray(annotations.size(), result_class, nullptr);
424
425  for (int i = 0; i < annotations.size(); ++i) {
426    CodepointSpan span_bmp =
427        ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
428    jobject result = env->NewObject(
429        result_class, result_class_constructor,
430        static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
431        ClassificationResultsToJObjectArray(env,
432
433                                            annotations[i].classification));
434    env->SetObjectArrayElement(results, i, result);
435    env->DeleteLocalRef(result);
436  }
437  env->DeleteLocalRef(result_class);
438  return results;
439}
440
441JNI_METHOD(void, TC_CLASS_NAME, nativeClose)
442(JNIEnv* env, jobject thiz, jlong ptr) {
443  TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
444  delete model;
445}
446
447JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage)
448(JNIEnv* env, jobject clazz, jint fd) {
449  TC_LOG(WARNING) << "Using deprecated getLanguage().";
450  return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd);
451}
452
453JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales)
454(JNIEnv* env, jobject clazz, jint fd) {
455  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
456      new libtextclassifier2::ScopedMmap(fd));
457  return GetLocalesFromMmap(env, mmap.get());
458}
459
460JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor)
461(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
462  const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
463  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
464      new libtextclassifier2::ScopedMmap(fd, offset, size));
465  return GetLocalesFromMmap(env, mmap.get());
466}
467
468JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion)
469(JNIEnv* env, jobject clazz, jint fd) {
470  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
471      new libtextclassifier2::ScopedMmap(fd));
472  return GetVersionFromMmap(env, mmap.get());
473}
474
475JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor)
476(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
477  const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
478  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
479      new libtextclassifier2::ScopedMmap(fd, offset, size));
480  return GetVersionFromMmap(env, mmap.get());
481}
482
483JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName)
484(JNIEnv* env, jobject clazz, jint fd) {
485  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
486      new libtextclassifier2::ScopedMmap(fd));
487  return GetNameFromMmap(env, mmap.get());
488}
489
490JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor)
491(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
492  const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
493  const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
494      new libtextclassifier2::ScopedMmap(fd, offset, size));
495  return GetNameFromMmap(env, mmap.get());
496}
497