1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package androidx.room.processor
18
19import COMMON
20import androidx.room.Dao
21import androidx.room.Insert
22import androidx.room.OnConflictStrategy
23import androidx.room.ext.CommonTypeNames
24import androidx.room.ext.typeName
25import androidx.room.testing.TestInvocation
26import androidx.room.testing.TestProcessor
27import androidx.room.vo.InsertionMethod
28import androidx.room.vo.InsertionMethod.Type
29import com.google.auto.common.MoreElements
30import com.google.auto.common.MoreTypes
31import com.google.common.truth.Truth.assertAbout
32import com.google.testing.compile.CompileTester
33import com.google.testing.compile.JavaFileObjects
34import com.google.testing.compile.JavaSourcesSubjectFactory
35import com.squareup.javapoet.ArrayTypeName
36import com.squareup.javapoet.ClassName
37import com.squareup.javapoet.ParameterizedTypeName
38import com.squareup.javapoet.TypeName
39import org.hamcrest.CoreMatchers.`is`
40import org.hamcrest.CoreMatchers.nullValue
41import org.hamcrest.MatcherAssert.assertThat
42import org.junit.Test
43import org.junit.runner.RunWith
44import org.junit.runners.JUnit4
45
46@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
47@RunWith(JUnit4::class)
48class InsertionMethodProcessorTest {
49    companion object {
50        const val DAO_PREFIX = """
51                package foo.bar;
52                import androidx.room.*;
53                import java.util.*;
54                @Dao
55                abstract class MyClass {
56                """
57        const val DAO_SUFFIX = "}"
58        val USER_TYPE_NAME: TypeName = COMMON.USER_TYPE_NAME
59        val BOOK_TYPE_NAME: TypeName = ClassName.get("foo.bar", "Book")
60    }
61
62    @Test
63    fun readNoParams() {
64        singleInsertMethod(
65                """
66                @Insert
67                abstract public void foo();
68                """) { insertion, _ ->
69            assertThat(insertion.name, `is`("foo"))
70            assertThat(insertion.parameters.size, `is`(0))
71            assertThat(insertion.returnType.typeName(), `is`(TypeName.VOID))
72            assertThat(insertion.entities.size, `is`(0))
73        }.failsToCompile().withErrorContaining(
74                ProcessorErrors.INSERTION_DOES_NOT_HAVE_ANY_PARAMETERS_TO_INSERT)
75    }
76
77    @Test
78    fun insertSingle() {
79        singleInsertMethod(
80                """
81                @Insert
82                abstract public long foo(User user);
83                """) { insertion, _ ->
84            assertThat(insertion.name, `is`("foo"))
85            assertThat(insertion.parameters.size, `is`(1))
86            val param = insertion.parameters.first()
87            assertThat(param.type.typeName(), `is`(USER_TYPE_NAME))
88            assertThat(param.entityType?.typeName(), `is`(USER_TYPE_NAME))
89            assertThat(insertion.entities["user"]?.typeName,
90                    `is`(ClassName.get("foo.bar", "User") as TypeName))
91            assertThat(insertion.returnType.typeName(), `is`(TypeName.LONG))
92        }.compilesWithoutError()
93    }
94
95    @Test
96    fun insertNotAnEntity() {
97        singleInsertMethod(
98                """
99                @Insert
100                abstract public void foo(NotAnEntity notValid);
101                """) { insertion, _ ->
102            assertThat(insertion.name, `is`("foo"))
103            assertThat(insertion.parameters.size, `is`(1))
104            val param = insertion.parameters.first()
105            assertThat(param.entityType, `is`(nullValue()))
106            assertThat(insertion.entities.size, `is`(0))
107        }.failsToCompile().withErrorContaining(
108                ProcessorErrors.CANNOT_FIND_ENTITY_FOR_SHORTCUT_QUERY_PARAMETER
109        )
110    }
111
112    @Test
113    fun insertTwo() {
114        singleInsertMethod(
115                """
116                @Insert
117                abstract public void foo(User u1, User u2);
118                """) { insertion, _ ->
119            assertThat(insertion.name, `is`("foo"))
120
121            assertThat(insertion.parameters.size, `is`(2))
122            insertion.parameters.forEach {
123                assertThat(it.type.typeName(), `is`(USER_TYPE_NAME))
124                assertThat(it.entityType?.typeName(), `is`(USER_TYPE_NAME))
125            }
126            assertThat(insertion.entities.size, `is`(2))
127            assertThat(insertion.entities["u1"]?.typeName, `is`(USER_TYPE_NAME))
128            assertThat(insertion.entities["u2"]?.typeName, `is`(USER_TYPE_NAME))
129            assertThat(insertion.parameters.map { it.name }, `is`(listOf("u1", "u2")))
130            assertThat(insertion.returnType.typeName(), `is`(TypeName.VOID))
131        }.compilesWithoutError()
132    }
133
134    @Test
135    fun insertList() {
136        singleInsertMethod(
137                """
138                @Insert
139                abstract public List<Long> insertUsers(List<User> users);
140                """) { insertion, _ ->
141            assertThat(insertion.name, `is`("insertUsers"))
142            assertThat(insertion.parameters.size, `is`(1))
143            val param = insertion.parameters.first()
144            assertThat(param.type.typeName(), `is`(
145                    ParameterizedTypeName.get(
146                            ClassName.get("java.util", "List"), USER_TYPE_NAME) as TypeName))
147            assertThat(param.entityType?.typeName(), `is`(USER_TYPE_NAME))
148            assertThat(insertion.entities.size, `is`(1))
149            assertThat(insertion.entities["users"]?.typeName, `is`(USER_TYPE_NAME))
150            assertThat(insertion.returnType.typeName(), `is`(
151                    ParameterizedTypeName.get(ClassName.get("java.util", "List"),
152                            ClassName.get("java.lang", "Long")) as TypeName
153            ))
154        }.compilesWithoutError()
155    }
156
157    @Test
158    fun insertArray() {
159        singleInsertMethod(
160                """
161                @Insert
162                abstract public void insertUsers(User[] users);
163                """) { insertion, _ ->
164            assertThat(insertion.name, `is`("insertUsers"))
165            assertThat(insertion.parameters.size, `is`(1))
166            val param = insertion.parameters.first()
167            assertThat(param.type.typeName(), `is`(
168                    ArrayTypeName.of(COMMON.USER_TYPE_NAME) as TypeName))
169            assertThat(insertion.entities.size, `is`(1))
170            assertThat(insertion.entities["users"]?.typeName, `is`(USER_TYPE_NAME))
171            assertThat(insertion.returnType.typeName(), `is`(TypeName.VOID))
172        }.compilesWithoutError()
173    }
174
175    @Test
176    fun insertSet() {
177        singleInsertMethod(
178                """
179                @Insert
180                abstract public void insertUsers(Set<User> users);
181                """) { insertion, _ ->
182            assertThat(insertion.name, `is`("insertUsers"))
183            assertThat(insertion.parameters.size, `is`(1))
184            val param = insertion.parameters.first()
185            assertThat(param.type.typeName(), `is`(
186                    ParameterizedTypeName.get(ClassName.get("java.util", "Set")
187                            , COMMON.USER_TYPE_NAME) as TypeName))
188            assertThat(insertion.entities.size, `is`(1))
189            assertThat(insertion.entities["users"]?.typeName, `is`(USER_TYPE_NAME))
190            assertThat(insertion.returnType.typeName(), `is`(TypeName.VOID))
191        }.compilesWithoutError()
192    }
193
194    @Test
195    fun insertQueue() {
196        singleInsertMethod(
197                """
198                @Insert
199                abstract public void insertUsers(Queue<User> users);
200                """) { insertion, _ ->
201            assertThat(insertion.name, `is`("insertUsers"))
202            assertThat(insertion.parameters.size, `is`(1))
203            val param = insertion.parameters.first()
204            assertThat(param.type.typeName(), `is`(
205                    ParameterizedTypeName.get(ClassName.get("java.util", "Queue")
206                            , USER_TYPE_NAME) as TypeName))
207            assertThat(insertion.entities.size, `is`(1))
208            assertThat(insertion.entities["users"]?.typeName, `is`(USER_TYPE_NAME))
209            assertThat(insertion.returnType.typeName(), `is`(TypeName.VOID))
210        }.compilesWithoutError()
211    }
212
213    @Test
214    fun insertIterable() {
215        singleInsertMethod("""
216                @Insert
217                abstract public void insert(Iterable<User> users);
218                """) { insertion, _ ->
219            assertThat(insertion.name, `is`("insert"))
220            assertThat(insertion.parameters.size, `is`(1))
221            val param = insertion.parameters.first()
222            assertThat(param.type.typeName(), `is`(ParameterizedTypeName.get(
223                    ClassName.get("java.lang", "Iterable"), USER_TYPE_NAME) as TypeName))
224            assertThat(insertion.entities.size, `is`(1))
225            assertThat(insertion.entities["users"]?.typeName, `is`(USER_TYPE_NAME))
226            assertThat(insertion.returnType.typeName(), `is`(TypeName.VOID))
227        }.compilesWithoutError()
228    }
229
230    @Test
231    fun insertCustomCollection() {
232        singleInsertMethod("""
233                static class MyList<Irrelevant, Item> extends ArrayList<Item> {}
234                @Insert
235                abstract public void insert(MyList<String, User> users);
236                """) { insertion, _ ->
237            assertThat(insertion.name, `is`("insert"))
238            assertThat(insertion.parameters.size, `is`(1))
239            val param = insertion.parameters.first()
240            assertThat(param.type.typeName(), `is`(ParameterizedTypeName.get(
241                    ClassName.get("foo.bar", "MyClass.MyList"),
242                    CommonTypeNames.STRING, USER_TYPE_NAME) as TypeName))
243            assertThat(insertion.entities.size, `is`(1))
244            assertThat(insertion.entities["users"]?.typeName, `is`(USER_TYPE_NAME))
245            assertThat(insertion.returnType.typeName(), `is`(TypeName.VOID))
246        }.compilesWithoutError()
247    }
248
249    @Test
250    fun insertDifferentTypes() {
251        singleInsertMethod(
252                """
253                @Insert
254                abstract public void foo(User u1, Book b1);
255                """) { insertion, _ ->
256            assertThat(insertion.parameters.size, `is`(2))
257            assertThat(insertion.parameters[0].type.typeName().toString(),
258                    `is`("foo.bar.User"))
259            assertThat(insertion.parameters[1].type.typeName().toString(),
260                    `is`("foo.bar.Book"))
261            assertThat(insertion.parameters.map { it.name }, `is`(listOf("u1", "b1")))
262            assertThat(insertion.returnType.typeName(), `is`(TypeName.VOID))
263            assertThat(insertion.entities.size, `is`(2))
264            assertThat(insertion.entities["u1"]?.typeName, `is`(USER_TYPE_NAME))
265            assertThat(insertion.entities["b1"]?.typeName, `is`(BOOK_TYPE_NAME))
266        }.compilesWithoutError()
267    }
268
269    @Test
270    fun onConflict_Default() {
271        singleInsertMethod(
272                """
273                @Insert
274                abstract public void foo(User user);
275                """) { insertion, _ ->
276            assertThat(insertion.onConflict, `is`(OnConflictStrategy.ABORT))
277        }.compilesWithoutError()
278    }
279
280    @Test
281    fun onConflict_Invalid() {
282        singleInsertMethod(
283                """
284                @Insert(onConflict = -1)
285                abstract public void foo(User user);
286                """) { _, _ ->
287        }.failsToCompile().withErrorContaining(ProcessorErrors.INVALID_ON_CONFLICT_VALUE)
288    }
289
290    @Test
291    fun onConflict_EachValue() {
292        listOf(
293                Pair("REPLACE", 1),
294                Pair("ROLLBACK", 2),
295                Pair("ABORT", 3),
296                Pair("FAIL", 4),
297                Pair("IGNORE", 5)
298        ).forEach { pair ->
299            singleInsertMethod(
300                    """
301                @Insert(onConflict=OnConflictStrategy.${pair.first})
302                abstract public void foo(User user);
303                """) { insertion, _ ->
304                assertThat(insertion.onConflict, `is`(pair.second))
305            }.compilesWithoutError()
306        }
307    }
308
309    @Test
310    fun invalidReturnType() {
311        singleInsertMethod(
312                """
313                @Insert
314                abstract public int foo(User user);
315                """) { insertion, _ ->
316            assertThat(insertion.insertionType, `is`(nullValue()))
317        }.failsToCompile().withErrorContaining(
318                ProcessorErrors.INVALID_INSERTION_METHOD_RETURN_TYPE)
319    }
320
321    @Test
322    fun mismatchedReturnType() {
323        singleInsertMethod(
324                """
325                @Insert
326                abstract public long[] foo(User user);
327                """) { insertion, _ ->
328            assertThat(insertion.insertionType, `is`(nullValue()))
329        }.failsToCompile().withErrorContaining(
330                ProcessorErrors.insertionMethodReturnTypeMismatch(
331                        ArrayTypeName.of(TypeName.LONG),
332                        InsertionMethodProcessor.SINGLE_ITEM_SET.map { it.returnTypeName }))
333    }
334
335    @Test
336    fun mismatchedReturnType2() {
337        singleInsertMethod(
338                """
339                @Insert
340                abstract public long foo(User... user);
341                """) { insertion, _ ->
342            assertThat(insertion.insertionType, `is`(nullValue()))
343        }.failsToCompile().withErrorContaining(
344                ProcessorErrors.insertionMethodReturnTypeMismatch(
345                        TypeName.LONG,
346                        InsertionMethodProcessor.MULTIPLE_ITEM_SET.map { it.returnTypeName }))
347    }
348
349    @Test
350    fun mismatchedReturnType3() {
351        singleInsertMethod(
352                """
353                @Insert
354                abstract public long foo(User user1, User user2);
355                """) { insertion, _ ->
356            assertThat(insertion.insertionType, `is`(nullValue()))
357        }.failsToCompile().withErrorContaining(
358                ProcessorErrors.insertionMethodReturnTypeMismatch(
359                        TypeName.LONG,
360                        InsertionMethodProcessor.VOID_SET.map { it.returnTypeName }))
361    }
362
363    @Test
364    fun validReturnTypes() {
365        listOf(
366                Pair("void", Type.INSERT_VOID),
367                Pair("long", Type.INSERT_SINGLE_ID),
368                Pair("long[]", Type.INSERT_ID_ARRAY),
369                Pair("Long[]", Type.INSERT_ID_ARRAY_BOX),
370                Pair("List<Long>", Type.INSERT_ID_LIST)
371        ).forEach { pair ->
372            val dots = if (pair.second in setOf(Type.INSERT_ID_LIST, Type.INSERT_ID_ARRAY,
373                    Type.INSERT_ID_ARRAY_BOX)) {
374                "..."
375            } else {
376                ""
377            }
378            singleInsertMethod(
379                    """
380                @Insert
381                abstract public ${pair.first} foo(User$dots user);
382                """) { insertion, _ ->
383                assertThat(insertion.insertMethodTypeFor(insertion.parameters.first()),
384                        `is`(pair.second))
385                assertThat(pair.toString(), insertion.insertionType, `is`(pair.second))
386            }.compilesWithoutError()
387        }
388    }
389
390    fun singleInsertMethod(
391            vararg input: String,
392            handler: (InsertionMethod, TestInvocation) -> Unit
393    ): CompileTester {
394        return assertAbout(JavaSourcesSubjectFactory.javaSources())
395                .that(listOf(JavaFileObjects.forSourceString("foo.bar.MyClass",
396                        DAO_PREFIX + input.joinToString("\n") + DAO_SUFFIX
397                ), COMMON.USER, COMMON.BOOK, COMMON.NOT_AN_ENTITY))
398                .processedWith(TestProcessor.builder()
399                        .forAnnotations(Insert::class, Dao::class)
400                        .nextRunHandler { invocation ->
401                            val (owner, methods) = invocation.roundEnv
402                                    .getElementsAnnotatedWith(Dao::class.java)
403                                    .map {
404                                        Pair(it,
405                                                invocation.processingEnv.elementUtils
406                                                        .getAllMembers(MoreElements.asType(it))
407                                                        .filter {
408                                                            MoreElements.isAnnotationPresent(it,
409                                                                    Insert::class.java)
410                                                        }
411                                        )
412                                    }.first { it.second.isNotEmpty() }
413                            val processor = InsertionMethodProcessor(
414                                    baseContext = invocation.context,
415                                    containing = MoreTypes.asDeclared(owner.asType()),
416                                    executableElement = MoreElements.asExecutable(methods.first()))
417                            val processed = processor.process()
418                            handler(processed, invocation)
419                            true
420                        }
421                        .build())
422    }
423}
424