1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// JNI wrapper for the TextClassifier. 18 19#include "textclassifier_jni.h" 20 21#include <jni.h> 22#include <type_traits> 23#include <vector> 24 25#include "text-classifier.h" 26#include "util/base/integral_types.h" 27#include "util/java/scoped_local_ref.h" 28#include "util/java/string_utils.h" 29#include "util/memory/mmap.h" 30#include "util/utf8/unilib.h" 31 32using libtextclassifier2::AnnotatedSpan; 33using libtextclassifier2::AnnotationOptions; 34using libtextclassifier2::ClassificationOptions; 35using libtextclassifier2::ClassificationResult; 36using libtextclassifier2::CodepointSpan; 37using libtextclassifier2::JStringToUtf8String; 38using libtextclassifier2::Model; 39using libtextclassifier2::ScopedLocalRef; 40using libtextclassifier2::SelectionOptions; 41using libtextclassifier2::TextClassifier; 42#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU 43using libtextclassifier2::UniLib; 44#endif 45 46namespace libtextclassifier2 { 47 48using libtextclassifier2::CodepointSpan; 49 50namespace { 51 52std::string ToStlString(JNIEnv* env, const jstring& str) { 53 std::string result; 54 JStringToUtf8String(env, str, &result); 55 return result; 56} 57 58jobjectArray ClassificationResultsToJObjectArray( 59 JNIEnv* env, 60 const std::vector<ClassificationResult>& classification_result) { 61 const ScopedLocalRef<jclass> result_class( 62 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"), 63 env); 64 if (!result_class) { 65 TC_LOG(ERROR) << "Couldn't find ClassificationResult class."; 66 return nullptr; 67 } 68 const ScopedLocalRef<jclass> datetime_parse_class( 69 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env); 70 if (!datetime_parse_class) { 71 TC_LOG(ERROR) << "Couldn't find DatetimeResult class."; 72 return nullptr; 73 } 74 75 const jmethodID result_class_constructor = 76 env->GetMethodID(result_class.get(), "<init>", 77 "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR 78 "$DatetimeResult;)V"); 79 const jmethodID datetime_parse_class_constructor = 80 env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V"); 81 82 const jobjectArray results = env->NewObjectArray(classification_result.size(), 83 result_class.get(), nullptr); 84 for (int i = 0; i < classification_result.size(); i++) { 85 jstring row_string = 86 env->NewStringUTF(classification_result[i].collection.c_str()); 87 jobject row_datetime_parse = nullptr; 88 if (classification_result[i].datetime_parse_result.IsSet()) { 89 row_datetime_parse = env->NewObject( 90 datetime_parse_class.get(), datetime_parse_class_constructor, 91 classification_result[i].datetime_parse_result.time_ms_utc, 92 classification_result[i].datetime_parse_result.granularity); 93 } 94 jobject result = 95 env->NewObject(result_class.get(), result_class_constructor, row_string, 96 static_cast<jfloat>(classification_result[i].score), 97 row_datetime_parse); 98 env->SetObjectArrayElement(results, i, result); 99 env->DeleteLocalRef(result); 100 } 101 return results; 102} 103 104template <typename T, typename F> 105std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object, 106 jclass class_object, F function, 107 const std::string& method_name, 108 const std::string& return_java_type) { 109 const jmethodID method = env->GetMethodID(class_object, method_name.c_str(), 110 ("()" + return_java_type).c_str()); 111 if (!method) { 112 return std::make_pair(false, T()); 113 } 114 return std::make_pair(true, (env->*function)(object, method)); 115} 116 117SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) { 118 if (!joptions) { 119 return {}; 120 } 121 122 const ScopedLocalRef<jclass> options_class( 123 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"), 124 env); 125 const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( 126 env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, 127 "getLocales", "Ljava/lang/String;"); 128 if (!status_or_locales.first) { 129 return {}; 130 } 131 132 SelectionOptions options; 133 options.locales = 134 ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); 135 136 return options; 137} 138 139template <typename T> 140T FromJavaOptionsInternal(JNIEnv* env, jobject joptions, 141 const std::string& class_name) { 142 if (!joptions) { 143 return {}; 144 } 145 146 const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()), 147 env); 148 if (!options_class) { 149 return {}; 150 } 151 152 const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( 153 env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, 154 "getLocale", "Ljava/lang/String;"); 155 const std::pair<bool, jobject> status_or_reference_timezone = 156 CallJniMethod0<jobject>(env, joptions, options_class.get(), 157 &JNIEnv::CallObjectMethod, "getReferenceTimezone", 158 "Ljava/lang/String;"); 159 const std::pair<bool, int64> status_or_reference_time_ms_utc = 160 CallJniMethod0<int64>(env, joptions, options_class.get(), 161 &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc", 162 "J"); 163 164 if (!status_or_locales.first || !status_or_reference_timezone.first || 165 !status_or_reference_time_ms_utc.first) { 166 return {}; 167 } 168 169 T options; 170 options.locales = 171 ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); 172 options.reference_timezone = ToStlString( 173 env, reinterpret_cast<jstring>(status_or_reference_timezone.second)); 174 options.reference_time_ms_utc = status_or_reference_time_ms_utc.second; 175 return options; 176} 177 178ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, 179 jobject joptions) { 180 return FromJavaOptionsInternal<ClassificationOptions>( 181 env, joptions, 182 TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions"); 183} 184 185AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) { 186 return FromJavaOptionsInternal<AnnotationOptions>( 187 env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions"); 188} 189 190CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str, 191 CodepointSpan orig_indices, 192 bool from_utf8) { 193 const libtextclassifier2::UnicodeText unicode_str = 194 libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false); 195 196 int unicode_index = 0; 197 int bmp_index = 0; 198 199 const int* source_index; 200 const int* target_index; 201 if (from_utf8) { 202 source_index = &unicode_index; 203 target_index = &bmp_index; 204 } else { 205 source_index = &bmp_index; 206 target_index = &unicode_index; 207 } 208 209 CodepointSpan result{-1, -1}; 210 std::function<void()> assign_indices_fn = [&result, &orig_indices, 211 &source_index, &target_index]() { 212 if (orig_indices.first == *source_index) { 213 result.first = *target_index; 214 } 215 216 if (orig_indices.second == *source_index) { 217 result.second = *target_index; 218 } 219 }; 220 221 for (auto it = unicode_str.begin(); it != unicode_str.end(); 222 ++it, ++unicode_index, ++bmp_index) { 223 assign_indices_fn(); 224 225 // There is 1 extra character in the input for each UTF8 character > 0xFFFF. 226 if (*it > 0xFFFF) { 227 ++bmp_index; 228 } 229 } 230 assign_indices_fn(); 231 232 return result; 233} 234 235} // namespace 236 237CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str, 238 CodepointSpan bmp_indices) { 239 return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false); 240} 241 242CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str, 243 CodepointSpan utf8_indices) { 244 return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true); 245} 246 247jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) { 248 // Get system-level file descriptor from AssetFileDescriptor. 249 ScopedLocalRef<jclass> afd_class( 250 env->FindClass("android/content/res/AssetFileDescriptor"), env); 251 if (afd_class == nullptr) { 252 TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor."; 253 return reinterpret_cast<jlong>(nullptr); 254 } 255 jmethodID afd_class_getFileDescriptor = env->GetMethodID( 256 afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;"); 257 if (afd_class_getFileDescriptor == nullptr) { 258 TC_LOG(ERROR) << "Couldn't find getFileDescriptor."; 259 return reinterpret_cast<jlong>(nullptr); 260 } 261 262 ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"), 263 env); 264 if (fd_class == nullptr) { 265 TC_LOG(ERROR) << "Couldn't find FileDescriptor."; 266 return reinterpret_cast<jlong>(nullptr); 267 } 268 jfieldID fd_class_descriptor = 269 env->GetFieldID(fd_class.get(), "descriptor", "I"); 270 if (fd_class_descriptor == nullptr) { 271 TC_LOG(ERROR) << "Couldn't find descriptor."; 272 return reinterpret_cast<jlong>(nullptr); 273 } 274 275 jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor); 276 return env->GetIntField(bundle_jfd, fd_class_descriptor); 277} 278 279jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { 280 if (!mmap->handle().ok()) { 281 return env->NewStringUTF(""); 282 } 283 const Model* model = libtextclassifier2::ViewModel( 284 mmap->handle().start(), mmap->handle().num_bytes()); 285 if (!model || !model->locales()) { 286 return env->NewStringUTF(""); 287 } 288 return env->NewStringUTF(model->locales()->c_str()); 289} 290 291jint GetVersionFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { 292 if (!mmap->handle().ok()) { 293 return 0; 294 } 295 const Model* model = libtextclassifier2::ViewModel( 296 mmap->handle().start(), mmap->handle().num_bytes()); 297 if (!model) { 298 return 0; 299 } 300 return model->version(); 301} 302 303jstring GetNameFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { 304 if (!mmap->handle().ok()) { 305 return env->NewStringUTF(""); 306 } 307 const Model* model = libtextclassifier2::ViewModel( 308 mmap->handle().start(), mmap->handle().num_bytes()); 309 if (!model || !model->name()) { 310 return env->NewStringUTF(""); 311 } 312 return env->NewStringUTF(model->name()->c_str()); 313} 314 315} // namespace libtextclassifier2 316 317using libtextclassifier2::ClassificationResultsToJObjectArray; 318using libtextclassifier2::ConvertIndicesBMPToUTF8; 319using libtextclassifier2::ConvertIndicesUTF8ToBMP; 320using libtextclassifier2::FromJavaAnnotationOptions; 321using libtextclassifier2::FromJavaClassificationOptions; 322using libtextclassifier2::FromJavaSelectionOptions; 323using libtextclassifier2::ToStlString; 324 325JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew) 326(JNIEnv* env, jobject thiz, jint fd) { 327#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU 328 return reinterpret_cast<jlong>( 329 TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env)); 330#else 331 return reinterpret_cast<jlong>( 332 TextClassifier::FromFileDescriptor(fd).release()); 333#endif 334} 335 336JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath) 337(JNIEnv* env, jobject thiz, jstring path) { 338 const std::string path_str = ToStlString(env, path); 339#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU 340 return reinterpret_cast<jlong>( 341 TextClassifier::FromPath(path_str, new UniLib(env)).release()); 342#else 343 return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release()); 344#endif 345} 346 347JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor) 348(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { 349 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); 350#ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU 351 return reinterpret_cast<jlong>( 352 TextClassifier::FromFileDescriptor(fd, offset, size, new UniLib(env)) 353 .release()); 354#else 355 return reinterpret_cast<jlong>( 356 TextClassifier::FromFileDescriptor(fd, offset, size).release()); 357#endif 358} 359 360JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection) 361(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, 362 jint selection_end, jobject options) { 363 if (!ptr) { 364 return nullptr; 365 } 366 367 TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); 368 369 const std::string context_utf8 = ToStlString(env, context); 370 CodepointSpan input_indices = 371 ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); 372 CodepointSpan selection = model->SuggestSelection( 373 context_utf8, input_indices, FromJavaSelectionOptions(env, options)); 374 selection = ConvertIndicesUTF8ToBMP(context_utf8, selection); 375 376 jintArray result = env->NewIntArray(2); 377 env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection))); 378 env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection))); 379 return result; 380} 381 382JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText) 383(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, 384 jint selection_end, jobject options) { 385 if (!ptr) { 386 return nullptr; 387 } 388 TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr); 389 390 const std::string context_utf8 = ToStlString(env, context); 391 const CodepointSpan input_indices = 392 ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); 393 const std::vector<ClassificationResult> classification_result = 394 ff_model->ClassifyText(context_utf8, input_indices, 395 FromJavaClassificationOptions(env, options)); 396 397 return ClassificationResultsToJObjectArray(env, classification_result); 398} 399 400JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate) 401(JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) { 402 if (!ptr) { 403 return nullptr; 404 } 405 TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); 406 std::string context_utf8 = ToStlString(env, context); 407 std::vector<AnnotatedSpan> annotations = 408 model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options)); 409 410 jclass result_class = 411 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan"); 412 if (!result_class) { 413 TC_LOG(ERROR) << "Couldn't find result class: " 414 << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan"; 415 return nullptr; 416 } 417 418 jmethodID result_class_constructor = env->GetMethodID( 419 result_class, "<init>", 420 "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V"); 421 422 jobjectArray results = 423 env->NewObjectArray(annotations.size(), result_class, nullptr); 424 425 for (int i = 0; i < annotations.size(); ++i) { 426 CodepointSpan span_bmp = 427 ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span); 428 jobject result = env->NewObject( 429 result_class, result_class_constructor, 430 static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second), 431 ClassificationResultsToJObjectArray(env, 432 433 annotations[i].classification)); 434 env->SetObjectArrayElement(results, i, result); 435 env->DeleteLocalRef(result); 436 } 437 env->DeleteLocalRef(result_class); 438 return results; 439} 440 441JNI_METHOD(void, TC_CLASS_NAME, nativeClose) 442(JNIEnv* env, jobject thiz, jlong ptr) { 443 TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); 444 delete model; 445} 446 447JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage) 448(JNIEnv* env, jobject clazz, jint fd) { 449 TC_LOG(WARNING) << "Using deprecated getLanguage()."; 450 return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd); 451} 452 453JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales) 454(JNIEnv* env, jobject clazz, jint fd) { 455 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 456 new libtextclassifier2::ScopedMmap(fd)); 457 return GetLocalesFromMmap(env, mmap.get()); 458} 459 460JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor) 461(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { 462 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); 463 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 464 new libtextclassifier2::ScopedMmap(fd, offset, size)); 465 return GetLocalesFromMmap(env, mmap.get()); 466} 467 468JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion) 469(JNIEnv* env, jobject clazz, jint fd) { 470 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 471 new libtextclassifier2::ScopedMmap(fd)); 472 return GetVersionFromMmap(env, mmap.get()); 473} 474 475JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor) 476(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { 477 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); 478 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 479 new libtextclassifier2::ScopedMmap(fd, offset, size)); 480 return GetVersionFromMmap(env, mmap.get()); 481} 482 483JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName) 484(JNIEnv* env, jobject clazz, jint fd) { 485 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 486 new libtextclassifier2::ScopedMmap(fd)); 487 return GetNameFromMmap(env, mmap.get()); 488} 489 490JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor) 491(JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { 492 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); 493 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 494 new libtextclassifier2::ScopedMmap(fd, offset, size)); 495 return GetNameFromMmap(env, mmap.get()); 496} 497