1/* 2 * Copyright (C) 2016 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package android.arch.persistence.room.writer 18 19import android.arch.persistence.room.ext.L 20import android.arch.persistence.room.ext.N 21import android.arch.persistence.room.ext.RoomTypeNames 22import android.arch.persistence.room.ext.SupportDbTypeNames 23import android.arch.persistence.room.ext.T 24import android.arch.persistence.room.ext.typeName 25import android.arch.persistence.room.parser.QueryType 26import android.arch.persistence.room.processor.OnConflictProcessor 27import android.arch.persistence.room.solver.CodeGenScope 28import android.arch.persistence.room.vo.Dao 29import android.arch.persistence.room.vo.Entity 30import android.arch.persistence.room.vo.InsertionMethod 31import android.arch.persistence.room.vo.QueryMethod 32import android.arch.persistence.room.vo.ShortcutMethod 33import android.arch.persistence.room.vo.TransactionMethod 34import com.google.auto.common.MoreTypes 35import com.squareup.javapoet.ClassName 36import com.squareup.javapoet.CodeBlock 37import com.squareup.javapoet.FieldSpec 38import com.squareup.javapoet.MethodSpec 39import com.squareup.javapoet.ParameterSpec 40import com.squareup.javapoet.TypeName 41import com.squareup.javapoet.TypeSpec 42import stripNonJava 43import javax.annotation.processing.ProcessingEnvironment 44import javax.lang.model.element.ElementKind 45import javax.lang.model.element.ExecutableElement 46import javax.lang.model.element.Modifier.FINAL 47import javax.lang.model.element.Modifier.PRIVATE 48import javax.lang.model.element.Modifier.PUBLIC 49import javax.lang.model.type.DeclaredType 50import javax.lang.model.type.TypeKind 51 52/** 53 * Creates the implementation for a class annotated with Dao. 54 */ 55class DaoWriter(val dao: Dao, val processingEnv: ProcessingEnvironment) 56 : ClassWriter(dao.typeName) { 57 val declaredDao = MoreTypes.asDeclared(dao.element.asType()) 58 companion object { 59 // TODO nothing prevents this from conflicting, we should fix. 60 val dbField: FieldSpec = FieldSpec 61 .builder(RoomTypeNames.ROOM_DB, "__db", PRIVATE, FINAL) 62 .build() 63 64 private fun typeNameToFieldName(typeName: TypeName?): String { 65 if (typeName is ClassName) { 66 return typeName.simpleName() 67 } else { 68 return typeName.toString().replace('.', '_').stripNonJava() 69 } 70 } 71 } 72 73 override fun createTypeSpecBuilder(): TypeSpec.Builder { 74 val builder = TypeSpec.classBuilder(dao.implTypeName) 75 /** 76 * if delete / update query method wants to return modified rows, we need prepared query. 77 * in that case, if args are dynamic, we cannot re-use the query, if not, we should re-use 78 * it. this requires more work but creates good performance. 79 */ 80 val groupedDeleteUpdate = dao.queryMethods 81 .filter { it.query.type == QueryType.DELETE || it.query.type == QueryType.UPDATE } 82 .groupBy { it.parameters.any { it.queryParamAdapter?.isMultiple ?: true } } 83 // delete queries that can be prepared ahead of time 84 val preparedDeleteOrUpdateQueries = groupedDeleteUpdate[false] ?: emptyList() 85 // delete queries that must be rebuild every single time 86 val oneOffDeleteOrUpdateQueries = groupedDeleteUpdate[true] ?: emptyList() 87 val shortcutMethods = createInsertionMethods() + 88 createDeletionMethods() + createUpdateMethods() + createTransactionMethods() + 89 createPreparedDeleteOrUpdateQueries(preparedDeleteOrUpdateQueries) 90 91 builder.apply { 92 addModifiers(PUBLIC) 93 if (dao.element.kind == ElementKind.INTERFACE) { 94 addSuperinterface(dao.typeName) 95 } else { 96 superclass(dao.typeName) 97 } 98 addField(dbField) 99 val dbParam = ParameterSpec 100 .builder(dao.constructorParamType ?: dbField.type, dbField.name).build() 101 102 addMethod(createConstructor(dbParam, shortcutMethods, dao.constructorParamType != null)) 103 104 shortcutMethods.forEach { 105 addMethod(it.methodImpl) 106 } 107 108 dao.queryMethods.filter { it.query.type == QueryType.SELECT }.forEach { method -> 109 addMethod(createSelectMethod(method)) 110 } 111 oneOffDeleteOrUpdateQueries.forEach { 112 addMethod(createDeleteOrUpdateQueryMethod(it)) 113 } 114 } 115 return builder 116 } 117 118 private fun createPreparedDeleteOrUpdateQueries(preparedDeleteQueries: List<QueryMethod>) 119 : List<PreparedStmtQuery> { 120 return preparedDeleteQueries.map { method -> 121 val fieldSpec = getOrCreateField(PreparedStatementField(method)) 122 val queryWriter = QueryWriter(method) 123 val fieldImpl = PreparedStatementWriter(queryWriter) 124 .createAnonymous(this@DaoWriter, dbField) 125 val methodBody = createPreparedDeleteQueryMethodBody(method, fieldSpec, queryWriter) 126 PreparedStmtQuery(mapOf(PreparedStmtQuery.NO_PARAM_FIELD 127 to (fieldSpec to fieldImpl)), methodBody) 128 } 129 } 130 131 private fun createPreparedDeleteQueryMethodBody(method: QueryMethod, 132 preparedStmtField: FieldSpec, 133 queryWriter: QueryWriter): MethodSpec { 134 val scope = CodeGenScope(this) 135 val methodBuilder = overrideWithoutAnnotations(method.element, declaredDao).apply { 136 val stmtName = scope.getTmpVar("_stmt") 137 addStatement("final $T $L = $N.acquire()", 138 SupportDbTypeNames.SQLITE_STMT, stmtName, preparedStmtField) 139 addStatement("$N.beginTransaction()", dbField) 140 beginControlFlow("try").apply { 141 val bindScope = scope.fork() 142 queryWriter.bindArgs(stmtName, emptyList(), bindScope) 143 addCode(bindScope.builder().build()) 144 if (method.returnsValue) { 145 val resultVar = scope.getTmpVar("_result") 146 addStatement("final $L $L = $L.executeUpdateDelete()", 147 method.returnType.typeName(), resultVar, stmtName) 148 addStatement("$N.setTransactionSuccessful()", dbField) 149 addStatement("return $L", resultVar) 150 } else { 151 addStatement("$L.executeUpdateDelete()", stmtName) 152 addStatement("$N.setTransactionSuccessful()", dbField) 153 } 154 } 155 nextControlFlow("finally").apply { 156 addStatement("$N.endTransaction()", dbField) 157 addStatement("$N.release($L)", preparedStmtField, stmtName) 158 } 159 endControlFlow() 160 } 161 return methodBuilder.build() 162 } 163 164 private fun createTransactionMethods(): List<PreparedStmtQuery> { 165 return dao.transactionMethods.map { 166 PreparedStmtQuery(emptyMap(), createTransactionMethodBody(it)) 167 } 168 } 169 170 private fun createTransactionMethodBody(method: TransactionMethod): MethodSpec { 171 val scope = CodeGenScope(this) 172 val methodBuilder = overrideWithoutAnnotations(method.element, declaredDao).apply { 173 addStatement("$N.beginTransaction()", dbField) 174 beginControlFlow("try").apply { 175 val returnsValue = method.element.returnType.kind != TypeKind.VOID 176 val resultVar = if (returnsValue) { 177 scope.getTmpVar("_result") 178 } else { 179 null 180 } 181 addDelegateToSuperStatement(method.element, resultVar) 182 addStatement("$N.setTransactionSuccessful()", dbField) 183 if (returnsValue) { 184 addStatement("return $N", resultVar) 185 } 186 } 187 nextControlFlow("finally").apply { 188 addStatement("$N.endTransaction()", dbField) 189 } 190 endControlFlow() 191 } 192 return methodBuilder.build() 193 } 194 195 private fun MethodSpec.Builder.addDelegateToSuperStatement(element: ExecutableElement, 196 result: String?) { 197 val params: MutableList<Any> = mutableListOf() 198 val format = buildString { 199 if (result != null) { 200 append("$T $L = ") 201 params.add(element.returnType) 202 params.add(result) 203 } 204 append("super.$N(") 205 params.add(element.simpleName) 206 var first = true 207 element.parameters.forEach { 208 if (first) { 209 first = false 210 } else { 211 append(", ") 212 } 213 append(L) 214 params.add(it.simpleName) 215 } 216 append(")") 217 } 218 addStatement(format, *params.toTypedArray()) 219 } 220 221 private fun createConstructor(dbParam: ParameterSpec, 222 shortcutMethods: List<PreparedStmtQuery>, 223 callSuper: Boolean): MethodSpec { 224 return MethodSpec.constructorBuilder().apply { 225 addParameter(dbParam) 226 addModifiers(PUBLIC) 227 if (callSuper) { 228 addStatement("super($N)", dbParam) 229 } 230 addStatement("this.$N = $N", dbField, dbParam) 231 shortcutMethods.filterNot { 232 it.fields.isEmpty() 233 }.map { 234 it.fields.values 235 }.flatten().groupBy { 236 it.first.name 237 }.map { 238 it.value.first() 239 }.forEach { 240 addStatement("this.$N = $L", it.first, it.second) 241 } 242 }.build() 243 } 244 245 private fun createSelectMethod(method: QueryMethod): MethodSpec { 246 return overrideWithoutAnnotations(method.element, declaredDao).apply { 247 addCode(createQueryMethodBody(method)) 248 }.build() 249 } 250 251 private fun createDeleteOrUpdateQueryMethod(method: QueryMethod): MethodSpec { 252 return overrideWithoutAnnotations(method.element, declaredDao).apply { 253 addCode(createDeleteOrUpdateQueryMethodBody(method)) 254 }.build() 255 } 256 257 /** 258 * Groups all insertion methods based on the insert statement they will use then creates all 259 * field specs, EntityInsertionAdapterWriter and actual insert methods. 260 */ 261 private fun createInsertionMethods(): List<PreparedStmtQuery> { 262 return dao.insertionMethods 263 .map { insertionMethod -> 264 val onConflict = OnConflictProcessor.onConflictText(insertionMethod.onConflict) 265 val entities = insertionMethod.entities 266 267 val fields = entities.mapValues { 268 val spec = getOrCreateField(InsertionMethodField(it.value, onConflict)) 269 val impl = EntityInsertionAdapterWriter(it.value, onConflict) 270 .createAnonymous(this@DaoWriter, dbField.name) 271 spec to impl 272 } 273 val methodImpl = overrideWithoutAnnotations(insertionMethod.element, 274 declaredDao).apply { 275 addCode(createInsertionMethodBody(insertionMethod, fields)) 276 }.build() 277 PreparedStmtQuery(fields, methodImpl) 278 }.filterNotNull() 279 } 280 281 private fun createInsertionMethodBody(method: InsertionMethod, 282 insertionAdapters: Map<String, Pair<FieldSpec, TypeSpec>>) 283 : CodeBlock { 284 val insertionType = method.insertionType 285 if (insertionAdapters.isEmpty() || insertionType == null) { 286 return CodeBlock.builder().build() 287 } 288 val scope = CodeGenScope(this) 289 290 return scope.builder().apply { 291 // TODO assert thread 292 // TODO collect results 293 addStatement("$N.beginTransaction()", dbField) 294 val needsReturnType = insertionType != InsertionMethod.Type.INSERT_VOID 295 val resultVar = if (needsReturnType) { 296 scope.getTmpVar("_result") 297 } else { 298 null 299 } 300 301 beginControlFlow("try").apply { 302 method.parameters.forEach { param -> 303 val insertionAdapter = insertionAdapters[param.name]?.first 304 if (needsReturnType) { 305 // if it has more than 1 parameter, we would've already printed the error 306 // so we don't care about re-declaring the variable here 307 addStatement("$T $L = $N.$L($L)", 308 insertionType.returnTypeName, resultVar, 309 insertionAdapter, insertionType.methodName, 310 param.name) 311 } else { 312 addStatement("$N.$L($L)", insertionAdapter, insertionType.methodName, 313 param.name) 314 } 315 } 316 addStatement("$N.setTransactionSuccessful()", dbField) 317 if (needsReturnType) { 318 addStatement("return $L", resultVar) 319 } 320 } 321 nextControlFlow("finally").apply { 322 addStatement("$N.endTransaction()", dbField) 323 } 324 endControlFlow() 325 }.build() 326 } 327 328 /** 329 * Creates EntityUpdateAdapter for each deletion method. 330 */ 331 private fun createDeletionMethods(): List<PreparedStmtQuery> { 332 return createShortcutMethods(dao.deletionMethods, "deletion", { _, entity -> 333 EntityDeletionAdapterWriter(entity) 334 .createAnonymous(this@DaoWriter, dbField.name) 335 }) 336 } 337 338 /** 339 * Creates EntityUpdateAdapter for each @Update method. 340 */ 341 private fun createUpdateMethods(): List<PreparedStmtQuery> { 342 return createShortcutMethods(dao.updateMethods, "update", { update, entity -> 343 val onConflict = OnConflictProcessor.onConflictText(update.onConflictStrategy) 344 EntityUpdateAdapterWriter(entity, onConflict) 345 .createAnonymous(this@DaoWriter, dbField.name) 346 }) 347 } 348 349 private fun <T : ShortcutMethod> createShortcutMethods(methods: List<T>, methodPrefix: String, 350 implCallback: (T, Entity) -> TypeSpec) 351 : List<PreparedStmtQuery> { 352 return methods.map { method -> 353 val entities = method.entities 354 355 if (entities.isEmpty()) { 356 null 357 } else { 358 val fields = entities.mapValues { 359 val spec = getOrCreateField(DeleteOrUpdateAdapterField(it.value, methodPrefix)) 360 val impl = implCallback(method, it.value) 361 spec to impl 362 } 363 val methodSpec = overrideWithoutAnnotations(method.element, declaredDao).apply { 364 addCode(createDeleteOrUpdateMethodBody(method, fields)) 365 }.build() 366 PreparedStmtQuery(fields, methodSpec) 367 } 368 }.filterNotNull() 369 } 370 371 private fun createDeleteOrUpdateMethodBody(method: ShortcutMethod, 372 adapters: Map<String, Pair<FieldSpec, TypeSpec>>) 373 : CodeBlock { 374 if (adapters.isEmpty()) { 375 return CodeBlock.builder().build() 376 } 377 val scope = CodeGenScope(this) 378 val resultVar = if (method.returnCount) { 379 scope.getTmpVar("_total") 380 } else { 381 null 382 } 383 return scope.builder().apply { 384 if (resultVar != null) { 385 addStatement("$T $L = 0", TypeName.INT, resultVar) 386 } 387 addStatement("$N.beginTransaction()", dbField) 388 beginControlFlow("try").apply { 389 method.parameters.forEach { param -> 390 val adapter = adapters[param.name]?.first 391 addStatement("$L$N.$L($L)", 392 if (resultVar == null) "" else "$resultVar +=", 393 adapter, param.handleMethodName(), param.name) 394 } 395 addStatement("$N.setTransactionSuccessful()", dbField) 396 if (resultVar != null) { 397 addStatement("return $L", resultVar) 398 } 399 } 400 nextControlFlow("finally").apply { 401 addStatement("$N.endTransaction()", dbField) 402 } 403 endControlFlow() 404 }.build() 405 } 406 407 /** 408 * @Query with delete action 409 */ 410 private fun createDeleteOrUpdateQueryMethodBody(method: QueryMethod): CodeBlock { 411 val queryWriter = QueryWriter(method) 412 val scope = CodeGenScope(this) 413 val sqlVar = scope.getTmpVar("_sql") 414 val stmtVar = scope.getTmpVar("_stmt") 415 val listSizeArgs = queryWriter.prepareQuery(sqlVar, scope) 416 scope.builder().apply { 417 addStatement("$T $L = $N.compileStatement($L)", 418 SupportDbTypeNames.SQLITE_STMT, stmtVar, dbField, sqlVar) 419 queryWriter.bindArgs(stmtVar, listSizeArgs, scope) 420 addStatement("$N.beginTransaction()", dbField) 421 beginControlFlow("try").apply { 422 if (method.returnsValue) { 423 val resultVar = scope.getTmpVar("_result") 424 addStatement("final $L $L = $L.executeUpdateDelete()", 425 method.returnType.typeName(), resultVar, stmtVar) 426 addStatement("$N.setTransactionSuccessful()", dbField) 427 addStatement("return $L", resultVar) 428 } else { 429 addStatement("$L.executeUpdateDelete()", stmtVar) 430 addStatement("$N.setTransactionSuccessful()", dbField) 431 } 432 } 433 nextControlFlow("finally").apply { 434 addStatement("$N.endTransaction()", dbField) 435 } 436 endControlFlow() 437 438 } 439 return scope.builder().build() 440 } 441 442 private fun createQueryMethodBody(method: QueryMethod): CodeBlock { 443 val queryWriter = QueryWriter(method) 444 val scope = CodeGenScope(this) 445 val sqlVar = scope.getTmpVar("_sql") 446 val roomSQLiteQueryVar = scope.getTmpVar("_statement") 447 queryWriter.prepareReadAndBind(sqlVar, roomSQLiteQueryVar, scope) 448 method.queryResultBinder.convertAndReturn(roomSQLiteQueryVar, dbField, scope) 449 return scope.builder().build() 450 } 451 452 private fun overrideWithoutAnnotations(elm: ExecutableElement, 453 owner : DeclaredType): MethodSpec.Builder { 454 val baseSpec = MethodSpec.overriding(elm, owner, processingEnv.typeUtils).build() 455 return MethodSpec.methodBuilder(baseSpec.name).apply { 456 addAnnotation(Override::class.java) 457 addModifiers(baseSpec.modifiers) 458 addParameters(baseSpec.parameters) 459 varargs(baseSpec.varargs) 460 returns(baseSpec.returnType) 461 } 462 } 463 464 /** 465 * Represents a query statement prepared in Dao implementation. 466 * 467 * @param fields This map holds all the member fields necessary for this query. The key is the 468 * corresponding parameter name in the defining query method. The value is a pair from the field 469 * declaration to definition. 470 * @param methodImpl The body of the query method implementation. 471 */ 472 data class PreparedStmtQuery(val fields: Map<String, Pair<FieldSpec, TypeSpec>>, 473 val methodImpl: MethodSpec) { 474 companion object { 475 // The key to be used in `fields` where the method requires a field that is not 476 // associated with any of its parameters 477 const val NO_PARAM_FIELD = "-" 478 } 479 } 480 481 private class InsertionMethodField(val entity: Entity, val onConflictText: String) 482 : SharedFieldSpec( 483 "insertionAdapterOf${Companion.typeNameToFieldName(entity.typeName)}", 484 RoomTypeNames.INSERTION_ADAPTER) { 485 486 override fun getUniqueKey(): String { 487 return "${entity.typeName} $onConflictText" 488 } 489 490 override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) { 491 builder.addModifiers(FINAL, PRIVATE) 492 } 493 } 494 495 class DeleteOrUpdateAdapterField(val entity: Entity, val methodPrefix: String) 496 : SharedFieldSpec( 497 "${methodPrefix}AdapterOf${Companion.typeNameToFieldName(entity.typeName)}", 498 RoomTypeNames.DELETE_OR_UPDATE_ADAPTER) { 499 override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) { 500 builder.addModifiers(PRIVATE, FINAL) 501 } 502 503 override fun getUniqueKey(): String { 504 return entity.typeName.toString() + methodPrefix 505 } 506 } 507 508 class PreparedStatementField(val method: QueryMethod) : SharedFieldSpec( 509 "preparedStmtOf${method.name.capitalize()}", RoomTypeNames.SHARED_SQLITE_STMT) { 510 override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) { 511 builder.addModifiers(PRIVATE, FINAL) 512 } 513 514 override fun getUniqueKey(): String { 515 return method.query.original 516 } 517 } 518} 519