1/* 2 * Copyright (C) 2016 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package androidx.room.solver.query 18 19import androidx.room.Dao 20import androidx.room.Query 21import androidx.room.ext.RoomTypeNames.ROOM_SQL_QUERY 22import androidx.room.ext.RoomTypeNames.STRING_UTIL 23import androidx.room.processor.QueryMethodProcessor 24import androidx.room.testing.TestProcessor 25import androidx.room.writer.QueryWriter 26import com.google.auto.common.MoreElements 27import com.google.auto.common.MoreTypes 28import com.google.common.truth.Truth 29import com.google.testing.compile.CompileTester 30import com.google.testing.compile.JavaFileObjects 31import com.google.testing.compile.JavaSourceSubjectFactory 32import org.hamcrest.CoreMatchers.`is` 33import org.hamcrest.MatcherAssert.assertThat 34import org.junit.Test 35import org.junit.runner.RunWith 36import org.junit.runners.JUnit4 37import testCodeGenScope 38 39@RunWith(JUnit4::class) 40class QueryWriterTest { 41 companion object { 42 const val DAO_PREFIX = """ 43 package foo.bar; 44 import androidx.room.*; 45 import java.util.*; 46 @Dao 47 abstract class MyClass { 48 """ 49 const val DAO_SUFFIX = "}" 50 val QUERY = ROOM_SQL_QUERY.toString() 51 } 52 53 @Test 54 fun simpleNoArgQuery() { 55 singleQueryMethod(""" 56 @Query("SELECT id FROM users") 57 abstract java.util.List<Integer> selectAllIds(); 58 """) { writer -> 59 val scope = testCodeGenScope() 60 writer.prepareReadAndBind("_sql", "_stmt", scope) 61 assertThat(scope.generate().trim(), `is`(""" 62 final java.lang.String _sql = "SELECT id FROM users"; 63 final $QUERY _stmt = $QUERY.acquire(_sql, 0); 64 """.trimIndent())) 65 }.compilesWithoutError() 66 } 67 68 @Test 69 fun simpleStringArgs() { 70 singleQueryMethod(""" 71 @Query("SELECT id FROM users WHERE name LIKE :name") 72 abstract java.util.List<Integer> selectAllIds(String name); 73 """) { writer -> 74 val scope = testCodeGenScope() 75 writer.prepareReadAndBind("_sql", "_stmt", scope) 76 assertThat(scope.generate().trim(), `is`( 77 """ 78 final java.lang.String _sql = "SELECT id FROM users WHERE name LIKE ?"; 79 final $QUERY _stmt = $QUERY.acquire(_sql, 1); 80 int _argIndex = 1; 81 if (name == null) { 82 _stmt.bindNull(_argIndex); 83 } else { 84 _stmt.bindString(_argIndex, name); 85 } 86 """.trimIndent())) 87 }.compilesWithoutError() 88 } 89 90 @Test 91 fun twoIntArgs() { 92 singleQueryMethod(""" 93 @Query("SELECT id FROM users WHERE id IN(:id1,:id2)") 94 abstract java.util.List<Integer> selectAllIds(int id1, int id2); 95 """) { writer -> 96 val scope = testCodeGenScope() 97 writer.prepareReadAndBind("_sql", "_stmt", scope) 98 assertThat(scope.generate().trim(), `is`( 99 """ 100 final java.lang.String _sql = "SELECT id FROM users WHERE id IN(?,?)"; 101 final $QUERY _stmt = $QUERY.acquire(_sql, 2); 102 int _argIndex = 1; 103 _stmt.bindLong(_argIndex, id1); 104 _argIndex = 2; 105 _stmt.bindLong(_argIndex, id2); 106 """.trimIndent())) 107 }.compilesWithoutError() 108 } 109 110 @Test 111 fun aLongAndIntVarArg() { 112 singleQueryMethod(""" 113 @Query("SELECT id FROM users WHERE id IN(:ids) AND age > :time") 114 abstract java.util.List<Integer> selectAllIds(long time, int... ids); 115 """) { writer -> 116 val scope = testCodeGenScope() 117 writer.prepareReadAndBind("_sql", "_stmt", scope) 118 assertThat(scope.generate().trim(), `is`( 119 """ 120 java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder(); 121 _stringBuilder.append("SELECT id FROM users WHERE id IN("); 122 final int _inputSize = ids.length; 123 $STRING_UTIL.appendPlaceholders(_stringBuilder, _inputSize); 124 _stringBuilder.append(") AND age > "); 125 _stringBuilder.append("?"); 126 final java.lang.String _sql = _stringBuilder.toString(); 127 final int _argCount = 1 + _inputSize; 128 final $QUERY _stmt = $QUERY.acquire(_sql, _argCount); 129 int _argIndex = 1; 130 for (int _item : ids) { 131 _stmt.bindLong(_argIndex, _item); 132 _argIndex ++; 133 } 134 _argIndex = 1 + _inputSize; 135 _stmt.bindLong(_argIndex, time); 136 """.trimIndent())) 137 }.compilesWithoutError() 138 } 139 140 val collectionOut = """ 141 java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder(); 142 _stringBuilder.append("SELECT id FROM users WHERE id IN("); 143 final int _inputSize = ids.size(); 144 $STRING_UTIL.appendPlaceholders(_stringBuilder, _inputSize); 145 _stringBuilder.append(") AND age > "); 146 _stringBuilder.append("?"); 147 final java.lang.String _sql = _stringBuilder.toString(); 148 final int _argCount = 1 + _inputSize; 149 final $QUERY _stmt = $QUERY.acquire(_sql, _argCount); 150 int _argIndex = 1; 151 for (java.lang.Integer _item : ids) { 152 if (_item == null) { 153 _stmt.bindNull(_argIndex); 154 } else { 155 _stmt.bindLong(_argIndex, _item); 156 } 157 _argIndex ++; 158 } 159 _argIndex = 1 + _inputSize; 160 _stmt.bindLong(_argIndex, time); 161 """.trimIndent() 162 163 @Test 164 fun aLongAndIntegerList() { 165 singleQueryMethod(""" 166 @Query("SELECT id FROM users WHERE id IN(:ids) AND age > :time") 167 abstract List<Integer> selectAllIds(long time, List<Integer> ids); 168 """) { writer -> 169 val scope = testCodeGenScope() 170 writer.prepareReadAndBind("_sql", "_stmt", scope) 171 assertThat(scope.generate().trim(), `is`(collectionOut)) 172 }.compilesWithoutError() 173 } 174 175 @Test 176 fun aLongAndIntegerSet() { 177 singleQueryMethod(""" 178 @Query("SELECT id FROM users WHERE id IN(:ids) AND age > :time") 179 abstract List<Integer> selectAllIds(long time, Set<Integer> ids); 180 """) { writer -> 181 val scope = testCodeGenScope() 182 writer.prepareReadAndBind("_sql", "_stmt", scope) 183 assertThat(scope.generate().trim(), `is`(collectionOut)) 184 }.compilesWithoutError() 185 } 186 187 @Test 188 fun testMultipleBindParamsWithSameName() { 189 singleQueryMethod(""" 190 @Query("SELECT id FROM users WHERE age > :age OR bage > :age") 191 abstract List<Integer> selectAllIds(int age); 192 """) { writer -> 193 val scope = testCodeGenScope() 194 writer.prepareReadAndBind("_sql", "_stmt", scope) 195 assertThat(scope.generate().trim(), `is`(""" 196 final java.lang.String _sql = "SELECT id FROM users WHERE age > ? OR bage > ?"; 197 final $QUERY _stmt = $QUERY.acquire(_sql, 2); 198 int _argIndex = 1; 199 _stmt.bindLong(_argIndex, age); 200 _argIndex = 2; 201 _stmt.bindLong(_argIndex, age); 202 """.trimIndent())) 203 }.compilesWithoutError() 204 } 205 206 @Test 207 fun testMultipleBindParamsWithSameNameWithVarArg() { 208 singleQueryMethod(""" 209 @Query("SELECT id FROM users WHERE age > :age OR bage > :age OR fage IN(:ages)") 210 abstract List<Integer> selectAllIds(int age, int... ages); 211 """) { writer -> 212 val scope = testCodeGenScope() 213 writer.prepareReadAndBind("_sql", "_stmt", scope) 214 assertThat(scope.generate().trim(), `is`(""" 215 java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder(); 216 _stringBuilder.append("SELECT id FROM users WHERE age > "); 217 _stringBuilder.append("?"); 218 _stringBuilder.append(" OR bage > "); 219 _stringBuilder.append("?"); 220 _stringBuilder.append(" OR fage IN("); 221 final int _inputSize = ages.length; 222 $STRING_UTIL.appendPlaceholders(_stringBuilder, _inputSize); 223 _stringBuilder.append(")"); 224 final java.lang.String _sql = _stringBuilder.toString(); 225 final int _argCount = 2 + _inputSize; 226 final $QUERY _stmt = $QUERY.acquire(_sql, _argCount); 227 int _argIndex = 1; 228 _stmt.bindLong(_argIndex, age); 229 _argIndex = 2; 230 _stmt.bindLong(_argIndex, age); 231 _argIndex = 3; 232 for (int _item : ages) { 233 _stmt.bindLong(_argIndex, _item); 234 _argIndex ++; 235 } 236 """.trimIndent())) 237 }.compilesWithoutError() 238 } 239 240 @Test 241 fun testMultipleBindParamsWithSameNameWithVarArgInTwoBindings() { 242 singleQueryMethod(""" 243 @Query("SELECT id FROM users WHERE age IN (:ages) OR bage > :age OR fage IN(:ages)") 244 abstract List<Integer> selectAllIds(int age, int... ages); 245 """) { writer -> 246 val scope = testCodeGenScope() 247 writer.prepareReadAndBind("_sql", "_stmt", scope) 248 assertThat(scope.generate().trim(), `is`(""" 249 java.lang.StringBuilder _stringBuilder = $STRING_UTIL.newStringBuilder(); 250 _stringBuilder.append("SELECT id FROM users WHERE age IN ("); 251 final int _inputSize = ages.length; 252 $STRING_UTIL.appendPlaceholders(_stringBuilder, _inputSize); 253 _stringBuilder.append(") OR bage > "); 254 _stringBuilder.append("?"); 255 _stringBuilder.append(" OR fage IN("); 256 final int _inputSize_1 = ages.length; 257 $STRING_UTIL.appendPlaceholders(_stringBuilder, _inputSize_1); 258 _stringBuilder.append(")"); 259 final java.lang.String _sql = _stringBuilder.toString(); 260 final int _argCount = 1 + _inputSize + _inputSize_1; 261 final $QUERY _stmt = $QUERY.acquire(_sql, _argCount); 262 int _argIndex = 1; 263 for (int _item : ages) { 264 _stmt.bindLong(_argIndex, _item); 265 _argIndex ++; 266 } 267 _argIndex = 1 + _inputSize; 268 _stmt.bindLong(_argIndex, age); 269 _argIndex = 2 + _inputSize; 270 for (int _item_1 : ages) { 271 _stmt.bindLong(_argIndex, _item_1); 272 _argIndex ++; 273 } 274 """.trimIndent())) 275 }.compilesWithoutError() 276 } 277 278 fun singleQueryMethod(vararg input: String, 279 handler: (QueryWriter) -> Unit): 280 CompileTester { 281 return Truth.assertAbout(JavaSourceSubjectFactory.javaSource()) 282 .that(JavaFileObjects.forSourceString("foo.bar.MyClass", 283 DAO_PREFIX + input.joinToString("\n") + DAO_SUFFIX 284 )) 285 .processedWith(TestProcessor.builder() 286 .forAnnotations(Query::class, Dao::class) 287 .nextRunHandler { invocation -> 288 val (owner, methods) = invocation.roundEnv 289 .getElementsAnnotatedWith(Dao::class.java) 290 .map { 291 Pair(it, 292 invocation.processingEnv.elementUtils 293 .getAllMembers(MoreElements.asType(it)) 294 .filter { 295 MoreElements.isAnnotationPresent(it, 296 Query::class.java) 297 } 298 ) 299 }.first { it.second.isNotEmpty() } 300 val parser = QueryMethodProcessor( 301 baseContext = invocation.context, 302 containing = MoreTypes.asDeclared(owner.asType()), 303 executableElement = MoreElements.asExecutable(methods.first())) 304 val parsedQuery = parser.process() 305 handler(QueryWriter(parsedQuery)) 306 true 307 } 308 .build()) 309 } 310} 311