1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "jni_binder.h"
18
19#include <dlfcn.h>
20#include <inttypes.h>
21#include <stdio.h>
22
23#include "android-base/logging.h"
24#include "android-base/stringprintf.h"
25
26#include "jvmti_helper.h"
27#include "scoped_local_ref.h"
28#include "scoped_utf_chars.h"
29#include "ti_utf.h"
30
31namespace art {
32
33static std::string MangleForJni(const std::string& s) {
34  std::string result;
35  size_t char_count = ti::CountModifiedUtf8Chars(s.c_str(), s.length());
36  const char* cp = &s[0];
37  for (size_t i = 0; i < char_count; ++i) {
38    uint32_t ch = ti::GetUtf16FromUtf8(&cp);
39    if ((ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9')) {
40      result.push_back(ch);
41    } else if (ch == '.' || ch == '/') {
42      result += "_";
43    } else if (ch == '_') {
44      result += "_1";
45    } else if (ch == ';') {
46      result += "_2";
47    } else if (ch == '[') {
48      result += "_3";
49    } else {
50      const uint16_t leading = ti::GetLeadingUtf16Char(ch);
51      const uint32_t trailing = ti::GetTrailingUtf16Char(ch);
52
53      android::base::StringAppendF(&result, "_0%04x", leading);
54      if (trailing != 0) {
55        android::base::StringAppendF(&result, "_0%04x", trailing);
56      }
57    }
58  }
59  return result;
60}
61
62static std::string GetJniShortName(const std::string& class_descriptor, const std::string& method) {
63  // Remove the leading 'L' and trailing ';'...
64  std::string class_name(class_descriptor);
65  CHECK_EQ(class_name[0], 'L') << class_name;
66  CHECK_EQ(class_name[class_name.size() - 1], ';') << class_name;
67  class_name.erase(0, 1);
68  class_name.erase(class_name.size() - 1, 1);
69
70  std::string short_name;
71  short_name += "Java_";
72  short_name += MangleForJni(class_name);
73  short_name += "_";
74  short_name += MangleForJni(method);
75  return short_name;
76}
77
78static void BindMethod(jvmtiEnv* jvmti_env, JNIEnv* env, jclass klass, jmethodID method) {
79  std::string name;
80  std::string signature;
81  std::string mangled_names[2];
82  {
83    char* name_cstr;
84    char* sig_cstr;
85    jvmtiError name_result = jvmti_env->GetMethodName(method, &name_cstr, &sig_cstr, nullptr);
86    CheckJvmtiError(jvmti_env, name_result);
87    CHECK(name_cstr != nullptr);
88    CHECK(sig_cstr != nullptr);
89    name = name_cstr;
90    signature = sig_cstr;
91
92    char* klass_name;
93    jvmtiError klass_result = jvmti_env->GetClassSignature(klass, &klass_name, nullptr);
94    CheckJvmtiError(jvmti_env, klass_result);
95
96    mangled_names[0] = GetJniShortName(klass_name, name);
97    // TODO: Long JNI name.
98
99    CheckJvmtiError(jvmti_env, Deallocate(jvmti_env, name_cstr));
100    CheckJvmtiError(jvmti_env, Deallocate(jvmti_env, sig_cstr));
101    CheckJvmtiError(jvmti_env, Deallocate(jvmti_env, klass_name));
102  }
103
104  for (const std::string& mangled_name : mangled_names) {
105    if (mangled_name.empty()) {
106      continue;
107    }
108    void* sym = dlsym(RTLD_DEFAULT, mangled_name.c_str());
109    if (sym == nullptr) {
110      continue;
111    }
112
113    JNINativeMethod native_method;
114    native_method.fnPtr = sym;
115    native_method.name = name.c_str();
116    native_method.signature = signature.c_str();
117
118    env->RegisterNatives(klass, &native_method, 1);
119
120    return;
121  }
122
123  LOG(FATAL) << "Could not find " << mangled_names[0];
124}
125
126static std::string DescriptorToDot(const char* descriptor) {
127  size_t length = strlen(descriptor);
128  if (length > 1) {
129    if (descriptor[0] == 'L' && descriptor[length - 1] == ';') {
130      // Descriptors have the leading 'L' and trailing ';' stripped.
131      std::string result(descriptor + 1, length - 2);
132      std::replace(result.begin(), result.end(), '/', '.');
133      return result;
134    } else {
135      // For arrays the 'L' and ';' remain intact.
136      std::string result(descriptor);
137      std::replace(result.begin(), result.end(), '/', '.');
138      return result;
139    }
140  }
141  // Do nothing for non-class/array descriptors.
142  return descriptor;
143}
144
145static jobject GetSystemClassLoader(JNIEnv* env) {
146  ScopedLocalRef<jclass> cl_klass(env, env->FindClass("java/lang/ClassLoader"));
147  CHECK(cl_klass.get() != nullptr);
148  jmethodID getsystemclassloader_method = env->GetStaticMethodID(cl_klass.get(),
149                                                                 "getSystemClassLoader",
150                                                                 "()Ljava/lang/ClassLoader;");
151  CHECK(getsystemclassloader_method != nullptr);
152  return env->CallStaticObjectMethod(cl_klass.get(), getsystemclassloader_method);
153}
154
155static jclass FindClassWithClassLoader(JNIEnv* env, const char* class_name, jobject class_loader) {
156  // Create a String of the name.
157  std::string descriptor = android::base::StringPrintf("L%s;", class_name);
158  std::string dot_name = DescriptorToDot(descriptor.c_str());
159  ScopedLocalRef<jstring> name_str(env, env->NewStringUTF(dot_name.c_str()));
160
161  // Call Class.forName with it.
162  ScopedLocalRef<jclass> c_klass(env, env->FindClass("java/lang/Class"));
163  CHECK(c_klass.get() != nullptr);
164  jmethodID forname_method = env->GetStaticMethodID(
165      c_klass.get(),
166      "forName",
167      "(Ljava/lang/String;ZLjava/lang/ClassLoader;)Ljava/lang/Class;");
168  CHECK(forname_method != nullptr);
169
170  return static_cast<jclass>(env->CallStaticObjectMethod(c_klass.get(),
171                                                         forname_method,
172                                                         name_str.get(),
173                                                         JNI_FALSE,
174                                                         class_loader));
175}
176
177jclass FindClass(jvmtiEnv* jvmti_env, JNIEnv* env, const char* class_name, jobject class_loader) {
178  if (class_loader != nullptr) {
179    return FindClassWithClassLoader(env, class_name, class_loader);
180  }
181
182  jclass from_implied = env->FindClass(class_name);
183  if (from_implied != nullptr) {
184    return from_implied;
185  }
186  env->ExceptionClear();
187
188  ScopedLocalRef<jobject> system_class_loader(env, GetSystemClassLoader(env));
189  CHECK(system_class_loader.get() != nullptr);
190  jclass from_system = FindClassWithClassLoader(env, class_name, system_class_loader.get());
191  if (from_system != nullptr) {
192    return from_system;
193  }
194  env->ExceptionClear();
195
196  // Look at the context classloaders of all threads.
197  jint thread_count;
198  jthread* threads;
199  CheckJvmtiError(jvmti_env, jvmti_env->GetAllThreads(&thread_count, &threads));
200  JvmtiUniquePtr threads_uptr = MakeJvmtiUniquePtr(jvmti_env, threads);
201
202  jclass result = nullptr;
203  for (jint t = 0; t != thread_count; ++t) {
204    // Always loop over all elements, as we need to free the local references.
205    if (result == nullptr) {
206      jvmtiThreadInfo info;
207      CheckJvmtiError(jvmti_env, jvmti_env->GetThreadInfo(threads[t], &info));
208      CheckJvmtiError(jvmti_env, Deallocate(jvmti_env, info.name));
209      if (info.thread_group != nullptr) {
210        env->DeleteLocalRef(info.thread_group);
211      }
212      if (info.context_class_loader != nullptr) {
213        result = FindClassWithClassLoader(env, class_name, info.context_class_loader);
214        env->ExceptionClear();
215        env->DeleteLocalRef(info.context_class_loader);
216      }
217    }
218    env->DeleteLocalRef(threads[t]);
219  }
220
221  if (result != nullptr) {
222    return result;
223  }
224
225  // TODO: Implement scanning *all* classloaders.
226  LOG(FATAL) << "Unimplemented";
227
228  return nullptr;
229}
230
231void BindFunctionsOnClass(jvmtiEnv* jvmti_env, JNIEnv* env, jclass klass) {
232  // Use JVMTI to get the methods.
233  jint method_count;
234  jmethodID* methods;
235  jvmtiError methods_result = jvmti_env->GetClassMethods(klass, &method_count, &methods);
236  CheckJvmtiError(jvmti_env, methods_result);
237
238  // Check each method.
239  for (jint i = 0; i < method_count; ++i) {
240    jint modifiers;
241    jvmtiError mod_result = jvmti_env->GetMethodModifiers(methods[i], &modifiers);
242    CheckJvmtiError(jvmti_env, mod_result);
243    constexpr jint kNative = static_cast<jint>(0x0100);
244    if ((modifiers & kNative) != 0) {
245      BindMethod(jvmti_env, env, klass, methods[i]);
246    }
247  }
248
249  CheckJvmtiError(jvmti_env, Deallocate(jvmti_env, methods));
250}
251
252void BindFunctions(jvmtiEnv* jvmti_env, JNIEnv* env, const char* class_name, jobject class_loader) {
253  // Use JNI to load the class.
254  ScopedLocalRef<jclass> klass(env, FindClass(jvmti_env, env, class_name, class_loader));
255  CHECK(klass.get() != nullptr) << class_name;
256  BindFunctionsOnClass(jvmti_env, env, klass.get());
257}
258
259}  // namespace art
260