1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package androidx.room;
18
19import static org.hamcrest.CoreMatchers.is;
20import static org.hamcrest.MatcherAssert.assertThat;
21import static org.hamcrest.core.IsCollectionContaining.hasItem;
22import static org.hamcrest.core.IsCollectionContaining.hasItems;
23import static org.mockito.Matchers.any;
24import static org.mockito.Matchers.anyInt;
25import static org.mockito.Matchers.anyString;
26import static org.mockito.Matchers.eq;
27import static org.mockito.Mockito.doReturn;
28import static org.mockito.Mockito.doThrow;
29import static org.mockito.Mockito.mock;
30import static org.mockito.Mockito.reset;
31import static org.mockito.Mockito.verify;
32import static org.mockito.Mockito.when;
33
34import android.database.Cursor;
35import android.database.sqlite.SQLiteException;
36
37import androidx.annotation.NonNull;
38import androidx.arch.core.executor.JunitTaskExecutorRule;
39import androidx.sqlite.db.SupportSQLiteDatabase;
40import androidx.sqlite.db.SupportSQLiteOpenHelper;
41import androidx.sqlite.db.SupportSQLiteStatement;
42
43import org.junit.After;
44import org.junit.Before;
45import org.junit.Rule;
46import org.junit.Test;
47import org.junit.runner.RunWith;
48import org.junit.runners.JUnit4;
49import org.mockito.Mockito;
50import org.mockito.invocation.InvocationOnMock;
51import org.mockito.stubbing.Answer;
52
53import java.lang.ref.WeakReference;
54import java.util.ArrayList;
55import java.util.Locale;
56import java.util.Set;
57import java.util.concurrent.CountDownLatch;
58import java.util.concurrent.TimeUnit;
59import java.util.concurrent.atomic.AtomicInteger;
60import java.util.concurrent.locks.ReentrantLock;
61
62@RunWith(JUnit4.class)
63public class InvalidationTrackerTest {
64    private InvalidationTracker mTracker;
65    private RoomDatabase mRoomDatabase;
66    private SupportSQLiteOpenHelper mOpenHelper;
67    @Rule
68    public JunitTaskExecutorRule mTaskExecutorRule = new JunitTaskExecutorRule(1, true);
69
70    @Before
71    public void setup() {
72        mRoomDatabase = mock(RoomDatabase.class);
73        SupportSQLiteDatabase sqliteDb = mock(SupportSQLiteDatabase.class);
74        final SupportSQLiteStatement statement = mock(SupportSQLiteStatement.class);
75        mOpenHelper = mock(SupportSQLiteOpenHelper.class);
76
77        doReturn(statement).when(sqliteDb).compileStatement(eq(InvalidationTracker.CLEANUP_SQL));
78        doReturn(sqliteDb).when(mOpenHelper).getWritableDatabase();
79        doReturn(true).when(mRoomDatabase).isOpen();
80        ReentrantLock closeLock = new ReentrantLock();
81        doReturn(closeLock).when(mRoomDatabase).getCloseLock();
82        //noinspection ResultOfMethodCallIgnored
83        doReturn(mOpenHelper).when(mRoomDatabase).getOpenHelper();
84
85        mTracker = new InvalidationTracker(mRoomDatabase, "a", "B", "i");
86        mTracker.internalInit(sqliteDb);
87    }
88
89    @Before
90    public void setLocale() {
91        Locale.setDefault(Locale.forLanguageTag("tr-TR"));
92    }
93
94    @After
95    public void unsetLocale() {
96        Locale.setDefault(Locale.US);
97    }
98
99    @Test
100    public void tableIds() {
101        assertThat(mTracker.mTableIdLookup.get("a"), is(0));
102        assertThat(mTracker.mTableIdLookup.get("b"), is(1));
103    }
104
105    @Test
106    public void testWeak() throws InterruptedException {
107        final AtomicInteger data = new AtomicInteger(0);
108        InvalidationTracker.Observer observer = new InvalidationTracker.Observer("a") {
109            @Override
110            public void onInvalidated(@NonNull Set<String> tables) {
111                data.incrementAndGet();
112            }
113        };
114        mTracker.addWeakObserver(observer);
115        setVersions(1, 0);
116        refreshSync();
117        assertThat(data.get(), is(1));
118        observer = null;
119        forceGc();
120        setVersions(2, 0);
121        refreshSync();
122        assertThat(data.get(), is(1));
123    }
124
125    @Test
126    public void addRemoveObserver() throws Exception {
127        InvalidationTracker.Observer observer = new LatchObserver(1, "a");
128        mTracker.addObserver(observer);
129        assertThat(mTracker.mObserverMap.size(), is(1));
130        mTracker.removeObserver(new LatchObserver(1, "a"));
131        assertThat(mTracker.mObserverMap.size(), is(1));
132        mTracker.removeObserver(observer);
133        assertThat(mTracker.mObserverMap.size(), is(0));
134    }
135
136    private void drainTasks() throws InterruptedException {
137        mTaskExecutorRule.drainTasks(200);
138    }
139
140    @Test(expected = IllegalArgumentException.class)
141    public void badObserver() {
142        InvalidationTracker.Observer observer = new LatchObserver(1, "x");
143        mTracker.addObserver(observer);
144    }
145
146    @Test
147    public void refreshReadValues() throws Exception {
148        setVersions(1, 0, 2, 1);
149        refreshSync();
150        assertThat(mTracker.mTableVersions, is(new long[]{1, 2, 0}));
151
152        setVersions(3, 1);
153        refreshSync();
154        assertThat(mTracker.mTableVersions, is(new long[]{1, 3, 0}));
155
156        setVersions(7, 0);
157        refreshSync();
158        assertThat(mTracker.mTableVersions, is(new long[]{7, 3, 0}));
159
160        refreshSync();
161        assertThat(mTracker.mTableVersions, is(new long[]{7, 3, 0}));
162    }
163
164    private void refreshSync() throws InterruptedException {
165        mTracker.refreshVersionsAsync();
166        drainTasks();
167    }
168
169    @Test
170    public void refreshCheckTasks() throws Exception {
171        when(mRoomDatabase.query(anyString(), any(Object[].class)))
172                .thenReturn(mock(Cursor.class));
173        mTracker.refreshVersionsAsync();
174        mTracker.refreshVersionsAsync();
175        verify(mTaskExecutorRule.getTaskExecutor()).executeOnDiskIO(mTracker.mRefreshRunnable);
176        drainTasks();
177
178        reset(mTaskExecutorRule.getTaskExecutor());
179        mTracker.refreshVersionsAsync();
180        verify(mTaskExecutorRule.getTaskExecutor()).executeOnDiskIO(mTracker.mRefreshRunnable);
181    }
182
183    @Test
184    public void observe1Table() throws Exception {
185        LatchObserver observer = new LatchObserver(1, "a");
186        mTracker.addObserver(observer);
187        setVersions(1, 0, 2, 1);
188        refreshSync();
189        assertThat(observer.await(), is(true));
190        assertThat(observer.getInvalidatedTables().size(), is(1));
191        assertThat(observer.getInvalidatedTables(), hasItem("a"));
192
193        setVersions(3, 1);
194        observer.reset(1);
195        refreshSync();
196        assertThat(observer.await(), is(false));
197
198        setVersions(4, 0);
199        refreshSync();
200        assertThat(observer.await(), is(true));
201        assertThat(observer.getInvalidatedTables().size(), is(1));
202        assertThat(observer.getInvalidatedTables(), hasItem("a"));
203    }
204
205    @Test
206    public void observe2Tables() throws Exception {
207        LatchObserver observer = new LatchObserver(1, "A", "B");
208        mTracker.addObserver(observer);
209        setVersions(1, 0, 2, 1);
210        refreshSync();
211        assertThat(observer.await(), is(true));
212        assertThat(observer.getInvalidatedTables().size(), is(2));
213        assertThat(observer.getInvalidatedTables(), hasItems("A", "B"));
214
215        setVersions(3, 1);
216        observer.reset(1);
217        refreshSync();
218        assertThat(observer.await(), is(true));
219        assertThat(observer.getInvalidatedTables().size(), is(1));
220        assertThat(observer.getInvalidatedTables(), hasItem("B"));
221
222        setVersions(4, 0);
223        observer.reset(1);
224        refreshSync();
225        assertThat(observer.await(), is(true));
226        assertThat(observer.getInvalidatedTables().size(), is(1));
227        assertThat(observer.getInvalidatedTables(), hasItem("A"));
228
229        observer.reset(1);
230        refreshSync();
231        assertThat(observer.await(), is(false));
232    }
233
234    @Test
235    public void locale() {
236        LatchObserver observer = new LatchObserver(1, "I");
237        mTracker.addObserver(observer);
238    }
239
240    @Test
241    public void closedDb() {
242        doReturn(false).when(mRoomDatabase).isOpen();
243        doThrow(new IllegalStateException("foo")).when(mOpenHelper).getWritableDatabase();
244        mTracker.addObserver(new LatchObserver(1, "a", "b"));
245        mTracker.mRefreshRunnable.run();
246    }
247
248    // @Test - disabled due to flakiness b/65257997
249    public void closedDbAfterOpen() throws InterruptedException {
250        setVersions(3, 1);
251        mTracker.addObserver(new LatchObserver(1, "a", "b"));
252        mTracker.syncTriggers();
253        mTracker.mRefreshRunnable.run();
254        doThrow(new SQLiteException("foo")).when(mRoomDatabase).query(
255                Mockito.eq(InvalidationTracker.SELECT_UPDATED_TABLES_SQL),
256                any(Object[].class));
257        mTracker.mPendingRefresh.set(true);
258        mTracker.mRefreshRunnable.run();
259    }
260
261    /**
262     * Key value pairs of VERSION, TABLE_ID
263     */
264    private void setVersions(int... keyValuePairs) throws InterruptedException {
265        // mockito does not like multi-threaded access so before setting versions, make sure we
266        // sync background tasks.
267        drainTasks();
268        Cursor cursor = createCursorWithValues(keyValuePairs);
269        doReturn(cursor).when(mRoomDatabase).query(
270                Mockito.eq(InvalidationTracker.SELECT_UPDATED_TABLES_SQL),
271                any(Object[].class)
272        );
273    }
274
275    private Cursor createCursorWithValues(final int... keyValuePairs) {
276        Cursor cursor = mock(Cursor.class);
277        final AtomicInteger index = new AtomicInteger(-2);
278        when(cursor.moveToNext()).thenAnswer(new Answer<Boolean>() {
279            @Override
280            public Boolean answer(InvocationOnMock invocation) throws Throwable {
281                return index.addAndGet(2) < keyValuePairs.length;
282            }
283        });
284        Answer<Integer> intAnswer = new Answer<Integer>() {
285            @Override
286            public Integer answer(InvocationOnMock invocation) throws Throwable {
287                return keyValuePairs[index.intValue() + (Integer) invocation.getArguments()[0]];
288            }
289        };
290        Answer<Long> longAnswer = new Answer<Long>() {
291            @Override
292            public Long answer(InvocationOnMock invocation) throws Throwable {
293                return (long) keyValuePairs[index.intValue()
294                        + (Integer) invocation.getArguments()[0]];
295            }
296        };
297        when(cursor.getInt(anyInt())).thenAnswer(intAnswer);
298        when(cursor.getLong(anyInt())).thenAnswer(longAnswer);
299        return cursor;
300    }
301
302    static class LatchObserver extends InvalidationTracker.Observer {
303        private CountDownLatch mLatch;
304        private Set<String> mInvalidatedTables;
305
306        LatchObserver(int count, String... tableNames) {
307            super(tableNames);
308            mLatch = new CountDownLatch(count);
309        }
310
311        boolean await() throws InterruptedException {
312            return mLatch.await(3, TimeUnit.SECONDS);
313        }
314
315        @Override
316        public void onInvalidated(@NonNull Set<String> tables) {
317            mInvalidatedTables = tables;
318            mLatch.countDown();
319        }
320
321        void reset(@SuppressWarnings("SameParameterValue") int count) {
322            mInvalidatedTables = null;
323            mLatch = new CountDownLatch(count);
324        }
325
326        Set<String> getInvalidatedTables() {
327            return mInvalidatedTables;
328        }
329    }
330
331    private static void forceGc() {
332        // Use a random index in the list to detect the garbage collection each time because
333        // .get() may accidentally trigger a strong reference during collection.
334        ArrayList<WeakReference<byte[]>> leak = new ArrayList<>();
335        do {
336            WeakReference<byte[]> arr = new WeakReference<>(new byte[100]);
337            leak.add(arr);
338        } while (leak.get((int) (Math.random() * leak.size())).get() != null);
339    }
340}
341