1/* 2 * Copyright (c) 2016 Mockito contributors 3 * This program is made available under the terms of the MIT License. 4 */ 5package org.mockito.internal.creation.bytebuddy; 6 7import net.bytebuddy.asm.Advice; 8import net.bytebuddy.description.method.MethodDescription; 9import net.bytebuddy.description.type.TypeDescription; 10import net.bytebuddy.dynamic.scaffold.MethodGraph; 11import net.bytebuddy.implementation.bind.annotation.Argument; 12import net.bytebuddy.implementation.bind.annotation.This; 13import net.bytebuddy.implementation.bytecode.assign.Assigner; 14import org.mockito.exceptions.base.MockitoException; 15import org.mockito.internal.debugging.LocationImpl; 16import org.mockito.internal.exceptions.stacktrace.ConditionalStackTraceFilter; 17import org.mockito.internal.invocation.RealMethod; 18import org.mockito.internal.invocation.SerializableMethod; 19import org.mockito.internal.invocation.mockref.MockReference; 20import org.mockito.internal.invocation.mockref.MockWeakReference; 21import org.mockito.internal.util.concurrent.WeakConcurrentMap; 22 23import java.io.IOException; 24import java.io.ObjectInputStream; 25import java.io.Serializable; 26import java.lang.annotation.Retention; 27import java.lang.annotation.RetentionPolicy; 28import java.lang.ref.SoftReference; 29import java.lang.reflect.InvocationTargetException; 30import java.lang.reflect.Method; 31import java.lang.reflect.Modifier; 32import java.util.ArrayList; 33import java.util.List; 34import java.util.concurrent.Callable; 35 36public class MockMethodAdvice extends MockMethodDispatcher { 37 38 final WeakConcurrentMap<Object, MockMethodInterceptor> interceptors; 39 40 private final String identifier; 41 42 private final SelfCallInfo selfCallInfo = new SelfCallInfo(); 43 private final MethodGraph.Compiler compiler = MethodGraph.Compiler.Default.forJavaHierarchy(); 44 private final WeakConcurrentMap<Class<?>, SoftReference<MethodGraph>> graphs 45 = new WeakConcurrentMap.WithInlinedExpunction<Class<?>, SoftReference<MethodGraph>>(); 46 47 public MockMethodAdvice(WeakConcurrentMap<Object, MockMethodInterceptor> interceptors, String identifier) { 48 this.interceptors = interceptors; 49 this.identifier = identifier; 50 } 51 52 @SuppressWarnings("unused") 53 @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class) 54 private static Callable<?> enter(@Identifier String identifier, 55 @Advice.This Object mock, 56 @Advice.Origin Method origin, 57 @Advice.AllArguments Object[] arguments) throws Throwable { 58 MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, mock); 59 if (dispatcher == null || !dispatcher.isMocked(mock) || dispatcher.isOverridden(mock, origin)) { 60 return null; 61 } else { 62 return dispatcher.handle(mock, origin, arguments); 63 } 64 } 65 66 @SuppressWarnings({"unused", "UnusedAssignment"}) 67 @Advice.OnMethodExit 68 private static void exit(@Advice.Return(readOnly = false, typing = Assigner.Typing.DYNAMIC) Object returned, 69 @Advice.Enter Callable<?> mocked) throws Throwable { 70 if (mocked != null) { 71 returned = mocked.call(); 72 } 73 } 74 75 static Throwable hideRecursiveCall(Throwable throwable, int current, Class<?> targetType) { 76 try { 77 StackTraceElement[] stack = throwable.getStackTrace(); 78 int skip = 0; 79 StackTraceElement next; 80 do { 81 next = stack[stack.length - current - ++skip]; 82 } while (!next.getClassName().equals(targetType.getName())); 83 int top = stack.length - current - skip; 84 StackTraceElement[] cleared = new StackTraceElement[stack.length - skip]; 85 System.arraycopy(stack, 0, cleared, 0, top); 86 System.arraycopy(stack, top + skip, cleared, top, current); 87 throwable.setStackTrace(cleared); 88 return throwable; 89 } catch (RuntimeException ignored) { 90 // This should not happen unless someone instrumented or manipulated exception stack traces. 91 return throwable; 92 } 93 } 94 95 @Override 96 public Callable<?> handle(Object instance, Method origin, Object[] arguments) throws Throwable { 97 MockMethodInterceptor interceptor = interceptors.get(instance); 98 if (interceptor == null) { 99 return null; 100 } 101 RealMethod realMethod; 102 if (instance instanceof Serializable) { 103 realMethod = new SerializableRealMethodCall(identifier, origin, instance, arguments); 104 } else { 105 realMethod = new RealMethodCall(selfCallInfo, origin, instance, arguments); 106 } 107 Throwable t = new Throwable(); 108 t.setStackTrace(skipInlineMethodElement(t.getStackTrace())); 109 return new ReturnValueWrapper(interceptor.doIntercept(instance, 110 origin, 111 arguments, 112 realMethod, 113 new LocationImpl(t))); 114 } 115 116 @Override 117 public boolean isMock(Object instance) { 118 // We need to exclude 'interceptors.target' explicitly to avoid a recursive check on whether 119 // the map is a mock object what requires reading from the map. 120 return instance != interceptors.target && interceptors.containsKey(instance); 121 } 122 123 @Override 124 public boolean isMocked(Object instance) { 125 return selfCallInfo.checkSuperCall(instance) && isMock(instance); 126 } 127 128 @Override 129 public boolean isOverridden(Object instance, Method origin) { 130 SoftReference<MethodGraph> reference = graphs.get(instance.getClass()); 131 MethodGraph methodGraph = reference == null ? null : reference.get(); 132 if (methodGraph == null) { 133 methodGraph = compiler.compile(new TypeDescription.ForLoadedType(instance.getClass())); 134 graphs.put(instance.getClass(), new SoftReference<MethodGraph>(methodGraph)); 135 } 136 MethodGraph.Node node = methodGraph.locate(new MethodDescription.ForLoadedMethod(origin).asSignatureToken()); 137 return !node.getSort().isResolved() || !node.getRepresentative().asDefined().getDeclaringType().represents(origin.getDeclaringClass()); 138 } 139 140 private static class RealMethodCall implements RealMethod { 141 142 private final SelfCallInfo selfCallInfo; 143 144 private final Method origin; 145 146 private final MockWeakReference<Object> instanceRef; 147 148 private final Object[] arguments; 149 150 private RealMethodCall(SelfCallInfo selfCallInfo, Method origin, Object instance, Object[] arguments) { 151 this.selfCallInfo = selfCallInfo; 152 this.origin = origin; 153 this.instanceRef = new MockWeakReference<Object>(instance); 154 this.arguments = arguments; 155 } 156 157 @Override 158 public boolean isInvokable() { 159 return true; 160 } 161 162 @Override 163 public Object invoke() throws Throwable { 164 if (!Modifier.isPublic(origin.getDeclaringClass().getModifiers() & origin.getModifiers())) { 165 origin.setAccessible(true); 166 } 167 selfCallInfo.set(instanceRef.get()); 168 return tryInvoke(origin, instanceRef.get(), arguments); 169 } 170 171 } 172 173 private static class SerializableRealMethodCall implements RealMethod { 174 175 private final String identifier; 176 177 private final SerializableMethod origin; 178 179 private final MockReference<Object> instanceRef; 180 181 private final Object[] arguments; 182 183 private SerializableRealMethodCall(String identifier, Method origin, Object instance, Object[] arguments) { 184 this.origin = new SerializableMethod(origin); 185 this.identifier = identifier; 186 this.instanceRef = new MockWeakReference<Object>(instance); 187 this.arguments = arguments; 188 } 189 190 @Override 191 public boolean isInvokable() { 192 return true; 193 } 194 195 @Override 196 public Object invoke() throws Throwable { 197 Method method = origin.getJavaMethod(); 198 if (!Modifier.isPublic(method.getDeclaringClass().getModifiers() & method.getModifiers())) { 199 method.setAccessible(true); 200 } 201 MockMethodDispatcher mockMethodDispatcher = MockMethodDispatcher.get(identifier, instanceRef.get()); 202 if (!(mockMethodDispatcher instanceof MockMethodAdvice)) { 203 throw new MockitoException("Unexpected dispatcher for advice-based super call"); 204 } 205 Object previous = ((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.replace(instanceRef.get()); 206 try { 207 return tryInvoke(method, instanceRef.get(), arguments); 208 } finally { 209 ((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.set(previous); 210 } 211 } 212 } 213 214 private static Object tryInvoke(Method origin, Object instance, Object[] arguments) throws Throwable { 215 try { 216 return origin.invoke(instance, arguments); 217 } catch (InvocationTargetException exception) { 218 Throwable cause = exception.getCause(); 219 new ConditionalStackTraceFilter().filter(hideRecursiveCall(cause, new Throwable().getStackTrace().length, origin.getDeclaringClass())); 220 throw cause; 221 } 222 } 223 224 // With inline mocking, mocks for concrete classes are not subclassed, so elements of the stubbing methods are not filtered out. 225 // Therefore, if the method is inlined, skip the element. 226 private static StackTraceElement[] skipInlineMethodElement(StackTraceElement[] elements) { 227 List<StackTraceElement> list = new ArrayList<StackTraceElement>(elements.length); 228 for (int i = 0; i < elements.length; i++) { 229 StackTraceElement element = elements[i]; 230 list.add(element); 231 if (element.getClassName().equals(MockMethodAdvice.class.getName()) && element.getMethodName().equals("handle")) { 232 // If the current element is MockMethodAdvice#handle(), the next is assumed to be an inlined method. 233 i++; 234 } 235 } 236 return list.toArray(new StackTraceElement[list.size()]); 237 } 238 239 private static class ReturnValueWrapper implements Callable<Object> { 240 241 private final Object returned; 242 243 private ReturnValueWrapper(Object returned) { 244 this.returned = returned; 245 } 246 247 @Override 248 public Object call() { 249 return returned; 250 } 251 } 252 253 private static class SelfCallInfo extends ThreadLocal<Object> { 254 255 Object replace(Object value) { 256 Object current = get(); 257 set(value); 258 return current; 259 } 260 261 boolean checkSuperCall(Object value) { 262 if (value == get()) { 263 set(null); 264 return false; 265 } else { 266 return true; 267 } 268 } 269 } 270 271 @Retention(RetentionPolicy.RUNTIME) 272 @interface Identifier { 273 274 } 275 276 static class ForHashCode { 277 278 @SuppressWarnings("unused") 279 @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class) 280 private static boolean enter(@Identifier String id, 281 @Advice.This Object self) { 282 MockMethodDispatcher dispatcher = MockMethodDispatcher.get(id, self); 283 return dispatcher != null && dispatcher.isMock(self); 284 } 285 286 @SuppressWarnings({"unused", "UnusedAssignment"}) 287 @Advice.OnMethodExit 288 private static void enter(@Advice.This Object self, 289 @Advice.Return(readOnly = false) int hashCode, 290 @Advice.Enter boolean skipped) { 291 if (skipped) { 292 hashCode = System.identityHashCode(self); 293 } 294 } 295 } 296 297 static class ForEquals { 298 299 @SuppressWarnings("unused") 300 @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class) 301 private static boolean enter(@Identifier String identifier, 302 @Advice.This Object self) { 303 MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, self); 304 return dispatcher != null && dispatcher.isMock(self); 305 } 306 307 @SuppressWarnings({"unused", "UnusedAssignment"}) 308 @Advice.OnMethodExit 309 private static void enter(@Advice.This Object self, 310 @Advice.Argument(0) Object other, 311 @Advice.Return(readOnly = false) boolean equals, 312 @Advice.Enter boolean skipped) { 313 if (skipped) { 314 equals = self == other; 315 } 316 } 317 } 318 319 public static class ForReadObject { 320 321 @SuppressWarnings("unused") 322 public static void doReadObject(@Identifier String identifier, 323 @This MockAccess thiz, 324 @Argument(0) ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException { 325 objectInputStream.defaultReadObject(); 326 MockMethodAdvice mockMethodAdvice = (MockMethodAdvice) MockMethodDispatcher.get(identifier, thiz); 327 if (mockMethodAdvice != null) { 328 mockMethodAdvice.interceptors.put(thiz, thiz.getMockitoInterceptor()); 329 } 330 } 331 } 332} 333