1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package androidx.room.writer
18
19import androidx.room.ext.L
20import androidx.room.ext.N
21import androidx.room.ext.RoomTypeNames
22import androidx.room.ext.SupportDbTypeNames
23import androidx.room.ext.T
24import androidx.room.ext.typeName
25import androidx.room.parser.QueryType
26import androidx.room.processor.OnConflictProcessor
27import androidx.room.solver.CodeGenScope
28import androidx.room.vo.Dao
29import androidx.room.vo.Entity
30import androidx.room.vo.InsertionMethod
31import androidx.room.vo.QueryMethod
32import androidx.room.vo.RawQueryMethod
33import androidx.room.vo.ShortcutMethod
34import androidx.room.vo.TransactionMethod
35import com.google.auto.common.MoreTypes
36import com.squareup.javapoet.ClassName
37import com.squareup.javapoet.CodeBlock
38import com.squareup.javapoet.FieldSpec
39import com.squareup.javapoet.MethodSpec
40import com.squareup.javapoet.ParameterSpec
41import com.squareup.javapoet.TypeName
42import com.squareup.javapoet.TypeSpec
43import me.eugeniomarletti.kotlin.metadata.shadow.load.java.JvmAbi
44import stripNonJava
45import javax.annotation.processing.ProcessingEnvironment
46import javax.lang.model.element.ElementKind
47import javax.lang.model.element.ExecutableElement
48import javax.lang.model.element.Modifier.FINAL
49import javax.lang.model.element.Modifier.PRIVATE
50import javax.lang.model.element.Modifier.PUBLIC
51import javax.lang.model.type.DeclaredType
52import javax.lang.model.type.TypeKind
53
54/**
55 * Creates the implementation for a class annotated with Dao.
56 */
57class DaoWriter(val dao: Dao, val processingEnv: ProcessingEnvironment)
58    : ClassWriter(dao.typeName) {
59    private val declaredDao = MoreTypes.asDeclared(dao.element.asType())
60
61    companion object {
62        // TODO nothing prevents this from conflicting, we should fix.
63        val dbField: FieldSpec = FieldSpec
64                .builder(RoomTypeNames.ROOM_DB, "__db", PRIVATE, FINAL)
65                .build()
66
67        private fun typeNameToFieldName(typeName: TypeName?): String {
68            if (typeName is ClassName) {
69                return typeName.simpleName()
70            } else {
71                return typeName.toString().replace('.', '_').stripNonJava()
72            }
73        }
74    }
75
76    override fun createTypeSpecBuilder(): TypeSpec.Builder {
77        val builder = TypeSpec.classBuilder(dao.implTypeName)
78        /**
79         * if delete / update query method wants to return modified rows, we need prepared query.
80         * in that case, if args are dynamic, we cannot re-use the query, if not, we should re-use
81         * it. this requires more work but creates good performance.
82         */
83        val groupedDeleteUpdate = dao.queryMethods
84                .filter { it.query.type == QueryType.DELETE || it.query.type == QueryType.UPDATE }
85                .groupBy { it.parameters.any { it.queryParamAdapter?.isMultiple ?: true } }
86        // delete queries that can be prepared ahead of time
87        val preparedDeleteOrUpdateQueries = groupedDeleteUpdate[false] ?: emptyList()
88        // delete queries that must be rebuild every single time
89        val oneOffDeleteOrUpdateQueries = groupedDeleteUpdate[true] ?: emptyList()
90        val shortcutMethods = createInsertionMethods() +
91                createDeletionMethods() + createUpdateMethods() + createTransactionMethods() +
92                createPreparedDeleteOrUpdateQueries(preparedDeleteOrUpdateQueries)
93
94        builder.apply {
95            addModifiers(PUBLIC)
96            if (dao.element.kind == ElementKind.INTERFACE) {
97                addSuperinterface(dao.typeName)
98            } else {
99                superclass(dao.typeName)
100            }
101            addField(dbField)
102            val dbParam = ParameterSpec
103                    .builder(dao.constructorParamType ?: dbField.type, dbField.name).build()
104
105            addMethod(createConstructor(dbParam, shortcutMethods, dao.constructorParamType != null))
106
107            shortcutMethods.forEach {
108                addMethod(it.methodImpl)
109            }
110
111            dao.queryMethods.filter { it.query.type == QueryType.SELECT }.forEach { method ->
112                addMethod(createSelectMethod(method))
113            }
114            oneOffDeleteOrUpdateQueries.forEach {
115                addMethod(createDeleteOrUpdateQueryMethod(it))
116            }
117            dao.rawQueryMethods.forEach {
118                addMethod(createRawQueryMethod(it))
119            }
120        }
121        return builder
122    }
123
124    private fun createPreparedDeleteOrUpdateQueries(
125            preparedDeleteQueries: List<QueryMethod>): List<PreparedStmtQuery> {
126        return preparedDeleteQueries.map { method ->
127            val fieldSpec = getOrCreateField(PreparedStatementField(method))
128            val queryWriter = QueryWriter(method)
129            val fieldImpl = PreparedStatementWriter(queryWriter)
130                    .createAnonymous(this@DaoWriter, dbField)
131            val methodBody = createPreparedDeleteQueryMethodBody(method, fieldSpec, queryWriter)
132            PreparedStmtQuery(mapOf(PreparedStmtQuery.NO_PARAM_FIELD
133                    to (fieldSpec to fieldImpl)), methodBody)
134        }
135    }
136
137    private fun createPreparedDeleteQueryMethodBody(
138            method: QueryMethod,
139            preparedStmtField: FieldSpec,
140            queryWriter: QueryWriter
141    ): MethodSpec {
142        val scope = CodeGenScope(this)
143        val methodBuilder = overrideWithoutAnnotations(method.element, declaredDao).apply {
144            val stmtName = scope.getTmpVar("_stmt")
145            addStatement("final $T $L = $N.acquire()",
146                    SupportDbTypeNames.SQLITE_STMT, stmtName, preparedStmtField)
147            addStatement("$N.beginTransaction()", dbField)
148            beginControlFlow("try").apply {
149                val bindScope = scope.fork()
150                queryWriter.bindArgs(stmtName, emptyList(), bindScope)
151                addCode(bindScope.builder().build())
152                if (method.returnsValue) {
153                    val resultVar = scope.getTmpVar("_result")
154                    addStatement("final $L $L = $L.executeUpdateDelete()",
155                            method.returnType.typeName(), resultVar, stmtName)
156                    addStatement("$N.setTransactionSuccessful()", dbField)
157                    addStatement("return $L", resultVar)
158                } else {
159                    addStatement("$L.executeUpdateDelete()", stmtName)
160                    addStatement("$N.setTransactionSuccessful()", dbField)
161                }
162            }
163            nextControlFlow("finally").apply {
164                addStatement("$N.endTransaction()", dbField)
165                addStatement("$N.release($L)", preparedStmtField, stmtName)
166            }
167            endControlFlow()
168        }
169        return methodBuilder.build()
170    }
171
172    private fun createTransactionMethods(): List<PreparedStmtQuery> {
173        return dao.transactionMethods.map {
174            PreparedStmtQuery(emptyMap(), createTransactionMethodBody(it))
175        }
176    }
177
178    private fun createTransactionMethodBody(method: TransactionMethod): MethodSpec {
179        val scope = CodeGenScope(this)
180        val methodBuilder = overrideWithoutAnnotations(method.element, declaredDao).apply {
181            addStatement("$N.beginTransaction()", dbField)
182            beginControlFlow("try").apply {
183                val returnsValue = method.element.returnType.kind != TypeKind.VOID
184                val resultVar = if (returnsValue) {
185                    scope.getTmpVar("_result")
186                } else {
187                    null
188                }
189                addDelegateToSuperStatement(method.element, method.callType, resultVar)
190                addStatement("$N.setTransactionSuccessful()", dbField)
191                if (returnsValue) {
192                    addStatement("return $N", resultVar)
193                }
194            }
195            nextControlFlow("finally").apply {
196                addStatement("$N.endTransaction()", dbField)
197            }
198            endControlFlow()
199        }
200        return methodBuilder.build()
201    }
202
203    private fun MethodSpec.Builder.addDelegateToSuperStatement(
204            element: ExecutableElement,
205            callType: TransactionMethod.CallType,
206            result: String?) {
207        val params: MutableList<Any> = mutableListOf()
208        val format = buildString {
209            if (result != null) {
210                append("$T $L = ")
211                params.add(element.returnType)
212                params.add(result)
213            }
214            when (callType) {
215                TransactionMethod.CallType.CONCRETE -> {
216                    append("super.$N(")
217                    params.add(element.simpleName)
218                }
219                TransactionMethod.CallType.DEFAULT_JAVA8 -> {
220                    append("$N.super.$N(")
221                    params.add(element.enclosingElement.simpleName)
222                    params.add(element.simpleName)
223                }
224                TransactionMethod.CallType.DEFAULT_KOTLIN -> {
225                    append("$N.$N.$N(this, ")
226                    params.add(element.enclosingElement.simpleName)
227                    params.add(JvmAbi.DEFAULT_IMPLS_CLASS_NAME)
228                    params.add(element.simpleName)
229                }
230            }
231            var first = true
232            element.parameters.forEach {
233                if (first) {
234                    first = false
235                } else {
236                    append(", ")
237                }
238                append(L)
239                params.add(it.simpleName)
240            }
241            append(")")
242        }
243        addStatement(format, *params.toTypedArray())
244    }
245
246    private fun createConstructor(
247            dbParam: ParameterSpec,
248            shortcutMethods: List<PreparedStmtQuery>,
249            callSuper: Boolean): MethodSpec {
250        return MethodSpec.constructorBuilder().apply {
251            addParameter(dbParam)
252            addModifiers(PUBLIC)
253            if (callSuper) {
254                addStatement("super($N)", dbParam)
255            }
256            addStatement("this.$N = $N", dbField, dbParam)
257            shortcutMethods.filterNot {
258                it.fields.isEmpty()
259            }.map {
260                it.fields.values
261            }.flatten().groupBy {
262                it.first.name
263            }.map {
264                it.value.first()
265            }.forEach {
266                addStatement("this.$N = $L", it.first, it.second)
267            }
268        }.build()
269    }
270
271    private fun createSelectMethod(method: QueryMethod): MethodSpec {
272        return overrideWithoutAnnotations(method.element, declaredDao).apply {
273            addCode(createQueryMethodBody(method))
274        }.build()
275    }
276
277    private fun createRawQueryMethod(method: RawQueryMethod): MethodSpec {
278        return overrideWithoutAnnotations(method.element, declaredDao).apply {
279            val scope = CodeGenScope(this@DaoWriter)
280            val roomSQLiteQueryVar: String
281            val queryParam = method.runtimeQueryParam
282            val shouldReleaseQuery: Boolean
283
284            when {
285                queryParam?.isString() == true -> {
286                    roomSQLiteQueryVar = scope.getTmpVar("_statement")
287                    shouldReleaseQuery = true
288                    addStatement("$T $L = $T.acquire($L, 0)",
289                            RoomTypeNames.ROOM_SQL_QUERY,
290                            roomSQLiteQueryVar,
291                            RoomTypeNames.ROOM_SQL_QUERY,
292                            queryParam.paramName)
293                }
294                queryParam?.isSupportQuery() == true -> {
295                    shouldReleaseQuery = false
296                    roomSQLiteQueryVar = scope.getTmpVar("_internalQuery")
297                    // move it to a final variable so that the generated code can use it inside
298                    // callback blocks in java 7
299                    addStatement("final $T $L = $N",
300                            queryParam.type,
301                            roomSQLiteQueryVar,
302                            queryParam.paramName)
303                }
304                else -> {
305                    // try to generate compiling code. we would've already reported this error
306                    roomSQLiteQueryVar = scope.getTmpVar("_statement")
307                    shouldReleaseQuery = false
308                    addStatement("$T $L = $T.acquire($L, 0)",
309                            RoomTypeNames.ROOM_SQL_QUERY,
310                            roomSQLiteQueryVar,
311                            RoomTypeNames.ROOM_SQL_QUERY,
312                            "missing query parameter")
313                }
314            }
315            if (method.returnsValue) {
316                // don't generate code because it will create 1 more error. The original error is
317                // already reported by the processor.
318                method.queryResultBinder.convertAndReturn(
319                        roomSQLiteQueryVar = roomSQLiteQueryVar,
320                        canReleaseQuery = shouldReleaseQuery,
321                        dbField = dbField,
322                        inTransaction = method.inTransaction,
323                        scope = scope)
324            }
325            addCode(scope.builder().build())
326        }.build()
327    }
328
329    private fun createDeleteOrUpdateQueryMethod(method: QueryMethod): MethodSpec {
330        return overrideWithoutAnnotations(method.element, declaredDao).apply {
331            addCode(createDeleteOrUpdateQueryMethodBody(method))
332        }.build()
333    }
334
335    /**
336     * Groups all insertion methods based on the insert statement they will use then creates all
337     * field specs, EntityInsertionAdapterWriter and actual insert methods.
338     */
339    private fun createInsertionMethods(): List<PreparedStmtQuery> {
340        return dao.insertionMethods
341                .map { insertionMethod ->
342                    val onConflict = OnConflictProcessor.onConflictText(insertionMethod.onConflict)
343                    val entities = insertionMethod.entities
344
345                    val fields = entities.mapValues {
346                        val spec = getOrCreateField(InsertionMethodField(it.value, onConflict))
347                        val impl = EntityInsertionAdapterWriter(it.value, onConflict)
348                                .createAnonymous(this@DaoWriter, dbField.name)
349                        spec to impl
350                    }
351                    val methodImpl = overrideWithoutAnnotations(insertionMethod.element,
352                            declaredDao).apply {
353                        addCode(createInsertionMethodBody(insertionMethod, fields))
354                    }.build()
355                    PreparedStmtQuery(fields, methodImpl)
356                }
357    }
358
359    private fun createInsertionMethodBody(
360            method: InsertionMethod,
361            insertionAdapters: Map<String, Pair<FieldSpec, TypeSpec>>
362    ): CodeBlock {
363        val insertionType = method.insertionType
364        if (insertionAdapters.isEmpty() || insertionType == null) {
365            return CodeBlock.builder().build()
366        }
367        val scope = CodeGenScope(this)
368
369        return scope.builder().apply {
370            // TODO assert thread
371            // TODO collect results
372            addStatement("$N.beginTransaction()", dbField)
373            val needsReturnType = insertionType != InsertionMethod.Type.INSERT_VOID
374            val resultVar = if (needsReturnType) {
375                scope.getTmpVar("_result")
376            } else {
377                null
378            }
379
380            beginControlFlow("try").apply {
381                method.parameters.forEach { param ->
382                    val insertionAdapter = insertionAdapters[param.name]?.first
383                    if (needsReturnType) {
384                        // if it has more than 1 parameter, we would've already printed the error
385                        // so we don't care about re-declaring the variable here
386                        addStatement("$T $L = $N.$L($L)",
387                                insertionType.returnTypeName, resultVar,
388                                insertionAdapter, insertionType.methodName,
389                                param.name)
390                    } else {
391                        addStatement("$N.$L($L)", insertionAdapter, insertionType.methodName,
392                                param.name)
393                    }
394                }
395                addStatement("$N.setTransactionSuccessful()", dbField)
396                if (needsReturnType) {
397                    addStatement("return $L", resultVar)
398                }
399            }
400            nextControlFlow("finally").apply {
401                addStatement("$N.endTransaction()", dbField)
402            }
403            endControlFlow()
404        }.build()
405    }
406
407    /**
408     * Creates EntityUpdateAdapter for each deletion method.
409     */
410    private fun createDeletionMethods(): List<PreparedStmtQuery> {
411        return createShortcutMethods(dao.deletionMethods, "deletion", { _, entity ->
412            EntityDeletionAdapterWriter(entity)
413                    .createAnonymous(this@DaoWriter, dbField.name)
414        })
415    }
416
417    /**
418     * Creates EntityUpdateAdapter for each @Update method.
419     */
420    private fun createUpdateMethods(): List<PreparedStmtQuery> {
421        return createShortcutMethods(dao.updateMethods, "update", { update, entity ->
422            val onConflict = OnConflictProcessor.onConflictText(update.onConflictStrategy)
423            EntityUpdateAdapterWriter(entity, onConflict)
424                    .createAnonymous(this@DaoWriter, dbField.name)
425        })
426    }
427
428    private fun <T : ShortcutMethod> createShortcutMethods(
429            methods: List<T>, methodPrefix: String,
430            implCallback: (T, Entity) -> TypeSpec
431    ): List<PreparedStmtQuery> {
432        return methods.mapNotNull { method ->
433            val entities = method.entities
434
435            if (entities.isEmpty()) {
436                null
437            } else {
438                val fields = entities.mapValues {
439                    val spec = getOrCreateField(DeleteOrUpdateAdapterField(it.value, methodPrefix))
440                    val impl = implCallback(method, it.value)
441                    spec to impl
442                }
443                val methodSpec = overrideWithoutAnnotations(method.element, declaredDao).apply {
444                    addCode(createDeleteOrUpdateMethodBody(method, fields))
445                }.build()
446                PreparedStmtQuery(fields, methodSpec)
447            }
448        }
449    }
450
451    private fun createDeleteOrUpdateMethodBody(
452            method: ShortcutMethod,
453            adapters: Map<String, Pair<FieldSpec, TypeSpec>>
454    ): CodeBlock {
455        if (adapters.isEmpty()) {
456            return CodeBlock.builder().build()
457        }
458        val scope = CodeGenScope(this)
459        val resultVar = if (method.returnCount) {
460            scope.getTmpVar("_total")
461        } else {
462            null
463        }
464        return scope.builder().apply {
465            if (resultVar != null) {
466                addStatement("$T $L = 0", TypeName.INT, resultVar)
467            }
468            addStatement("$N.beginTransaction()", dbField)
469            beginControlFlow("try").apply {
470                method.parameters.forEach { param ->
471                    val adapter = adapters[param.name]?.first
472                    addStatement("$L$N.$L($L)",
473                            if (resultVar == null) "" else "$resultVar +=",
474                            adapter, param.handleMethodName(), param.name)
475                }
476                addStatement("$N.setTransactionSuccessful()", dbField)
477                if (resultVar != null) {
478                    addStatement("return $L", resultVar)
479                }
480            }
481            nextControlFlow("finally").apply {
482                addStatement("$N.endTransaction()", dbField)
483            }
484            endControlFlow()
485        }.build()
486    }
487
488    /**
489     * @Query with delete action
490     */
491    private fun createDeleteOrUpdateQueryMethodBody(method: QueryMethod): CodeBlock {
492        val queryWriter = QueryWriter(method)
493        val scope = CodeGenScope(this)
494        val sqlVar = scope.getTmpVar("_sql")
495        val stmtVar = scope.getTmpVar("_stmt")
496        val listSizeArgs = queryWriter.prepareQuery(sqlVar, scope)
497        scope.builder().apply {
498            addStatement("$T $L = $N.compileStatement($L)",
499                    SupportDbTypeNames.SQLITE_STMT, stmtVar, dbField, sqlVar)
500            queryWriter.bindArgs(stmtVar, listSizeArgs, scope)
501            addStatement("$N.beginTransaction()", dbField)
502            beginControlFlow("try").apply {
503                if (method.returnsValue) {
504                    val resultVar = scope.getTmpVar("_result")
505                    addStatement("final $L $L = $L.executeUpdateDelete()",
506                            method.returnType.typeName(), resultVar, stmtVar)
507                    addStatement("$N.setTransactionSuccessful()", dbField)
508                    addStatement("return $L", resultVar)
509                } else {
510                    addStatement("$L.executeUpdateDelete()", stmtVar)
511                    addStatement("$N.setTransactionSuccessful()", dbField)
512                }
513            }
514            nextControlFlow("finally").apply {
515                addStatement("$N.endTransaction()", dbField)
516            }
517            endControlFlow()
518        }
519        return scope.builder().build()
520    }
521
522    private fun createQueryMethodBody(method: QueryMethod): CodeBlock {
523        val queryWriter = QueryWriter(method)
524        val scope = CodeGenScope(this)
525        val sqlVar = scope.getTmpVar("_sql")
526        val roomSQLiteQueryVar = scope.getTmpVar("_statement")
527        queryWriter.prepareReadAndBind(sqlVar, roomSQLiteQueryVar, scope)
528        method.queryResultBinder.convertAndReturn(
529                roomSQLiteQueryVar = roomSQLiteQueryVar,
530                canReleaseQuery = true,
531                dbField = dbField,
532                inTransaction = method.inTransaction,
533                scope = scope)
534        return scope.builder().build()
535    }
536
537    private fun overrideWithoutAnnotations(
538            elm: ExecutableElement,
539            owner: DeclaredType): MethodSpec.Builder {
540        val baseSpec = MethodSpec.overriding(elm, owner, processingEnv.typeUtils).build()
541        return MethodSpec.methodBuilder(baseSpec.name).apply {
542            addAnnotation(Override::class.java)
543            addModifiers(baseSpec.modifiers)
544            addParameters(baseSpec.parameters)
545            varargs(baseSpec.varargs)
546            returns(baseSpec.returnType)
547        }
548    }
549
550    /**
551     * Represents a query statement prepared in Dao implementation.
552     *
553     * @param fields This map holds all the member fields necessary for this query. The key is the
554     * corresponding parameter name in the defining query method. The value is a pair from the field
555     * declaration to definition.
556     * @param methodImpl The body of the query method implementation.
557     */
558    data class PreparedStmtQuery(
559            val fields: Map<String, Pair<FieldSpec, TypeSpec>>,
560            val methodImpl: MethodSpec) {
561        companion object {
562            // The key to be used in `fields` where the method requires a field that is not
563            // associated with any of its parameters
564            const val NO_PARAM_FIELD = "-"
565        }
566    }
567
568    private class InsertionMethodField(val entity: Entity, val onConflictText: String)
569        : SharedFieldSpec(
570            "insertionAdapterOf${Companion.typeNameToFieldName(entity.typeName)}",
571            RoomTypeNames.INSERTION_ADAPTER) {
572
573        override fun getUniqueKey(): String {
574            return "${entity.typeName} $onConflictText"
575        }
576
577        override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) {
578            builder.addModifiers(FINAL, PRIVATE)
579        }
580    }
581
582    class DeleteOrUpdateAdapterField(val entity: Entity, val methodPrefix: String)
583        : SharedFieldSpec(
584            "${methodPrefix}AdapterOf${Companion.typeNameToFieldName(entity.typeName)}",
585            RoomTypeNames.DELETE_OR_UPDATE_ADAPTER) {
586        override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) {
587            builder.addModifiers(PRIVATE, FINAL)
588        }
589
590        override fun getUniqueKey(): String {
591            return entity.typeName.toString() + methodPrefix
592        }
593    }
594
595    class PreparedStatementField(val method: QueryMethod) : SharedFieldSpec(
596            "preparedStmtOf${method.name.capitalize()}", RoomTypeNames.SHARED_SQLITE_STMT) {
597        override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) {
598            builder.addModifiers(PRIVATE, FINAL)
599        }
600
601        override fun getUniqueKey(): String {
602            return method.query.original
603        }
604    }
605}
606