1/*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "reflection.h"
18
19#include "class_linker.h"
20#include "common_throws.h"
21#include "dex_file-inl.h"
22#include "invoke_arg_array_builder.h"
23#include "jni_internal.h"
24#include "mirror/art_field-inl.h"
25#include "mirror/art_method-inl.h"
26#include "mirror/class.h"
27#include "mirror/class-inl.h"
28#include "mirror/object_array.h"
29#include "mirror/object_array-inl.h"
30#include "object_utils.h"
31#include "scoped_thread_state_change.h"
32#include "well_known_classes.h"
33
34namespace art {
35
36jobject InvokeMethod(const ScopedObjectAccess& soa, jobject javaMethod, jobject javaReceiver,
37                     jobject javaArgs) {
38  jmethodID mid = soa.Env()->FromReflectedMethod(javaMethod);
39  mirror::ArtMethod* m = soa.DecodeMethod(mid);
40
41  mirror::Class* declaring_class = m->GetDeclaringClass();
42  if (!Runtime::Current()->GetClassLinker()->EnsureInitialized(declaring_class, true, true)) {
43    return NULL;
44  }
45
46  mirror::Object* receiver = NULL;
47  if (!m->IsStatic()) {
48    // Check that the receiver is non-null and an instance of the field's declaring class.
49    receiver = soa.Decode<mirror::Object*>(javaReceiver);
50    if (!VerifyObjectInClass(receiver, declaring_class)) {
51      return NULL;
52    }
53
54    // Find the actual implementation of the virtual method.
55    m = receiver->GetClass()->FindVirtualMethodForVirtualOrInterface(m);
56    mid = soa.EncodeMethod(m);
57  }
58
59  // Get our arrays of arguments and their types, and check they're the same size.
60  mirror::ObjectArray<mirror::Object>* objects =
61      soa.Decode<mirror::ObjectArray<mirror::Object>*>(javaArgs);
62  MethodHelper mh(m);
63  const DexFile::TypeList* classes = mh.GetParameterTypeList();
64  uint32_t classes_size = classes == NULL ? 0 : classes->Size();
65  uint32_t arg_count = (objects != NULL) ? objects->GetLength() : 0;
66  if (arg_count != classes_size) {
67    ThrowIllegalArgumentException(NULL,
68                                  StringPrintf("Wrong number of arguments; expected %d, got %d",
69                                               classes_size, arg_count).c_str());
70    return NULL;
71  }
72
73  // Translate javaArgs to a jvalue[].
74  UniquePtr<jvalue[]> args(new jvalue[arg_count]);
75  JValue* decoded_args = reinterpret_cast<JValue*>(args.get());
76  for (uint32_t i = 0; i < arg_count; ++i) {
77    mirror::Object* arg = objects->Get(i);
78    mirror::Class* dst_class = mh.GetClassFromTypeIdx(classes->GetTypeItem(i).type_idx_);
79    if (!UnboxPrimitiveForArgument(arg, dst_class, decoded_args[i], m, i)) {
80      return NULL;
81    }
82    if (!dst_class->IsPrimitive()) {
83      args[i].l = soa.AddLocalReference<jobject>(arg);
84    }
85  }
86
87  // Invoke the method.
88  JValue value(InvokeWithJValues(soa, javaReceiver, mid, args.get()));
89
90  // Wrap any exception with "Ljava/lang/reflect/InvocationTargetException;" and return early.
91  if (soa.Self()->IsExceptionPending()) {
92    jthrowable th = soa.Env()->ExceptionOccurred();
93    soa.Env()->ExceptionClear();
94    jclass exception_class = soa.Env()->FindClass("java/lang/reflect/InvocationTargetException");
95    jmethodID mid = soa.Env()->GetMethodID(exception_class, "<init>", "(Ljava/lang/Throwable;)V");
96    jobject exception_instance = soa.Env()->NewObject(exception_class, mid, th);
97    soa.Env()->Throw(reinterpret_cast<jthrowable>(exception_instance));
98    return NULL;
99  }
100
101  // Box if necessary and return.
102  return soa.AddLocalReference<jobject>(BoxPrimitive(mh.GetReturnType()->GetPrimitiveType(), value));
103}
104
105bool VerifyObjectInClass(mirror::Object* o, mirror::Class* c) {
106  if (o == NULL) {
107    ThrowNullPointerException(NULL, "null receiver");
108    return false;
109  } else if (!o->InstanceOf(c)) {
110    std::string expected_class_name(PrettyDescriptor(c));
111    std::string actual_class_name(PrettyTypeOf(o));
112    ThrowIllegalArgumentException(NULL,
113                                  StringPrintf("Expected receiver of type %s, but got %s",
114                                               expected_class_name.c_str(),
115                                               actual_class_name.c_str()).c_str());
116    return false;
117  }
118  return true;
119}
120
121bool ConvertPrimitiveValue(const ThrowLocation* throw_location, bool unbox_for_result,
122                           Primitive::Type srcType, Primitive::Type dstType,
123                           const JValue& src, JValue& dst) {
124  CHECK(srcType != Primitive::kPrimNot && dstType != Primitive::kPrimNot);
125  switch (dstType) {
126  case Primitive::kPrimBoolean:
127    if (srcType == Primitive::kPrimBoolean) {
128      dst.SetZ(src.GetZ());
129      return true;
130    }
131    break;
132  case Primitive::kPrimChar:
133    if (srcType == Primitive::kPrimChar) {
134      dst.SetC(src.GetC());
135      return true;
136    }
137    break;
138  case Primitive::kPrimByte:
139    if (srcType == Primitive::kPrimByte) {
140      dst.SetB(src.GetB());
141      return true;
142    }
143    break;
144  case Primitive::kPrimShort:
145    if (srcType == Primitive::kPrimByte || srcType == Primitive::kPrimShort) {
146      dst.SetS(src.GetI());
147      return true;
148    }
149    break;
150  case Primitive::kPrimInt:
151    if (srcType == Primitive::kPrimByte || srcType == Primitive::kPrimChar ||
152        srcType == Primitive::kPrimShort || srcType == Primitive::kPrimInt) {
153      dst.SetI(src.GetI());
154      return true;
155    }
156    break;
157  case Primitive::kPrimLong:
158    if (srcType == Primitive::kPrimByte || srcType == Primitive::kPrimChar ||
159        srcType == Primitive::kPrimShort || srcType == Primitive::kPrimInt) {
160      dst.SetJ(src.GetI());
161      return true;
162    } else if (srcType == Primitive::kPrimLong) {
163      dst.SetJ(src.GetJ());
164      return true;
165    }
166    break;
167  case Primitive::kPrimFloat:
168    if (srcType == Primitive::kPrimByte || srcType == Primitive::kPrimChar ||
169        srcType == Primitive::kPrimShort || srcType == Primitive::kPrimInt) {
170      dst.SetF(src.GetI());
171      return true;
172    } else if (srcType == Primitive::kPrimLong) {
173      dst.SetF(src.GetJ());
174      return true;
175    } else if (srcType == Primitive::kPrimFloat) {
176      dst.SetF(src.GetF());
177      return true;
178    }
179    break;
180  case Primitive::kPrimDouble:
181    if (srcType == Primitive::kPrimByte || srcType == Primitive::kPrimChar ||
182        srcType == Primitive::kPrimShort || srcType == Primitive::kPrimInt) {
183      dst.SetD(src.GetI());
184      return true;
185    } else if (srcType == Primitive::kPrimLong) {
186      dst.SetD(src.GetJ());
187      return true;
188    } else if (srcType == Primitive::kPrimFloat) {
189      dst.SetD(src.GetF());
190      return true;
191    } else if (srcType == Primitive::kPrimDouble) {
192      dst.SetJ(src.GetJ());
193      return true;
194    }
195    break;
196  default:
197    break;
198  }
199  if (!unbox_for_result) {
200    ThrowIllegalArgumentException(throw_location,
201                                  StringPrintf("Invalid primitive conversion from %s to %s",
202                                               PrettyDescriptor(srcType).c_str(),
203                                               PrettyDescriptor(dstType).c_str()).c_str());
204  } else {
205    ThrowClassCastException(throw_location,
206                            StringPrintf("Couldn't convert result of type %s to %s",
207                                         PrettyDescriptor(srcType).c_str(),
208                                         PrettyDescriptor(dstType).c_str()).c_str());
209  }
210  return false;
211}
212
213mirror::Object* BoxPrimitive(Primitive::Type src_class, const JValue& value) {
214  if (src_class == Primitive::kPrimNot) {
215    return value.GetL();
216  }
217
218  jmethodID m = NULL;
219  switch (src_class) {
220  case Primitive::kPrimBoolean:
221    m = WellKnownClasses::java_lang_Boolean_valueOf;
222    break;
223  case Primitive::kPrimByte:
224    m = WellKnownClasses::java_lang_Byte_valueOf;
225    break;
226  case Primitive::kPrimChar:
227    m = WellKnownClasses::java_lang_Character_valueOf;
228    break;
229  case Primitive::kPrimDouble:
230    m = WellKnownClasses::java_lang_Double_valueOf;
231    break;
232  case Primitive::kPrimFloat:
233    m = WellKnownClasses::java_lang_Float_valueOf;
234    break;
235  case Primitive::kPrimInt:
236    m = WellKnownClasses::java_lang_Integer_valueOf;
237    break;
238  case Primitive::kPrimLong:
239    m = WellKnownClasses::java_lang_Long_valueOf;
240    break;
241  case Primitive::kPrimShort:
242    m = WellKnownClasses::java_lang_Short_valueOf;
243    break;
244  case Primitive::kPrimVoid:
245    // There's no such thing as a void field, and void methods invoked via reflection return null.
246    return NULL;
247  default:
248    LOG(FATAL) << static_cast<int>(src_class);
249  }
250
251  ScopedObjectAccessUnchecked soa(Thread::Current());
252  if (kIsDebugBuild) {
253    CHECK_EQ(soa.Self()->GetState(), kRunnable);
254  }
255
256  ArgArray arg_array(NULL, 0);
257  JValue result;
258  if (src_class == Primitive::kPrimDouble || src_class == Primitive::kPrimLong) {
259    arg_array.AppendWide(value.GetJ());
260  } else {
261    arg_array.Append(value.GetI());
262  }
263
264  soa.DecodeMethod(m)->Invoke(soa.Self(), arg_array.GetArray(), arg_array.GetNumBytes(),
265                              &result, 'L');
266  return result.GetL();
267}
268
269static std::string UnboxingFailureKind(mirror::ArtMethod* m, int index, mirror::ArtField* f)
270    SHARED_LOCKS_REQUIRED(Locks::mutator_lock_) {
271  if (m != NULL && index != -1) {
272    ++index;  // Humans count from 1.
273    return StringPrintf("method %s argument %d", PrettyMethod(m, false).c_str(), index);
274  }
275  if (f != NULL) {
276    return "field " + PrettyField(f, false);
277  }
278  return "result";
279}
280
281static bool UnboxPrimitive(const ThrowLocation* throw_location, mirror::Object* o,
282                           mirror::Class* dst_class, JValue& unboxed_value,
283                           mirror::ArtMethod* m, int index, mirror::ArtField* f)
284    SHARED_LOCKS_REQUIRED(Locks::mutator_lock_) {
285  bool unbox_for_result = (f == NULL) && (index == -1);
286  if (!dst_class->IsPrimitive()) {
287    if (UNLIKELY(o != NULL && !o->InstanceOf(dst_class))) {
288      if (!unbox_for_result) {
289        ThrowIllegalArgumentException(throw_location,
290                                      StringPrintf("%s has type %s, got %s",
291                                                   UnboxingFailureKind(m, index, f).c_str(),
292                                                   PrettyDescriptor(dst_class).c_str(),
293                                                   PrettyTypeOf(o).c_str()).c_str());
294      } else {
295        ThrowClassCastException(throw_location,
296                                StringPrintf("Couldn't convert result of type %s to %s",
297                                             PrettyTypeOf(o).c_str(),
298                                             PrettyDescriptor(dst_class).c_str()).c_str());
299      }
300      return false;
301    }
302    unboxed_value.SetL(o);
303    return true;
304  }
305  if (UNLIKELY(dst_class->GetPrimitiveType() == Primitive::kPrimVoid)) {
306    ThrowIllegalArgumentException(throw_location,
307                                  StringPrintf("Can't unbox %s to void",
308                                               UnboxingFailureKind(m, index, f).c_str()).c_str());
309    return false;
310  }
311  if (UNLIKELY(o == NULL)) {
312    if (!unbox_for_result) {
313      ThrowIllegalArgumentException(throw_location,
314                                    StringPrintf("%s has type %s, got null",
315                                                 UnboxingFailureKind(m, index, f).c_str(),
316                                                 PrettyDescriptor(dst_class).c_str()).c_str());
317    } else {
318      ThrowNullPointerException(throw_location,
319                                StringPrintf("Expected to unbox a '%s' primitive type but was returned null",
320                                             PrettyDescriptor(dst_class).c_str()).c_str());
321    }
322    return false;
323  }
324
325  JValue boxed_value;
326  std::string src_descriptor(ClassHelper(o->GetClass()).GetDescriptor());
327  mirror::Class* src_class = NULL;
328  ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
329  mirror::ArtField* primitive_field = o->GetClass()->GetIFields()->Get(0);
330  if (src_descriptor == "Ljava/lang/Boolean;") {
331    src_class = class_linker->FindPrimitiveClass('Z');
332    boxed_value.SetZ(primitive_field->GetBoolean(o));
333  } else if (src_descriptor == "Ljava/lang/Byte;") {
334    src_class = class_linker->FindPrimitiveClass('B');
335    boxed_value.SetB(primitive_field->GetByte(o));
336  } else if (src_descriptor == "Ljava/lang/Character;") {
337    src_class = class_linker->FindPrimitiveClass('C');
338    boxed_value.SetC(primitive_field->GetChar(o));
339  } else if (src_descriptor == "Ljava/lang/Float;") {
340    src_class = class_linker->FindPrimitiveClass('F');
341    boxed_value.SetF(primitive_field->GetFloat(o));
342  } else if (src_descriptor == "Ljava/lang/Double;") {
343    src_class = class_linker->FindPrimitiveClass('D');
344    boxed_value.SetD(primitive_field->GetDouble(o));
345  } else if (src_descriptor == "Ljava/lang/Integer;") {
346    src_class = class_linker->FindPrimitiveClass('I');
347    boxed_value.SetI(primitive_field->GetInt(o));
348  } else if (src_descriptor == "Ljava/lang/Long;") {
349    src_class = class_linker->FindPrimitiveClass('J');
350    boxed_value.SetJ(primitive_field->GetLong(o));
351  } else if (src_descriptor == "Ljava/lang/Short;") {
352    src_class = class_linker->FindPrimitiveClass('S');
353    boxed_value.SetS(primitive_field->GetShort(o));
354  } else {
355    ThrowIllegalArgumentException(throw_location,
356                                  StringPrintf("%s has type %s, got %s",
357                                               UnboxingFailureKind(m, index, f).c_str(),
358                                               PrettyDescriptor(dst_class).c_str(),
359                                               PrettyDescriptor(src_descriptor.c_str()).c_str()).c_str());
360    return false;
361  }
362
363  return ConvertPrimitiveValue(throw_location, unbox_for_result,
364                               src_class->GetPrimitiveType(), dst_class->GetPrimitiveType(),
365                               boxed_value, unboxed_value);
366}
367
368bool UnboxPrimitiveForArgument(mirror::Object* o, mirror::Class* dst_class, JValue& unboxed_value,
369                               mirror::ArtMethod* m, size_t index) {
370  CHECK(m != NULL);
371  return UnboxPrimitive(NULL, o, dst_class, unboxed_value, m, index, NULL);
372}
373
374bool UnboxPrimitiveForField(mirror::Object* o, mirror::Class* dst_class, JValue& unboxed_value,
375                            mirror::ArtField* f) {
376  CHECK(f != NULL);
377  return UnboxPrimitive(NULL, o, dst_class, unboxed_value, NULL, -1, f);
378}
379
380bool UnboxPrimitiveForResult(const ThrowLocation& throw_location, mirror::Object* o,
381                             mirror::Class* dst_class, JValue& unboxed_value) {
382  return UnboxPrimitive(&throw_location, o, dst_class, unboxed_value, NULL, -1, NULL);
383}
384
385}  // namespace art
386