1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package androidx.room.integration.testapp.test;
18
19import static org.hamcrest.CoreMatchers.is;
20import static org.hamcrest.CoreMatchers.notNullValue;
21import static org.hamcrest.MatcherAssert.assertThat;
22
23import android.support.test.InstrumentationRegistry;
24import android.support.test.filters.SmallTest;
25
26import androidx.annotation.NonNull;
27import androidx.arch.core.executor.ArchTaskExecutor;
28import androidx.arch.core.executor.testing.CountingTaskExecutorRule;
29import androidx.lifecycle.Lifecycle;
30import androidx.lifecycle.LiveData;
31import androidx.paging.DataSource;
32import androidx.paging.LivePagedListBuilder;
33import androidx.paging.PagedList;
34import androidx.paging.PositionalDataSource;
35import androidx.room.Dao;
36import androidx.room.Database;
37import androidx.room.Entity;
38import androidx.room.Ignore;
39import androidx.room.Insert;
40import androidx.room.PrimaryKey;
41import androidx.room.Query;
42import androidx.room.Relation;
43import androidx.room.Room;
44import androidx.room.RoomDatabase;
45import androidx.room.RoomWarnings;
46import androidx.room.Transaction;
47import androidx.room.paging.LimitOffsetDataSource;
48
49import org.junit.After;
50import org.junit.Before;
51import org.junit.Rule;
52import org.junit.Test;
53import org.junit.runner.RunWith;
54import org.junit.runners.Parameterized;
55
56import java.util.Collections;
57import java.util.List;
58import java.util.concurrent.ExecutionException;
59import java.util.concurrent.FutureTask;
60import java.util.concurrent.TimeUnit;
61import java.util.concurrent.TimeoutException;
62import java.util.concurrent.atomic.AtomicInteger;
63
64import io.reactivex.Flowable;
65import io.reactivex.Maybe;
66import io.reactivex.Single;
67import io.reactivex.observers.TestObserver;
68import io.reactivex.schedulers.Schedulers;
69import io.reactivex.subscribers.TestSubscriber;
70
71@SmallTest
72@RunWith(Parameterized.class)
73@SuppressWarnings("CheckReturnValue")
74public class QueryTransactionTest {
75    @Rule
76    public CountingTaskExecutorRule countingTaskExecutorRule = new CountingTaskExecutorRule();
77    private static final AtomicInteger sStartedTransactionCount = new AtomicInteger(0);
78    private TransactionDb mDb;
79    private final boolean mUseTransactionDao;
80    private Entity1Dao mDao;
81    private final LiveDataQueryTest.TestLifecycleOwner mLifecycleOwner = new LiveDataQueryTest
82            .TestLifecycleOwner();
83
84    @NonNull
85    @Parameterized.Parameters(name = "useTransaction_{0}")
86    public static Boolean[] getParams() {
87        return new Boolean[]{false, true};
88    }
89
90    public QueryTransactionTest(boolean useTransactionDao) {
91        mUseTransactionDao = useTransactionDao;
92    }
93
94    @Before
95    public void initDb() {
96        InstrumentationRegistry.getInstrumentation().runOnMainSync(
97                () -> mLifecycleOwner.handleEvent(Lifecycle.Event.ON_START));
98
99        resetTransactionCount();
100        mDb = Room.inMemoryDatabaseBuilder(InstrumentationRegistry.getTargetContext(),
101                TransactionDb.class).build();
102        mDao = mUseTransactionDao ? mDb.transactionDao() : mDb.dao();
103        drain();
104    }
105
106    @After
107    public void closeDb() {
108        InstrumentationRegistry.getInstrumentation().runOnMainSync(
109                () -> mLifecycleOwner.handleEvent(Lifecycle.Event.ON_DESTROY));
110        drain();
111        mDb.close();
112    }
113
114    @Test
115    public void readList() {
116        mDao.insert(new Entity1(1, "foo"));
117        resetTransactionCount();
118
119        int expectedTransactionCount = mUseTransactionDao ? 1 : 0;
120        List<Entity1> allEntities = mDao.allEntities();
121        assertTransactionCount(allEntities, expectedTransactionCount);
122    }
123
124    @Test
125    public void liveData() {
126        LiveData<List<Entity1>> listLiveData = mDao.liveData();
127        observeForever(listLiveData);
128        drain();
129        assertThat(listLiveData.getValue(), is(Collections.<Entity1>emptyList()));
130
131        resetTransactionCount();
132        mDao.insert(new Entity1(1, "foo"));
133        drain();
134
135        //noinspection ConstantConditions
136        assertThat(listLiveData.getValue().size(), is(1));
137        int expectedTransactionCount = mUseTransactionDao ? 2 : 1;
138        assertTransactionCount(listLiveData.getValue(), expectedTransactionCount);
139    }
140
141    @Test
142    public void flowable() {
143        Flowable<List<Entity1>> flowable = mDao.flowable();
144        TestSubscriber<List<Entity1>> subscriber = observe(flowable);
145        drain();
146        assertThat(subscriber.values().size(), is(1));
147
148        resetTransactionCount();
149        mDao.insert(new Entity1(1, "foo"));
150        drain();
151
152        List<Entity1> allEntities = subscriber.values().get(1);
153        assertThat(allEntities.size(), is(1));
154        int expectedTransactionCount = mUseTransactionDao ? 2 : 1;
155        assertTransactionCount(allEntities, expectedTransactionCount);
156    }
157
158    @Test
159    public void maybe() {
160        mDao.insert(new Entity1(1, "foo"));
161        resetTransactionCount();
162
163        int expectedTransactionCount = mUseTransactionDao ? 1 : 0;
164        Maybe<List<Entity1>> listMaybe = mDao.maybe();
165        TestObserver<List<Entity1>> observer = observe(listMaybe);
166        drain();
167        List<Entity1> allEntities = observer.values().get(0);
168        assertTransactionCount(allEntities, expectedTransactionCount);
169    }
170
171    @Test
172    public void single() {
173        mDao.insert(new Entity1(1, "foo"));
174        resetTransactionCount();
175
176        int expectedTransactionCount = mUseTransactionDao ? 1 : 0;
177        Single<List<Entity1>> listMaybe = mDao.single();
178        TestObserver<List<Entity1>> observer = observe(listMaybe);
179        drain();
180        List<Entity1> allEntities = observer.values().get(0);
181        assertTransactionCount(allEntities, expectedTransactionCount);
182    }
183
184    @Test
185    public void relation() {
186        mDao.insert(new Entity1(1, "foo"));
187        mDao.insert(new Child(1, 1));
188        mDao.insert(new Child(2, 1));
189        resetTransactionCount();
190
191        List<Entity1WithChildren> result = mDao.withRelation();
192        int expectedTransactionCount = mUseTransactionDao ? 1 : 0;
193        assertTransactionCountWithChildren(result, expectedTransactionCount);
194    }
195
196    @Test
197    public void pagedList() {
198        LiveData<PagedList<Entity1>> pagedList =
199                new LivePagedListBuilder<>(mDao.pagedList(), 10).build();
200        observeForever(pagedList);
201        drain();
202        assertThat(sStartedTransactionCount.get(), is(mUseTransactionDao ? 0 : 0));
203
204        mDao.insert(new Entity1(1, "foo"));
205        drain();
206        //noinspection ConstantConditions
207        assertThat(pagedList.getValue().size(), is(1));
208        assertTransactionCount(pagedList.getValue(), mUseTransactionDao ? 2 : 1);
209
210        mDao.insert(new Entity1(2, "bar"));
211        drain();
212        assertThat(pagedList.getValue().size(), is(2));
213        assertTransactionCount(pagedList.getValue(), mUseTransactionDao ? 4 : 2);
214    }
215
216    @Test
217    public void dataSource() {
218        mDao.insert(new Entity1(2, "bar"));
219        drain();
220        resetTransactionCount();
221        @SuppressWarnings("deprecation")
222        LimitOffsetDataSource<Entity1> dataSource =
223                (LimitOffsetDataSource<Entity1>) mDao.dataSource();
224        dataSource.loadRange(0, 10);
225        assertThat(sStartedTransactionCount.get(), is(mUseTransactionDao ? 1 : 0));
226    }
227
228    private void assertTransactionCount(List<Entity1> allEntities, int expectedTransactionCount) {
229        assertThat(sStartedTransactionCount.get(), is(expectedTransactionCount));
230        assertThat(allEntities.isEmpty(), is(false));
231        for (Entity1 entity1 : allEntities) {
232            assertThat(entity1.transactionId, is(expectedTransactionCount));
233        }
234    }
235
236    private void assertTransactionCountWithChildren(List<Entity1WithChildren> allEntities,
237            int expectedTransactionCount) {
238        assertThat(sStartedTransactionCount.get(), is(expectedTransactionCount));
239        assertThat(allEntities.isEmpty(), is(false));
240        for (Entity1WithChildren entity1 : allEntities) {
241            assertThat(entity1.transactionId, is(expectedTransactionCount));
242            assertThat(entity1.children, notNullValue());
243            assertThat(entity1.children.isEmpty(), is(false));
244            for (Child child : entity1.children) {
245                assertThat(child.transactionId, is(expectedTransactionCount));
246            }
247        }
248    }
249
250    private void resetTransactionCount() {
251        sStartedTransactionCount.set(0);
252    }
253
254    private void drain() {
255        try {
256            countingTaskExecutorRule.drainTasks(30, TimeUnit.SECONDS);
257        } catch (InterruptedException e) {
258            throw new AssertionError("interrupted", e);
259        } catch (TimeoutException e) {
260            throw new AssertionError("drain timed out", e);
261        }
262    }
263
264    private <T> TestSubscriber<T> observe(final Flowable<T> flowable) {
265        TestSubscriber<T> subscriber = new TestSubscriber<>();
266        flowable.observeOn(Schedulers.from(ArchTaskExecutor.getMainThreadExecutor()))
267                .subscribeWith(subscriber);
268        return subscriber;
269    }
270
271    private <T> TestObserver<T> observe(final Maybe<T> maybe) {
272        TestObserver<T> observer = new TestObserver<>();
273        maybe.observeOn(Schedulers.from(ArchTaskExecutor.getMainThreadExecutor()))
274                .subscribeWith(observer);
275        return observer;
276    }
277
278    private <T> TestObserver<T> observe(final Single<T> single) {
279        TestObserver<T> observer = new TestObserver<>();
280        single.observeOn(Schedulers.from(ArchTaskExecutor.getMainThreadExecutor()))
281                .subscribeWith(observer);
282        return observer;
283    }
284
285    private <T> void observeForever(final LiveData<T> liveData) {
286        FutureTask<Void> futureTask = new FutureTask<>(() -> {
287            liveData.observe(mLifecycleOwner, t -> {
288            });
289            return null;
290        });
291        ArchTaskExecutor.getMainThreadExecutor().execute(futureTask);
292        try {
293            futureTask.get();
294        } catch (InterruptedException e) {
295            throw new AssertionError("interrupted", e);
296        } catch (ExecutionException e) {
297            throw new AssertionError("execution error", e);
298        }
299    }
300
301    @SuppressWarnings("WeakerAccess")
302    static class Entity1WithChildren extends Entity1 {
303        @Relation(entity = Child.class, parentColumn = "id",
304                entityColumn = "entity1Id")
305        public List<Child> children;
306
307        Entity1WithChildren(int id, String value) {
308            super(id, value);
309        }
310    }
311
312    @SuppressWarnings("WeakerAccess")
313    @Entity
314    static class Child {
315        @PrimaryKey(autoGenerate = true)
316        public int id;
317        public int entity1Id;
318        @Ignore
319        public final int transactionId = sStartedTransactionCount.get();
320
321        Child(int id, int entity1Id) {
322            this.id = id;
323            this.entity1Id = entity1Id;
324        }
325    }
326
327    @SuppressWarnings("WeakerAccess")
328    @Entity
329    static class Entity1 {
330        @PrimaryKey(autoGenerate = true)
331        public int id;
332        public String value;
333        @Ignore
334        public final int transactionId = sStartedTransactionCount.get();
335
336        Entity1(int id, String value) {
337            this.id = id;
338            this.value = value;
339        }
340    }
341
342    // we don't support dao inheritance for queries so for now, go with this
343    interface Entity1Dao {
344        String SELECT_ALL = "select * from Entity1";
345
346        List<Entity1> allEntities();
347
348        Flowable<List<Entity1>> flowable();
349
350        Maybe<List<Entity1>> maybe();
351
352        Single<List<Entity1>> single();
353
354        LiveData<List<Entity1>> liveData();
355
356        List<Entity1WithChildren> withRelation();
357
358        DataSource.Factory<Integer, Entity1> pagedList();
359
360        PositionalDataSource<Entity1> dataSource();
361
362        @Insert
363        void insert(Entity1 entity1);
364
365        @Insert
366        void insert(Child entity1);
367    }
368
369    @Dao
370    interface EntityDao extends Entity1Dao {
371        @Override
372        @Query(SELECT_ALL)
373        List<Entity1> allEntities();
374
375        @Override
376        @Query(SELECT_ALL)
377        Flowable<List<Entity1>> flowable();
378
379        @Override
380        @Query(SELECT_ALL)
381        LiveData<List<Entity1>> liveData();
382
383        @Override
384        @Query(SELECT_ALL)
385        Maybe<List<Entity1>> maybe();
386
387        @Override
388        @Query(SELECT_ALL)
389        Single<List<Entity1>> single();
390
391        @Override
392        @Query(SELECT_ALL)
393        @SuppressWarnings(RoomWarnings.RELATION_QUERY_WITHOUT_TRANSACTION)
394        List<Entity1WithChildren> withRelation();
395
396        @Override
397        @Query(SELECT_ALL)
398        DataSource.Factory<Integer, Entity1> pagedList();
399
400        @Override
401        @Query(SELECT_ALL)
402        PositionalDataSource<Entity1> dataSource();
403    }
404
405    @Dao
406    interface TransactionDao extends Entity1Dao {
407        @Override
408        @Transaction
409        @Query(SELECT_ALL)
410        List<Entity1> allEntities();
411
412        @Override
413        @Transaction
414        @Query(SELECT_ALL)
415        Flowable<List<Entity1>> flowable();
416
417        @Override
418        @Transaction
419        @Query(SELECT_ALL)
420        LiveData<List<Entity1>> liveData();
421
422        @Override
423        @Transaction
424        @Query(SELECT_ALL)
425        Maybe<List<Entity1>> maybe();
426
427        @Override
428        @Transaction
429        @Query(SELECT_ALL)
430        Single<List<Entity1>> single();
431
432        @Override
433        @Transaction
434        @Query(SELECT_ALL)
435        List<Entity1WithChildren> withRelation();
436
437        @Override
438        @Transaction
439        @Query(SELECT_ALL)
440        DataSource.Factory<Integer, Entity1> pagedList();
441
442        @Override
443        @Transaction
444        @Query(SELECT_ALL)
445        PositionalDataSource<Entity1> dataSource();
446    }
447
448    @Database(version = 1, entities = {Entity1.class, Child.class}, exportSchema = false)
449    abstract static class TransactionDb extends RoomDatabase {
450        abstract EntityDao dao();
451
452        abstract TransactionDao transactionDao();
453
454        @Override
455        public void beginTransaction() {
456            super.beginTransaction();
457            sStartedTransactionCount.incrementAndGet();
458        }
459    }
460}
461