1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.arch.persistence.room;
18
19import static org.hamcrest.CoreMatchers.is;
20import static org.hamcrest.MatcherAssert.assertThat;
21import static org.hamcrest.core.IsCollectionContaining.hasItem;
22import static org.hamcrest.core.IsCollectionContaining.hasItems;
23import static org.mockito.Matchers.any;
24import static org.mockito.Matchers.anyInt;
25import static org.mockito.Matchers.anyString;
26import static org.mockito.Matchers.eq;
27import static org.mockito.Mockito.doReturn;
28import static org.mockito.Mockito.doThrow;
29import static org.mockito.Mockito.mock;
30import static org.mockito.Mockito.reset;
31import static org.mockito.Mockito.verify;
32import static org.mockito.Mockito.when;
33
34import android.arch.core.executor.JunitTaskExecutorRule;
35import android.arch.persistence.db.SupportSQLiteDatabase;
36import android.arch.persistence.db.SupportSQLiteOpenHelper;
37import android.arch.persistence.db.SupportSQLiteStatement;
38import android.database.Cursor;
39import android.database.sqlite.SQLiteException;
40import android.support.annotation.NonNull;
41
42import org.junit.After;
43import org.junit.Before;
44import org.junit.Rule;
45import org.junit.Test;
46import org.junit.runner.RunWith;
47import org.junit.runners.JUnit4;
48import org.mockito.Mockito;
49import org.mockito.invocation.InvocationOnMock;
50import org.mockito.stubbing.Answer;
51
52import java.lang.ref.WeakReference;
53import java.util.ArrayList;
54import java.util.Locale;
55import java.util.Set;
56import java.util.concurrent.CountDownLatch;
57import java.util.concurrent.TimeUnit;
58import java.util.concurrent.atomic.AtomicInteger;
59import java.util.concurrent.locks.ReentrantLock;
60
61@RunWith(JUnit4.class)
62public class InvalidationTrackerTest {
63    private InvalidationTracker mTracker;
64    private RoomDatabase mRoomDatabase;
65    private SupportSQLiteOpenHelper mOpenHelper;
66    @Rule
67    public JunitTaskExecutorRule mTaskExecutorRule = new JunitTaskExecutorRule(1, true);
68
69    @Before
70    public void setup() {
71        mRoomDatabase = mock(RoomDatabase.class);
72        SupportSQLiteDatabase sqliteDb = mock(SupportSQLiteDatabase.class);
73        final SupportSQLiteStatement statement = mock(SupportSQLiteStatement.class);
74        mOpenHelper = mock(SupportSQLiteOpenHelper.class);
75
76        doReturn(statement).when(sqliteDb).compileStatement(eq(InvalidationTracker.CLEANUP_SQL));
77        doReturn(sqliteDb).when(mOpenHelper).getWritableDatabase();
78        doReturn(true).when(mRoomDatabase).isOpen();
79        ReentrantLock closeLock = new ReentrantLock();
80        doReturn(closeLock).when(mRoomDatabase).getCloseLock();
81        //noinspection ResultOfMethodCallIgnored
82        doReturn(mOpenHelper).when(mRoomDatabase).getOpenHelper();
83
84        mTracker = new InvalidationTracker(mRoomDatabase, "a", "B", "i");
85        mTracker.internalInit(sqliteDb);
86    }
87
88    @Before
89    public void setLocale() {
90        Locale.setDefault(Locale.forLanguageTag("tr-TR"));
91    }
92
93    @After
94    public void unsetLocale() {
95        Locale.setDefault(Locale.US);
96    }
97
98    @Test
99    public void tableIds() {
100        assertThat(mTracker.mTableIdLookup.get("a"), is(0));
101        assertThat(mTracker.mTableIdLookup.get("b"), is(1));
102    }
103
104    @Test
105    public void testWeak() throws InterruptedException {
106        final AtomicInteger data = new AtomicInteger(0);
107        InvalidationTracker.Observer observer = new InvalidationTracker.Observer("a") {
108            @Override
109            public void onInvalidated(@NonNull Set<String> tables) {
110                data.incrementAndGet();
111            }
112        };
113        mTracker.addWeakObserver(observer);
114        setVersions(1, 0);
115        refreshSync();
116        assertThat(data.get(), is(1));
117        observer = null;
118        forceGc();
119        setVersions(2, 0);
120        refreshSync();
121        assertThat(data.get(), is(1));
122    }
123
124    @Test
125    public void addRemoveObserver() throws Exception {
126        InvalidationTracker.Observer observer = new LatchObserver(1, "a");
127        mTracker.addObserver(observer);
128        drainTasks();
129        assertThat(mTracker.mObserverMap.size(), is(1));
130        mTracker.removeObserver(new LatchObserver(1, "a"));
131        drainTasks();
132        assertThat(mTracker.mObserverMap.size(), is(1));
133        mTracker.removeObserver(observer);
134        drainTasks();
135        assertThat(mTracker.mObserverMap.size(), is(0));
136    }
137
138    private void drainTasks() throws InterruptedException {
139        mTaskExecutorRule.drainTasks(200);
140    }
141
142    @Test(expected = IllegalArgumentException.class)
143    public void badObserver() {
144        InvalidationTracker.Observer observer = new LatchObserver(1, "x");
145        mTracker.addObserver(observer);
146    }
147
148    @Test
149    public void refreshReadValues() throws Exception {
150        setVersions(1, 0, 2, 1);
151        refreshSync();
152        assertThat(mTracker.mTableVersions, is(new long[]{1, 2, 0}));
153
154        setVersions(3, 1);
155        refreshSync();
156        assertThat(mTracker.mTableVersions, is(new long[]{1, 3, 0}));
157
158        setVersions(7, 0);
159        refreshSync();
160        assertThat(mTracker.mTableVersions, is(new long[]{7, 3, 0}));
161
162        refreshSync();
163        assertThat(mTracker.mTableVersions, is(new long[]{7, 3, 0}));
164    }
165
166    private void refreshSync() throws InterruptedException {
167        mTracker.refreshVersionsAsync();
168        drainTasks();
169    }
170
171    @Test
172    public void refreshCheckTasks() throws Exception {
173        when(mRoomDatabase.query(anyString(), any(Object[].class)))
174                .thenReturn(mock(Cursor.class));
175        mTracker.refreshVersionsAsync();
176        mTracker.refreshVersionsAsync();
177        verify(mTaskExecutorRule.getTaskExecutor()).executeOnDiskIO(mTracker.mRefreshRunnable);
178        drainTasks();
179
180        reset(mTaskExecutorRule.getTaskExecutor());
181        mTracker.refreshVersionsAsync();
182        verify(mTaskExecutorRule.getTaskExecutor()).executeOnDiskIO(mTracker.mRefreshRunnable);
183    }
184
185    @Test
186    public void observe1Table() throws Exception {
187        LatchObserver observer = new LatchObserver(1, "a");
188        mTracker.addObserver(observer);
189        setVersions(1, 0, 2, 1);
190        refreshSync();
191        assertThat(observer.await(), is(true));
192        assertThat(observer.getInvalidatedTables().size(), is(1));
193        assertThat(observer.getInvalidatedTables(), hasItem("a"));
194
195        setVersions(3, 1);
196        observer.reset(1);
197        refreshSync();
198        assertThat(observer.await(), is(false));
199
200        setVersions(4, 0);
201        refreshSync();
202        assertThat(observer.await(), is(true));
203        assertThat(observer.getInvalidatedTables().size(), is(1));
204        assertThat(observer.getInvalidatedTables(), hasItem("a"));
205    }
206
207    @Test
208    public void observe2Tables() throws Exception {
209        LatchObserver observer = new LatchObserver(1, "A", "B");
210        mTracker.addObserver(observer);
211        setVersions(1, 0, 2, 1);
212        refreshSync();
213        assertThat(observer.await(), is(true));
214        assertThat(observer.getInvalidatedTables().size(), is(2));
215        assertThat(observer.getInvalidatedTables(), hasItems("A", "B"));
216
217        setVersions(3, 1);
218        observer.reset(1);
219        refreshSync();
220        assertThat(observer.await(), is(true));
221        assertThat(observer.getInvalidatedTables().size(), is(1));
222        assertThat(observer.getInvalidatedTables(), hasItem("B"));
223
224        setVersions(4, 0);
225        observer.reset(1);
226        refreshSync();
227        assertThat(observer.await(), is(true));
228        assertThat(observer.getInvalidatedTables().size(), is(1));
229        assertThat(observer.getInvalidatedTables(), hasItem("A"));
230
231        observer.reset(1);
232        refreshSync();
233        assertThat(observer.await(), is(false));
234    }
235
236    @Test
237    public void locale() {
238        LatchObserver observer = new LatchObserver(1, "I");
239        mTracker.addObserver(observer);
240    }
241
242    @Test
243    public void closedDb() {
244        doThrow(new IllegalStateException("foo")).when(mOpenHelper).getWritableDatabase();
245        mTracker.addObserver(new LatchObserver(1, "a", "b"));
246        mTracker.syncTriggers();
247        mTracker.mRefreshRunnable.run();
248    }
249
250    @Test
251    public void closedDbAfterOpen() throws InterruptedException {
252        setVersions(3, 1);
253        mTracker.addObserver(new LatchObserver(1, "a", "b"));
254        mTracker.syncTriggers();
255        mTracker.mRefreshRunnable.run();
256        doThrow(new SQLiteException("foo")).when(mRoomDatabase).query(
257                Mockito.eq(InvalidationTracker.SELECT_UPDATED_TABLES_SQL),
258                any(Object[].class));
259        mTracker.mPendingRefresh.set(true);
260        mTracker.mRefreshRunnable.run();
261    }
262
263    /**
264     * Key value pairs of VERSION, TABLE_ID
265     */
266    private void setVersions(int... keyValuePairs) throws InterruptedException {
267        // mockito does not like multi-threaded access so before setting versions, make sure we
268        // sync background tasks.
269        drainTasks();
270        Cursor cursor = createCursorWithValues(keyValuePairs);
271        doReturn(cursor).when(mRoomDatabase).query(
272                Mockito.eq(InvalidationTracker.SELECT_UPDATED_TABLES_SQL),
273                any(Object[].class)
274        );
275    }
276
277    private Cursor createCursorWithValues(final int... keyValuePairs) {
278        Cursor cursor = mock(Cursor.class);
279        final AtomicInteger index = new AtomicInteger(-2);
280        when(cursor.moveToNext()).thenAnswer(new Answer<Boolean>() {
281            @Override
282            public Boolean answer(InvocationOnMock invocation) throws Throwable {
283                return index.addAndGet(2) < keyValuePairs.length;
284            }
285        });
286        Answer<Integer> intAnswer = new Answer<Integer>() {
287            @Override
288            public Integer answer(InvocationOnMock invocation) throws Throwable {
289                return keyValuePairs[index.intValue() + (Integer) invocation.getArguments()[0]];
290            }
291        };
292        Answer<Long> longAnswer = new Answer<Long>() {
293            @Override
294            public Long answer(InvocationOnMock invocation) throws Throwable {
295                return (long) keyValuePairs[index.intValue()
296                        + (Integer) invocation.getArguments()[0]];
297            }
298        };
299        when(cursor.getInt(anyInt())).thenAnswer(intAnswer);
300        when(cursor.getLong(anyInt())).thenAnswer(longAnswer);
301        return cursor;
302    }
303
304    static class LatchObserver extends InvalidationTracker.Observer {
305        private CountDownLatch mLatch;
306        private Set<String> mInvalidatedTables;
307
308        LatchObserver(int count, String... tableNames) {
309            super(tableNames);
310            mLatch = new CountDownLatch(count);
311        }
312
313        boolean await() throws InterruptedException {
314            return mLatch.await(3, TimeUnit.SECONDS);
315        }
316
317        @Override
318        public void onInvalidated(@NonNull Set<String> tables) {
319            mInvalidatedTables = tables;
320            mLatch.countDown();
321        }
322
323        void reset(@SuppressWarnings("SameParameterValue") int count) {
324            mInvalidatedTables = null;
325            mLatch = new CountDownLatch(count);
326        }
327
328        Set<String> getInvalidatedTables() {
329            return mInvalidatedTables;
330        }
331    }
332
333    private static void forceGc() {
334        // Use a random index in the list to detect the garbage collection each time because
335        // .get() may accidentally trigger a strong reference during collection.
336        ArrayList<WeakReference<byte[]>> leak = new ArrayList<>();
337        do {
338            WeakReference<byte[]> arr = new WeakReference<>(new byte[100]);
339            leak.add(arr);
340        } while (leak.get((int) (Math.random() * leak.size())).get() != null);
341    }
342}
343