1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package androidx.arch.core.executor.testing;
18
19import android.os.SystemClock;
20
21import androidx.arch.core.executor.ArchTaskExecutor;
22import androidx.arch.core.executor.DefaultTaskExecutor;
23
24import org.junit.rules.TestWatcher;
25import org.junit.runner.Description;
26
27import java.util.concurrent.TimeUnit;
28import java.util.concurrent.TimeoutException;
29
30/**
31 * A JUnit Test Rule that swaps the background executor used by the Architecture Components with a
32 * different one which counts the tasks as they are start and finish.
33 * <p>
34 * You can use this rule for your host side tests that use Architecture Components.
35 */
36public class CountingTaskExecutorRule extends TestWatcher {
37    private final Object mCountLock = new Object();
38    private int mTaskCount = 0;
39
40    @Override
41    protected void starting(Description description) {
42        super.starting(description);
43        ArchTaskExecutor.getInstance().setDelegate(new DefaultTaskExecutor() {
44            @Override
45            public void executeOnDiskIO(Runnable runnable) {
46                super.executeOnDiskIO(new CountingRunnable(runnable));
47            }
48
49            @Override
50            public void postToMainThread(Runnable runnable) {
51                super.postToMainThread(new CountingRunnable(runnable));
52            }
53        });
54    }
55
56    @Override
57    protected void finished(Description description) {
58        super.finished(description);
59        ArchTaskExecutor.getInstance().setDelegate(null);
60    }
61
62    private void increment() {
63        synchronized (mCountLock) {
64            mTaskCount++;
65        }
66    }
67
68    private void decrement() {
69        synchronized (mCountLock) {
70            mTaskCount--;
71            if (mTaskCount == 0) {
72                onIdle();
73                mCountLock.notifyAll();
74            }
75        }
76    }
77
78    /**
79     * Called when the number of awaiting tasks reaches to 0.
80     *
81     * @see #isIdle()
82     */
83    protected void onIdle() {
84
85    }
86
87    /**
88     * Returns false if there are tasks waiting to be executed, true otherwise.
89     *
90     * @return False if there are tasks waiting to be executed, true otherwise.
91     *
92     * @see #onIdle()
93     */
94    public boolean isIdle() {
95        synchronized (mCountLock) {
96            return mTaskCount == 0;
97        }
98    }
99
100    /**
101     * Waits until all active tasks are finished.
102     *
103     * @param time The duration to wait
104     * @param timeUnit The time unit for the {@code time} parameter
105     *
106     * @throws InterruptedException If thread is interrupted while waiting
107     * @throws TimeoutException If tasks cannot be drained at the given time
108     */
109    public void drainTasks(int time, TimeUnit timeUnit)
110            throws InterruptedException, TimeoutException {
111        long end = SystemClock.uptimeMillis() + timeUnit.toMillis(time);
112        synchronized (mCountLock) {
113            while (mTaskCount != 0) {
114                long now = SystemClock.uptimeMillis();
115                long remaining = end - now;
116                if (remaining > 0) {
117                    mCountLock.wait(remaining);
118                } else {
119                    throw new TimeoutException("could not drain tasks");
120                }
121            }
122        }
123    }
124
125    class CountingRunnable implements Runnable {
126        final Runnable mWrapped;
127
128        CountingRunnable(Runnable wrapped) {
129            mWrapped = wrapped;
130            increment();
131        }
132
133        @Override
134        public void run() {
135            try {
136                mWrapped.run();
137            } finally {
138                decrement();
139            }
140        }
141    }
142}
143