1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package androidx.room.solver.query
18
19import androidx.room.Dao
20import androidx.room.Query
21import androidx.room.ext.RoomTypeNames.ROOM_SQL_QUERY
22import androidx.room.ext.RoomTypeNames.STRING_UTIL
23import androidx.room.processor.QueryMethodProcessor
24import androidx.room.testing.TestProcessor
25import androidx.room.writer.QueryWriter
26import com.google.auto.common.MoreElements
27import com.google.auto.common.MoreTypes
28import com.google.common.truth.Truth
29import com.google.testing.compile.CompileTester
30import com.google.testing.compile.JavaFileObjects
31import com.google.testing.compile.JavaSourceSubjectFactory
32import org.hamcrest.CoreMatchers.`is`
33import org.hamcrest.MatcherAssert.assertThat
34import org.junit.Test
35import org.junit.runner.RunWith
36import org.junit.runners.JUnit4
37import testCodeGenScope
38
39@RunWith(JUnit4::class)
40class QueryWriterTest {
41    companion object {
42        const val DAO_PREFIX = """
43                package foo.bar;
44                import androidx.room.*;
45                import java.util.*;
46                @Dao
47                abstract class MyClass {
48                """
49        const val DAO_SUFFIX = "}"
50        val QUERY = ROOM_SQL_QUERY.toString()
51    }
52
53    @Test
54    fun simpleNoArgQuery() {
55        singleQueryMethod("""
56                @Query("SELECT id FROM users")
57                abstract java.util.List<Integer> selectAllIds();
58                """) { writer ->
59            val scope = testCodeGenScope()
60            writer.prepareReadAndBind("_sql", "_stmt", scope)
61            assertThat(scope.generate().trim(), `is`("""
62                    final java.lang.String _sql = "SELECT id FROM users";
63                    final $QUERY _stmt = $QUERY.acquire(_sql, 0);
64                    """.trimIndent()))
65        }.compilesWithoutError()
66    }
67
68    @Test
69    fun simpleStringArgs() {
70        singleQueryMethod("""
71                @Query("SELECT id FROM users WHERE name LIKE :name")
72                abstract java.util.List<Integer> selectAllIds(String name);
73                """) { writer ->
74            val scope = testCodeGenScope()
75            writer.prepareReadAndBind("_sql", "_stmt", scope)
76            assertThat(scope.generate().trim(), `is`(
77                    """
78                    final java.lang.String _sql = "SELECT id FROM users WHERE name LIKE ?";
79                    final $QUERY _stmt = $QUERY.acquire(_sql, 1);
80                    int _argIndex = 1;
81                    if (name == null) {
82                      _stmt.bindNull(_argIndex);
83                    } else {
84                      _stmt.bindString(_argIndex, name);
85                    }
86                    """.trimIndent()))
87        }.compilesWithoutError()
88    }
89
90    @Test
91    fun twoIntArgs() {
92        singleQueryMethod("""
93                @Query("SELECT id FROM users WHERE id IN(:id1,:id2)")
94                abstract java.util.List<Integer> selectAllIds(int id1, int id2);
95                """) { writer ->
96            val scope = testCodeGenScope()
97            writer.prepareReadAndBind("_sql", "_stmt", scope)
98            assertThat(scope.generate().trim(), `is`(
99                    """
100                    final java.lang.String _sql = "SELECT id FROM users WHERE id IN(?,?)";
101                    final $QUERY _stmt = $QUERY.acquire(_sql, 2);
102                    int _argIndex = 1;
103                    _stmt.bindLong(_argIndex, id1);
104                    _argIndex = 2;
105                    _stmt.bindLong(_argIndex, id2);
106                    """.trimIndent()))
107        }.compilesWithoutError()
108    }
109
110    @Test
111    fun aLongAndIntVarArg() {
112        singleQueryMethod("""
113                @Query("SELECT id FROM users WHERE id IN(:ids) AND age > :time")
114                abstract java.util.List<Integer> selectAllIds(long time, int... ids);
115                """) { writer ->
116            val scope = testCodeGenScope()
117            writer.prepareReadAndBind("_sql", "_stmt", scope)
118            assertThat(scope.generate().trim(), `is`(
119                    """
120                    java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder();
121                    _stringBuilder.append("SELECT id FROM users WHERE id IN(");
122                    final int _inputSize = ids.length;
123                    $STRING_UTIL.appendPlaceholders(_stringBuilder, _inputSize);
124                    _stringBuilder.append(") AND age > ");
125                    _stringBuilder.append("?");
126                    final java.lang.String _sql = _stringBuilder.toString();
127                    final int _argCount = 1 + _inputSize;
128                    final $QUERY _stmt = $QUERY.acquire(_sql, _argCount);
129                    int _argIndex = 1;
130                    for (int _item : ids) {
131                      _stmt.bindLong(_argIndex, _item);
132                      _argIndex ++;
133                    }
134                    _argIndex = 1 + _inputSize;
135                    _stmt.bindLong(_argIndex, time);
136                    """.trimIndent()))
137        }.compilesWithoutError()
138    }
139
140    val collectionOut = """
141                    java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder();
142                    _stringBuilder.append("SELECT id FROM users WHERE id IN(");
143                    final int _inputSize = ids.size();
144                    $STRING_UTIL.appendPlaceholders(_stringBuilder, _inputSize);
145                    _stringBuilder.append(") AND age > ");
146                    _stringBuilder.append("?");
147                    final java.lang.String _sql = _stringBuilder.toString();
148                    final int _argCount = 1 + _inputSize;
149                    final $QUERY _stmt = $QUERY.acquire(_sql, _argCount);
150                    int _argIndex = 1;
151                    for (java.lang.Integer _item : ids) {
152                      if (_item == null) {
153                        _stmt.bindNull(_argIndex);
154                      } else {
155                        _stmt.bindLong(_argIndex, _item);
156                      }
157                      _argIndex ++;
158                    }
159                    _argIndex = 1 + _inputSize;
160                    _stmt.bindLong(_argIndex, time);
161                    """.trimIndent()
162
163    @Test
164    fun aLongAndIntegerList() {
165        singleQueryMethod("""
166                @Query("SELECT id FROM users WHERE id IN(:ids) AND age > :time")
167                abstract List<Integer> selectAllIds(long time, List<Integer> ids);
168                """) { writer ->
169            val scope = testCodeGenScope()
170            writer.prepareReadAndBind("_sql", "_stmt", scope)
171            assertThat(scope.generate().trim(), `is`(collectionOut))
172        }.compilesWithoutError()
173    }
174
175    @Test
176    fun aLongAndIntegerSet() {
177        singleQueryMethod("""
178                @Query("SELECT id FROM users WHERE id IN(:ids) AND age > :time")
179                abstract List<Integer> selectAllIds(long time, Set<Integer> ids);
180                """) { writer ->
181            val scope = testCodeGenScope()
182            writer.prepareReadAndBind("_sql", "_stmt", scope)
183            assertThat(scope.generate().trim(), `is`(collectionOut))
184        }.compilesWithoutError()
185    }
186
187    @Test
188    fun testMultipleBindParamsWithSameName() {
189        singleQueryMethod("""
190                @Query("SELECT id FROM users WHERE age > :age OR bage > :age")
191                abstract List<Integer> selectAllIds(int age);
192                """) { writer ->
193            val scope = testCodeGenScope()
194            writer.prepareReadAndBind("_sql", "_stmt", scope)
195            assertThat(scope.generate().trim(), `is`("""
196                    final java.lang.String _sql = "SELECT id FROM users WHERE age > ? OR bage > ?";
197                    final $QUERY _stmt = $QUERY.acquire(_sql, 2);
198                    int _argIndex = 1;
199                    _stmt.bindLong(_argIndex, age);
200                    _argIndex = 2;
201                    _stmt.bindLong(_argIndex, age);
202                    """.trimIndent()))
203        }.compilesWithoutError()
204    }
205
206    @Test
207    fun testMultipleBindParamsWithSameNameWithVarArg() {
208        singleQueryMethod("""
209                @Query("SELECT id FROM users WHERE age > :age OR bage > :age OR fage IN(:ages)")
210                abstract List<Integer> selectAllIds(int age, int... ages);
211                """) { writer ->
212            val scope = testCodeGenScope()
213            writer.prepareReadAndBind("_sql", "_stmt", scope)
214            assertThat(scope.generate().trim(), `is`("""
215                    java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder();
216                    _stringBuilder.append("SELECT id FROM users WHERE age > ");
217                    _stringBuilder.append("?");
218                    _stringBuilder.append(" OR bage > ");
219                    _stringBuilder.append("?");
220                    _stringBuilder.append(" OR fage IN(");
221                    final int _inputSize = ages.length;
222                    $STRING_UTIL.appendPlaceholders(_stringBuilder, _inputSize);
223                    _stringBuilder.append(")");
224                    final java.lang.String _sql = _stringBuilder.toString();
225                    final int _argCount = 2 + _inputSize;
226                    final $QUERY _stmt = $QUERY.acquire(_sql, _argCount);
227                    int _argIndex = 1;
228                    _stmt.bindLong(_argIndex, age);
229                    _argIndex = 2;
230                    _stmt.bindLong(_argIndex, age);
231                    _argIndex = 3;
232                    for (int _item : ages) {
233                      _stmt.bindLong(_argIndex, _item);
234                      _argIndex ++;
235                    }
236                    """.trimIndent()))
237        }.compilesWithoutError()
238    }
239
240    @Test
241    fun testMultipleBindParamsWithSameNameWithVarArgInTwoBindings() {
242        singleQueryMethod("""
243                @Query("SELECT id FROM users WHERE age IN (:ages) OR bage > :age OR fage IN(:ages)")
244                abstract List<Integer> selectAllIds(int age, int... ages);
245                """) { writer ->
246            val scope = testCodeGenScope()
247            writer.prepareReadAndBind("_sql", "_stmt", scope)
248            assertThat(scope.generate().trim(), `is`("""
249                    java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder();
250                    _stringBuilder.append("SELECT id FROM users WHERE age IN (");
251                    final int _inputSize = ages.length;
252                    $STRING_UTIL.appendPlaceholders(_stringBuilder, _inputSize);
253                    _stringBuilder.append(") OR bage > ");
254                    _stringBuilder.append("?");
255                    _stringBuilder.append(" OR fage IN(");
256                    final int _inputSize_1 = ages.length;
257                    $STRING_UTIL.appendPlaceholders(_stringBuilder, _inputSize_1);
258                    _stringBuilder.append(")");
259                    final java.lang.String _sql = _stringBuilder.toString();
260                    final int _argCount = 1 + _inputSize + _inputSize_1;
261                    final $QUERY _stmt = $QUERY.acquire(_sql, _argCount);
262                    int _argIndex = 1;
263                    for (int _item : ages) {
264                      _stmt.bindLong(_argIndex, _item);
265                      _argIndex ++;
266                    }
267                    _argIndex = 1 + _inputSize;
268                    _stmt.bindLong(_argIndex, age);
269                    _argIndex = 2 + _inputSize;
270                    for (int _item_1 : ages) {
271                      _stmt.bindLong(_argIndex, _item_1);
272                      _argIndex ++;
273                    }
274                    """.trimIndent()))
275        }.compilesWithoutError()
276    }
277
278    fun singleQueryMethod(vararg input: String,
279                          handler: (QueryWriter) -> Unit):
280            CompileTester {
281        return Truth.assertAbout(JavaSourceSubjectFactory.javaSource())
282                .that(JavaFileObjects.forSourceString("foo.bar.MyClass",
283                        DAO_PREFIX + input.joinToString("\n") + DAO_SUFFIX
284                ))
285                .processedWith(TestProcessor.builder()
286                        .forAnnotations(Query::class, Dao::class)
287                        .nextRunHandler { invocation ->
288                            val (owner, methods) = invocation.roundEnv
289                                    .getElementsAnnotatedWith(Dao::class.java)
290                                    .map {
291                                        Pair(it,
292                                                invocation.processingEnv.elementUtils
293                                                        .getAllMembers(MoreElements.asType(it))
294                                                        .filter {
295                                                            MoreElements.isAnnotationPresent(it,
296                                                                    Query::class.java)
297                                                        }
298                                        )
299                                    }.first { it.second.isNotEmpty() }
300                            val parser = QueryMethodProcessor(
301                                    baseContext = invocation.context,
302                                    containing = MoreTypes.asDeclared(owner.asType()),
303                                    executableElement = MoreElements.asExecutable(methods.first()))
304                            val parsedQuery = parser.process()
305                            handler(QueryWriter(parsedQuery))
306                            true
307                        }
308                        .build())
309    }
310}
311