1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
5 * except in compliance with the License. You may obtain a copy of the License at
6 *
7 *      http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software distributed under the
10 * License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
11 * KIND, either express or implied. See the License for the specific language governing
12 * permissions and limitations under the License.
13 */
14
15package android.testing;
16
17import android.os.Handler;
18import android.os.HandlerThread;
19import android.os.Looper;
20import android.os.Message;
21import android.os.MessageQueue;
22import android.os.TestLooperManager;
23import android.support.test.InstrumentationRegistry;
24import android.util.ArrayMap;
25
26import org.junit.runners.model.FrameworkMethod;
27
28import java.lang.annotation.ElementType;
29import java.lang.annotation.Retention;
30import java.lang.annotation.RetentionPolicy;
31import java.lang.annotation.Target;
32import java.lang.reflect.Field;
33import java.util.Map;
34
35/**
36 * Creates a looper on the current thread with control over if/when messages are
37 * executed. Warning: This class works through some reflection and may break/need
38 * to be updated from time to time.
39 */
40public class TestableLooper {
41
42    private Looper mLooper;
43    private MessageQueue mQueue;
44    private boolean mMain;
45    private Object mOriginalMain;
46    private MessageHandler mMessageHandler;
47
48    private Handler mHandler;
49    private Runnable mEmptyMessage;
50    private TestLooperManager mQueueWrapper;
51
52    public TestableLooper(Looper l) throws Exception {
53        this(InstrumentationRegistry.getInstrumentation().acquireLooperManager(l), l);
54    }
55
56    private TestableLooper(TestLooperManager wrapper, Looper l) throws Exception {
57        mQueueWrapper = wrapper;
58        setupQueue(l);
59    }
60
61    private TestableLooper(Looper looper, boolean b) throws Exception {
62        setupQueue(looper);
63    }
64
65    public Looper getLooper() {
66        return mLooper;
67    }
68
69    private void setupQueue(Looper l) throws Exception {
70        mLooper = l;
71        mQueue = mLooper.getQueue();
72        mHandler = new Handler(mLooper);
73    }
74
75    public void setAsMainLooper() throws NoSuchFieldException, IllegalAccessException {
76        mMain = true;
77        setAsMainInt();
78    }
79
80    private void setAsMainInt() throws NoSuchFieldException, IllegalAccessException {
81        Field field = mLooper.getClass().getDeclaredField("sMainLooper");
82        field.setAccessible(true);
83        if (mOriginalMain == null) {
84            mOriginalMain = field.get(null);
85        }
86        field.set(null, mLooper);
87    }
88
89    /**
90     * Must be called if setAsMainLooper is called to restore the main looper when the
91     * test is complete, otherwise the main looper will not be available for any subsequent
92     * tests.
93     */
94    public void destroy() throws NoSuchFieldException, IllegalAccessException {
95        mQueueWrapper.release();
96        if (mMain && mOriginalMain != null) {
97            Field field = mLooper.getClass().getDeclaredField("sMainLooper");
98            field.setAccessible(true);
99            field.set(null, mOriginalMain);
100            mOriginalMain = null;
101        }
102    }
103
104    public void setMessageHandler(MessageHandler handler) {
105        mMessageHandler = handler;
106    }
107
108    /**
109     * Parse num messages from the message queue.
110     *
111     * @param num Number of messages to parse
112     */
113    public int processMessages(int num) {
114        for (int i = 0; i < num; i++) {
115            if (!parseMessageInt()) {
116                return i + 1;
117            }
118        }
119        return num;
120    }
121
122    public void processAllMessages() {
123        while (processQueuedMessages() != 0) ;
124    }
125
126    private int processQueuedMessages() {
127        int count = 0;
128        mEmptyMessage = () -> { };
129        mHandler.post(mEmptyMessage);
130        waitForMessage(mQueueWrapper, mHandler, mEmptyMessage);
131        while (parseMessageInt()) count++;
132        return count;
133    }
134
135    private boolean parseMessageInt() {
136        try {
137            Message result = mQueueWrapper.next();
138            if (result != null) {
139                // This is a break message.
140                if (result.getCallback() == mEmptyMessage) {
141                    mQueueWrapper.recycle(result);
142                    return false;
143                }
144
145                if (mMessageHandler != null) {
146                    if (mMessageHandler.onMessageHandled(result)) {
147                        result.getTarget().dispatchMessage(result);
148                        mQueueWrapper.recycle(result);
149                    } else {
150                        mQueueWrapper.recycle(result);
151                        // Message handler indicated it doesn't want us to continue.
152                        return false;
153                    }
154                } else {
155                    result.getTarget().dispatchMessage(result);
156                    mQueueWrapper.recycle(result);
157                }
158            } else {
159                // No messages, don't continue parsing
160                return false;
161            }
162        } catch (Exception e) {
163            throw new RuntimeException(e);
164        }
165        return true;
166    }
167
168    /**
169     * Runs an executable with myLooper set and processes all messages added.
170     */
171    public void runWithLooper(RunnableWithException runnable) throws Exception {
172        new Handler(getLooper()).post(() -> {
173            try {
174                runnable.run();
175            } catch (Exception e) {
176                throw new RuntimeException(e);
177            }
178        });
179        processAllMessages();
180    }
181
182    public interface RunnableWithException {
183        void run() throws Exception;
184    }
185
186    @Retention(RetentionPolicy.RUNTIME)
187    @Target({ElementType.METHOD, ElementType.TYPE})
188    public @interface RunWithLooper {
189        boolean setAsMainLooper() default false;
190    }
191
192    private static void waitForMessage(TestLooperManager queueWrapper, Handler handler,
193            Runnable execute) {
194        for (int i = 0; i < 10; i++) {
195            if (!queueWrapper.hasMessages(handler, null, execute)) {
196                try {
197                    Thread.sleep(1);
198                } catch (InterruptedException e) {
199                }
200            }
201        }
202        if (!queueWrapper.hasMessages(handler, null, execute)) {
203            throw new RuntimeException("Message didn't queue...");
204        }
205    }
206
207    private static final Map<Object, TestableLooper> sLoopers = new ArrayMap<>();
208
209    public static TestableLooper get(Object test) {
210        return sLoopers.get(test);
211    }
212
213    public static class LooperFrameworkMethod extends FrameworkMethod {
214        private HandlerThread mHandlerThread;
215
216        private final TestableLooper mTestableLooper;
217        private final Looper mLooper;
218        private final Handler mHandler;
219
220        public LooperFrameworkMethod(FrameworkMethod base, boolean setAsMain, Object test) {
221            super(base.getMethod());
222            try {
223                mLooper = setAsMain ? Looper.getMainLooper() : createLooper();
224                mTestableLooper = new TestableLooper(mLooper, false);
225            } catch (Exception e) {
226                throw new RuntimeException(e);
227            }
228            sLoopers.put(test, mTestableLooper);
229            mHandler = new Handler(mLooper);
230        }
231
232        public LooperFrameworkMethod(TestableLooper other, FrameworkMethod base) {
233            super(base.getMethod());
234            mLooper = other.mLooper;
235            mTestableLooper = other;
236            mHandler = new Handler(mLooper);
237        }
238
239        public static FrameworkMethod get(FrameworkMethod base, boolean setAsMain, Object test) {
240            if (sLoopers.containsKey(test)) {
241                return new LooperFrameworkMethod(sLoopers.get(test), base);
242            }
243            return new LooperFrameworkMethod(base, setAsMain, test);
244        }
245
246        @Override
247        public Object invokeExplosively(Object target, Object... params) throws Throwable {
248            if (Looper.myLooper() == mLooper) {
249                // Already on the right thread from another statement, just execute then.
250                return super.invokeExplosively(target, params);
251            }
252            boolean set = mTestableLooper.mQueueWrapper == null;
253            if (set) {
254                mTestableLooper.mQueueWrapper = InstrumentationRegistry.getInstrumentation()
255                        .acquireLooperManager(mLooper);
256            }
257            try {
258                Object[] ret = new Object[1];
259                // Run the execution on the looper thread.
260                Runnable execute = () -> {
261                    try {
262                        ret[0] = super.invokeExplosively(target, params);
263                    } catch (Throwable throwable) {
264                        throw new LooperException(throwable);
265                    }
266                };
267                Message m = Message.obtain(mHandler, execute);
268
269                // Dispatch our message.
270                try {
271                    mTestableLooper.mQueueWrapper.execute(m);
272                } catch (LooperException e) {
273                    throw e.getSource();
274                } catch (RuntimeException re) {
275                    // If the TestLooperManager has to post, it will wrap what it throws in a
276                    // RuntimeException, make sure we grab the actual source.
277                    if (re.getCause() instanceof LooperException) {
278                        throw ((LooperException) re.getCause()).getSource();
279                    } else {
280                        throw re.getCause();
281                    }
282                } finally {
283                    m.recycle();
284                }
285                return ret[0];
286            } finally {
287                if (set) {
288                    mTestableLooper.mQueueWrapper.release();
289                    mTestableLooper.mQueueWrapper = null;
290                }
291            }
292        }
293
294        private Looper createLooper() {
295            // TODO: Find way to share these.
296            mHandlerThread = new HandlerThread(TestableLooper.class.getSimpleName());
297            mHandlerThread.start();
298            return mHandlerThread.getLooper();
299        }
300
301        @Override
302        protected void finalize() throws Throwable {
303            super.finalize();
304            if (mHandlerThread != null) {
305                mHandlerThread.quit();
306            }
307        }
308
309        private static class LooperException extends RuntimeException {
310            private final Throwable mSource;
311
312            public LooperException(Throwable t) {
313                mSource = t;
314            }
315
316            public Throwable getSource() {
317                return mSource;
318            }
319        }
320    }
321
322    public interface MessageHandler {
323        /**
324         * Return true to have the message executed and delivered to target.
325         * Return false to not execute the message and stop executing messages.
326         */
327        boolean onMessageHandled(Message m);
328    }
329}
330