1/*
2 * Copyright (c) 2016 Mockito contributors
3 * This program is made available under the terms of the MIT License.
4 */
5package org.mockito.internal.creation.bytebuddy;
6
7import net.bytebuddy.asm.Advice;
8import net.bytebuddy.description.method.MethodDescription;
9import net.bytebuddy.description.type.TypeDescription;
10import net.bytebuddy.dynamic.scaffold.MethodGraph;
11import net.bytebuddy.implementation.bind.annotation.Argument;
12import net.bytebuddy.implementation.bind.annotation.This;
13import net.bytebuddy.implementation.bytecode.assign.Assigner;
14import org.mockito.exceptions.base.MockitoException;
15import org.mockito.internal.debugging.LocationImpl;
16import org.mockito.internal.exceptions.stacktrace.ConditionalStackTraceFilter;
17import org.mockito.internal.invocation.RealMethod;
18import org.mockito.internal.invocation.SerializableMethod;
19import org.mockito.internal.invocation.mockref.MockReference;
20import org.mockito.internal.invocation.mockref.MockWeakReference;
21import org.mockito.internal.util.concurrent.WeakConcurrentMap;
22
23import java.io.IOException;
24import java.io.ObjectInputStream;
25import java.io.Serializable;
26import java.lang.annotation.Retention;
27import java.lang.annotation.RetentionPolicy;
28import java.lang.ref.SoftReference;
29import java.lang.reflect.InvocationTargetException;
30import java.lang.reflect.Method;
31import java.lang.reflect.Modifier;
32import java.util.ArrayList;
33import java.util.List;
34import java.util.concurrent.Callable;
35
36public class MockMethodAdvice extends MockMethodDispatcher {
37
38    final WeakConcurrentMap<Object, MockMethodInterceptor> interceptors;
39
40    private final String identifier;
41
42    private final SelfCallInfo selfCallInfo = new SelfCallInfo();
43    private final MethodGraph.Compiler compiler = MethodGraph.Compiler.Default.forJavaHierarchy();
44    private final WeakConcurrentMap<Class<?>, SoftReference<MethodGraph>> graphs
45        = new WeakConcurrentMap.WithInlinedExpunction<Class<?>, SoftReference<MethodGraph>>();
46
47    public MockMethodAdvice(WeakConcurrentMap<Object, MockMethodInterceptor> interceptors, String identifier) {
48        this.interceptors = interceptors;
49        this.identifier = identifier;
50    }
51
52    @SuppressWarnings("unused")
53    @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
54    private static Callable<?> enter(@Identifier String identifier,
55                                     @Advice.This Object mock,
56                                     @Advice.Origin Method origin,
57                                     @Advice.AllArguments Object[] arguments) throws Throwable {
58        MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, mock);
59        if (dispatcher == null || !dispatcher.isMocked(mock) || dispatcher.isOverridden(mock, origin)) {
60            return null;
61        } else {
62            return dispatcher.handle(mock, origin, arguments);
63        }
64    }
65
66    @SuppressWarnings({"unused", "UnusedAssignment"})
67    @Advice.OnMethodExit
68    private static void exit(@Advice.Return(readOnly = false, typing = Assigner.Typing.DYNAMIC) Object returned,
69                             @Advice.Enter Callable<?> mocked) throws Throwable {
70        if (mocked != null) {
71            returned = mocked.call();
72        }
73    }
74
75    static Throwable hideRecursiveCall(Throwable throwable, int current, Class<?> targetType) {
76        try {
77            StackTraceElement[] stack = throwable.getStackTrace();
78            int skip = 0;
79            StackTraceElement next;
80            do {
81                next = stack[stack.length - current - ++skip];
82            } while (!next.getClassName().equals(targetType.getName()));
83            int top = stack.length - current - skip;
84            StackTraceElement[] cleared = new StackTraceElement[stack.length - skip];
85            System.arraycopy(stack, 0, cleared, 0, top);
86            System.arraycopy(stack, top + skip, cleared, top, current);
87            throwable.setStackTrace(cleared);
88            return throwable;
89        } catch (RuntimeException ignored) {
90            // This should not happen unless someone instrumented or manipulated exception stack traces.
91            return throwable;
92        }
93    }
94
95    @Override
96    public Callable<?> handle(Object instance, Method origin, Object[] arguments) throws Throwable {
97        MockMethodInterceptor interceptor = interceptors.get(instance);
98        if (interceptor == null) {
99            return null;
100        }
101        RealMethod realMethod;
102        if (instance instanceof Serializable) {
103            realMethod = new SerializableRealMethodCall(identifier, origin, instance, arguments);
104        } else {
105            realMethod = new RealMethodCall(selfCallInfo, origin, instance, arguments);
106        }
107        Throwable t = new Throwable();
108        t.setStackTrace(skipInlineMethodElement(t.getStackTrace()));
109        return new ReturnValueWrapper(interceptor.doIntercept(instance,
110                origin,
111                arguments,
112                realMethod,
113                new LocationImpl(t)));
114    }
115
116    @Override
117    public boolean isMock(Object instance) {
118        // We need to exclude 'interceptors.target' explicitly to avoid a recursive check on whether
119        // the map is a mock object what requires reading from the map.
120        return instance != interceptors.target && interceptors.containsKey(instance);
121    }
122
123    @Override
124    public boolean isMocked(Object instance) {
125        return selfCallInfo.checkSuperCall(instance) && isMock(instance);
126    }
127
128    @Override
129    public boolean isOverridden(Object instance, Method origin) {
130        SoftReference<MethodGraph> reference = graphs.get(instance.getClass());
131        MethodGraph methodGraph = reference == null ? null : reference.get();
132        if (methodGraph == null) {
133            methodGraph = compiler.compile(new TypeDescription.ForLoadedType(instance.getClass()));
134            graphs.put(instance.getClass(), new SoftReference<MethodGraph>(methodGraph));
135        }
136        MethodGraph.Node node = methodGraph.locate(new MethodDescription.ForLoadedMethod(origin).asSignatureToken());
137        return !node.getSort().isResolved() || !node.getRepresentative().asDefined().getDeclaringType().represents(origin.getDeclaringClass());
138    }
139
140    private static class RealMethodCall implements RealMethod {
141
142        private final SelfCallInfo selfCallInfo;
143
144        private final Method origin;
145
146        private final MockWeakReference<Object> instanceRef;
147
148        private final Object[] arguments;
149
150        private RealMethodCall(SelfCallInfo selfCallInfo, Method origin, Object instance, Object[] arguments) {
151            this.selfCallInfo = selfCallInfo;
152            this.origin = origin;
153            this.instanceRef = new MockWeakReference<Object>(instance);
154            this.arguments = arguments;
155        }
156
157        @Override
158        public boolean isInvokable() {
159            return true;
160        }
161
162        @Override
163        public Object invoke() throws Throwable {
164            if (!Modifier.isPublic(origin.getDeclaringClass().getModifiers() & origin.getModifiers())) {
165                origin.setAccessible(true);
166            }
167            selfCallInfo.set(instanceRef.get());
168            return tryInvoke(origin, instanceRef.get(), arguments);
169        }
170
171    }
172
173    private static class SerializableRealMethodCall implements RealMethod {
174
175        private final String identifier;
176
177        private final SerializableMethod origin;
178
179        private final MockReference<Object> instanceRef;
180
181        private final Object[] arguments;
182
183        private SerializableRealMethodCall(String identifier, Method origin, Object instance, Object[] arguments) {
184            this.origin = new SerializableMethod(origin);
185            this.identifier = identifier;
186            this.instanceRef = new MockWeakReference<Object>(instance);
187            this.arguments = arguments;
188        }
189
190        @Override
191        public boolean isInvokable() {
192            return true;
193        }
194
195        @Override
196        public Object invoke() throws Throwable {
197            Method method = origin.getJavaMethod();
198            if (!Modifier.isPublic(method.getDeclaringClass().getModifiers() & method.getModifiers())) {
199                method.setAccessible(true);
200            }
201            MockMethodDispatcher mockMethodDispatcher = MockMethodDispatcher.get(identifier, instanceRef.get());
202            if (!(mockMethodDispatcher instanceof MockMethodAdvice)) {
203                throw new MockitoException("Unexpected dispatcher for advice-based super call");
204            }
205            Object previous = ((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.replace(instanceRef.get());
206            try {
207                return tryInvoke(method, instanceRef.get(), arguments);
208            } finally {
209                ((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.set(previous);
210            }
211        }
212    }
213
214    private static Object tryInvoke(Method origin, Object instance, Object[] arguments) throws Throwable {
215        try {
216            return origin.invoke(instance, arguments);
217        } catch (InvocationTargetException exception) {
218            Throwable cause = exception.getCause();
219            new ConditionalStackTraceFilter().filter(hideRecursiveCall(cause, new Throwable().getStackTrace().length, origin.getDeclaringClass()));
220            throw cause;
221        }
222    }
223
224    // With inline mocking, mocks for concrete classes are not subclassed, so elements of the stubbing methods are not filtered out.
225    // Therefore, if the method is inlined, skip the element.
226    private static StackTraceElement[] skipInlineMethodElement(StackTraceElement[] elements) {
227        List<StackTraceElement> list = new ArrayList<StackTraceElement>(elements.length);
228        for (int i = 0; i < elements.length; i++) {
229            StackTraceElement element = elements[i];
230            list.add(element);
231            if (element.getClassName().equals(MockMethodAdvice.class.getName()) && element.getMethodName().equals("handle")) {
232                // If the current element is MockMethodAdvice#handle(), the next is assumed to be an inlined method.
233                i++;
234            }
235        }
236        return list.toArray(new StackTraceElement[list.size()]);
237    }
238
239    private static class ReturnValueWrapper implements Callable<Object> {
240
241        private final Object returned;
242
243        private ReturnValueWrapper(Object returned) {
244            this.returned = returned;
245        }
246
247        @Override
248        public Object call() {
249            return returned;
250        }
251    }
252
253    private static class SelfCallInfo extends ThreadLocal<Object> {
254
255        Object replace(Object value) {
256            Object current = get();
257            set(value);
258            return current;
259        }
260
261        boolean checkSuperCall(Object value) {
262            if (value == get()) {
263                set(null);
264                return false;
265            } else {
266                return true;
267            }
268        }
269    }
270
271    @Retention(RetentionPolicy.RUNTIME)
272    @interface Identifier {
273
274    }
275
276    static class ForHashCode {
277
278        @SuppressWarnings("unused")
279        @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
280        private static boolean enter(@Identifier String id,
281                                     @Advice.This Object self) {
282            MockMethodDispatcher dispatcher = MockMethodDispatcher.get(id, self);
283            return dispatcher != null && dispatcher.isMock(self);
284        }
285
286        @SuppressWarnings({"unused", "UnusedAssignment"})
287        @Advice.OnMethodExit
288        private static void enter(@Advice.This Object self,
289                                  @Advice.Return(readOnly = false) int hashCode,
290                                  @Advice.Enter boolean skipped) {
291            if (skipped) {
292                hashCode = System.identityHashCode(self);
293            }
294        }
295    }
296
297    static class ForEquals {
298
299        @SuppressWarnings("unused")
300        @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
301        private static boolean enter(@Identifier String identifier,
302                                     @Advice.This Object self) {
303            MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, self);
304            return dispatcher != null && dispatcher.isMock(self);
305        }
306
307        @SuppressWarnings({"unused", "UnusedAssignment"})
308        @Advice.OnMethodExit
309        private static void enter(@Advice.This Object self,
310                                  @Advice.Argument(0) Object other,
311                                  @Advice.Return(readOnly = false) boolean equals,
312                                  @Advice.Enter boolean skipped) {
313            if (skipped) {
314                equals = self == other;
315            }
316        }
317    }
318
319    public static class ForReadObject {
320
321        @SuppressWarnings("unused")
322        public static void doReadObject(@Identifier String identifier,
323                                        @This MockAccess thiz,
324                                        @Argument(0) ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
325            objectInputStream.defaultReadObject();
326            MockMethodAdvice mockMethodAdvice = (MockMethodAdvice) MockMethodDispatcher.get(identifier, thiz);
327            if (mockMethodAdvice != null) {
328                mockMethodAdvice.interceptors.put(thiz, thiz.getMockitoInterceptor());
329            }
330        }
331    }
332}
333