1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.arch.persistence.room.writer
18
19import android.arch.persistence.room.ext.L
20import android.arch.persistence.room.ext.N
21import android.arch.persistence.room.ext.RoomTypeNames
22import android.arch.persistence.room.ext.SupportDbTypeNames
23import android.arch.persistence.room.ext.T
24import android.arch.persistence.room.ext.typeName
25import android.arch.persistence.room.parser.QueryType
26import android.arch.persistence.room.processor.OnConflictProcessor
27import android.arch.persistence.room.solver.CodeGenScope
28import android.arch.persistence.room.vo.Dao
29import android.arch.persistence.room.vo.Entity
30import android.arch.persistence.room.vo.InsertionMethod
31import android.arch.persistence.room.vo.QueryMethod
32import android.arch.persistence.room.vo.ShortcutMethod
33import android.arch.persistence.room.vo.TransactionMethod
34import com.google.auto.common.MoreTypes
35import com.squareup.javapoet.ClassName
36import com.squareup.javapoet.CodeBlock
37import com.squareup.javapoet.FieldSpec
38import com.squareup.javapoet.MethodSpec
39import com.squareup.javapoet.ParameterSpec
40import com.squareup.javapoet.TypeName
41import com.squareup.javapoet.TypeSpec
42import stripNonJava
43import javax.annotation.processing.ProcessingEnvironment
44import javax.lang.model.element.ElementKind
45import javax.lang.model.element.ExecutableElement
46import javax.lang.model.element.Modifier.FINAL
47import javax.lang.model.element.Modifier.PRIVATE
48import javax.lang.model.element.Modifier.PUBLIC
49import javax.lang.model.type.DeclaredType
50import javax.lang.model.type.TypeKind
51
52/**
53 * Creates the implementation for a class annotated with Dao.
54 */
55class DaoWriter(val dao: Dao, val processingEnv: ProcessingEnvironment)
56    : ClassWriter(dao.typeName) {
57    val declaredDao = MoreTypes.asDeclared(dao.element.asType())
58    companion object {
59        // TODO nothing prevents this from conflicting, we should fix.
60        val dbField: FieldSpec = FieldSpec
61                .builder(RoomTypeNames.ROOM_DB, "__db", PRIVATE, FINAL)
62                .build()
63
64        private fun typeNameToFieldName(typeName: TypeName?): String {
65            if (typeName is ClassName) {
66                return typeName.simpleName()
67            } else {
68                return typeName.toString().replace('.', '_').stripNonJava()
69            }
70        }
71    }
72
73    override fun createTypeSpecBuilder(): TypeSpec.Builder {
74        val builder = TypeSpec.classBuilder(dao.implTypeName)
75        /**
76         * if delete / update query method wants to return modified rows, we need prepared query.
77         * in that case, if args are dynamic, we cannot re-use the query, if not, we should re-use
78         * it. this requires more work but creates good performance.
79         */
80        val groupedDeleteUpdate = dao.queryMethods
81                .filter { it.query.type == QueryType.DELETE || it.query.type == QueryType.UPDATE }
82                .groupBy { it.parameters.any { it.queryParamAdapter?.isMultiple ?: true } }
83        // delete queries that can be prepared ahead of time
84        val preparedDeleteOrUpdateQueries = groupedDeleteUpdate[false] ?: emptyList()
85        // delete queries that must be rebuild every single time
86        val oneOffDeleteOrUpdateQueries = groupedDeleteUpdate[true] ?: emptyList()
87        val shortcutMethods = createInsertionMethods() +
88                createDeletionMethods() + createUpdateMethods() + createTransactionMethods() +
89                createPreparedDeleteOrUpdateQueries(preparedDeleteOrUpdateQueries)
90
91        builder.apply {
92            addModifiers(PUBLIC)
93            if (dao.element.kind == ElementKind.INTERFACE) {
94                addSuperinterface(dao.typeName)
95            } else {
96                superclass(dao.typeName)
97            }
98            addField(dbField)
99            val dbParam = ParameterSpec
100                    .builder(dao.constructorParamType ?: dbField.type, dbField.name).build()
101
102            addMethod(createConstructor(dbParam, shortcutMethods, dao.constructorParamType != null))
103
104            shortcutMethods.forEach {
105                addMethod(it.methodImpl)
106            }
107
108            dao.queryMethods.filter { it.query.type == QueryType.SELECT }.forEach { method ->
109                addMethod(createSelectMethod(method))
110            }
111            oneOffDeleteOrUpdateQueries.forEach {
112                addMethod(createDeleteOrUpdateQueryMethod(it))
113            }
114        }
115        return builder
116    }
117
118    private fun createPreparedDeleteOrUpdateQueries(preparedDeleteQueries: List<QueryMethod>)
119            : List<PreparedStmtQuery> {
120        return preparedDeleteQueries.map { method ->
121            val fieldSpec = getOrCreateField(PreparedStatementField(method))
122            val queryWriter = QueryWriter(method)
123            val fieldImpl = PreparedStatementWriter(queryWriter)
124                    .createAnonymous(this@DaoWriter, dbField)
125            val methodBody = createPreparedDeleteQueryMethodBody(method, fieldSpec, queryWriter)
126            PreparedStmtQuery(mapOf(PreparedStmtQuery.NO_PARAM_FIELD
127                    to (fieldSpec to fieldImpl)), methodBody)
128        }
129    }
130
131    private fun createPreparedDeleteQueryMethodBody(method: QueryMethod,
132                                                    preparedStmtField: FieldSpec,
133                                                    queryWriter: QueryWriter): MethodSpec {
134        val scope = CodeGenScope(this)
135        val methodBuilder = overrideWithoutAnnotations(method.element, declaredDao).apply {
136            val stmtName = scope.getTmpVar("_stmt")
137            addStatement("final $T $L = $N.acquire()",
138                    SupportDbTypeNames.SQLITE_STMT, stmtName, preparedStmtField)
139            addStatement("$N.beginTransaction()", dbField)
140            beginControlFlow("try").apply {
141                val bindScope = scope.fork()
142                queryWriter.bindArgs(stmtName, emptyList(), bindScope)
143                addCode(bindScope.builder().build())
144                if (method.returnsValue) {
145                    val resultVar = scope.getTmpVar("_result")
146                    addStatement("final $L $L = $L.executeUpdateDelete()",
147                            method.returnType.typeName(), resultVar, stmtName)
148                    addStatement("$N.setTransactionSuccessful()", dbField)
149                    addStatement("return $L", resultVar)
150                } else {
151                    addStatement("$L.executeUpdateDelete()", stmtName)
152                    addStatement("$N.setTransactionSuccessful()", dbField)
153                }
154            }
155            nextControlFlow("finally").apply {
156                addStatement("$N.endTransaction()", dbField)
157                addStatement("$N.release($L)", preparedStmtField, stmtName)
158            }
159            endControlFlow()
160        }
161        return methodBuilder.build()
162    }
163
164    private fun createTransactionMethods(): List<PreparedStmtQuery> {
165        return dao.transactionMethods.map {
166            PreparedStmtQuery(emptyMap(), createTransactionMethodBody(it))
167        }
168    }
169
170    private fun createTransactionMethodBody(method: TransactionMethod): MethodSpec {
171        val scope = CodeGenScope(this)
172        val methodBuilder = overrideWithoutAnnotations(method.element, declaredDao).apply {
173            addStatement("$N.beginTransaction()", dbField)
174            beginControlFlow("try").apply {
175                val returnsValue = method.element.returnType.kind != TypeKind.VOID
176                val resultVar = if (returnsValue) {
177                    scope.getTmpVar("_result")
178                } else {
179                    null
180                }
181                addDelegateToSuperStatement(method.element, resultVar)
182                addStatement("$N.setTransactionSuccessful()", dbField)
183                if (returnsValue) {
184                    addStatement("return $N", resultVar)
185                }
186            }
187            nextControlFlow("finally").apply {
188                addStatement("$N.endTransaction()", dbField)
189            }
190            endControlFlow()
191        }
192        return methodBuilder.build()
193    }
194
195    private fun MethodSpec.Builder.addDelegateToSuperStatement(element: ExecutableElement,
196                                                               result: String?) {
197        val params: MutableList<Any> = mutableListOf()
198        val format = buildString {
199            if (result != null) {
200                append("$T $L = ")
201                params.add(element.returnType)
202                params.add(result)
203            }
204            append("super.$N(")
205            params.add(element.simpleName)
206            var first = true
207            element.parameters.forEach {
208                if (first) {
209                    first = false
210                } else {
211                    append(", ")
212                }
213                append(L)
214                params.add(it.simpleName)
215            }
216            append(")")
217        }
218        addStatement(format, *params.toTypedArray())
219    }
220
221    private fun createConstructor(dbParam: ParameterSpec,
222                                  shortcutMethods: List<PreparedStmtQuery>,
223                                  callSuper: Boolean): MethodSpec {
224        return MethodSpec.constructorBuilder().apply {
225            addParameter(dbParam)
226            addModifiers(PUBLIC)
227            if (callSuper) {
228                addStatement("super($N)", dbParam)
229            }
230            addStatement("this.$N = $N", dbField, dbParam)
231            shortcutMethods.filterNot {
232                it.fields.isEmpty()
233            }.map {
234                it.fields.values
235            }.flatten().groupBy {
236                it.first.name
237            }.map {
238                it.value.first()
239            }.forEach {
240                addStatement("this.$N = $L", it.first, it.second)
241            }
242        }.build()
243    }
244
245    private fun createSelectMethod(method: QueryMethod): MethodSpec {
246        return overrideWithoutAnnotations(method.element, declaredDao).apply {
247            addCode(createQueryMethodBody(method))
248        }.build()
249    }
250
251    private fun createDeleteOrUpdateQueryMethod(method: QueryMethod): MethodSpec {
252        return overrideWithoutAnnotations(method.element, declaredDao).apply {
253            addCode(createDeleteOrUpdateQueryMethodBody(method))
254        }.build()
255    }
256
257    /**
258     * Groups all insertion methods based on the insert statement they will use then creates all
259     * field specs, EntityInsertionAdapterWriter and actual insert methods.
260     */
261    private fun createInsertionMethods(): List<PreparedStmtQuery> {
262        return dao.insertionMethods
263                .map { insertionMethod ->
264                    val onConflict = OnConflictProcessor.onConflictText(insertionMethod.onConflict)
265                    val entities = insertionMethod.entities
266
267                    val fields = entities.mapValues {
268                        val spec = getOrCreateField(InsertionMethodField(it.value, onConflict))
269                        val impl = EntityInsertionAdapterWriter(it.value, onConflict)
270                                .createAnonymous(this@DaoWriter, dbField.name)
271                        spec to impl
272                    }
273                    val methodImpl = overrideWithoutAnnotations(insertionMethod.element,
274                            declaredDao).apply {
275                        addCode(createInsertionMethodBody(insertionMethod, fields))
276                    }.build()
277                    PreparedStmtQuery(fields, methodImpl)
278                }.filterNotNull()
279    }
280
281    private fun createInsertionMethodBody(method: InsertionMethod,
282                                          insertionAdapters: Map<String, Pair<FieldSpec, TypeSpec>>)
283            : CodeBlock {
284        val insertionType = method.insertionType
285        if (insertionAdapters.isEmpty() || insertionType == null) {
286            return CodeBlock.builder().build()
287        }
288        val scope = CodeGenScope(this)
289
290        return scope.builder().apply {
291            // TODO assert thread
292            // TODO collect results
293            addStatement("$N.beginTransaction()", dbField)
294            val needsReturnType = insertionType != InsertionMethod.Type.INSERT_VOID
295            val resultVar = if (needsReturnType) {
296                scope.getTmpVar("_result")
297            } else {
298                null
299            }
300
301            beginControlFlow("try").apply {
302                method.parameters.forEach { param ->
303                    val insertionAdapter = insertionAdapters[param.name]?.first
304                    if (needsReturnType) {
305                        // if it has more than 1 parameter, we would've already printed the error
306                        // so we don't care about re-declaring the variable here
307                        addStatement("$T $L = $N.$L($L)",
308                                insertionType.returnTypeName, resultVar,
309                                insertionAdapter, insertionType.methodName,
310                                param.name)
311                    } else {
312                        addStatement("$N.$L($L)", insertionAdapter, insertionType.methodName,
313                                param.name)
314                    }
315                }
316                addStatement("$N.setTransactionSuccessful()", dbField)
317                if (needsReturnType) {
318                    addStatement("return $L", resultVar)
319                }
320            }
321            nextControlFlow("finally").apply {
322                addStatement("$N.endTransaction()", dbField)
323            }
324            endControlFlow()
325        }.build()
326    }
327
328    /**
329     * Creates EntityUpdateAdapter for each deletion method.
330     */
331    private fun createDeletionMethods(): List<PreparedStmtQuery> {
332        return createShortcutMethods(dao.deletionMethods, "deletion", { _, entity ->
333            EntityDeletionAdapterWriter(entity)
334                    .createAnonymous(this@DaoWriter, dbField.name)
335        })
336    }
337
338    /**
339     * Creates EntityUpdateAdapter for each @Update method.
340     */
341    private fun createUpdateMethods(): List<PreparedStmtQuery> {
342        return createShortcutMethods(dao.updateMethods, "update", { update, entity ->
343            val onConflict = OnConflictProcessor.onConflictText(update.onConflictStrategy)
344            EntityUpdateAdapterWriter(entity, onConflict)
345                    .createAnonymous(this@DaoWriter, dbField.name)
346        })
347    }
348
349    private fun <T : ShortcutMethod> createShortcutMethods(methods: List<T>, methodPrefix: String,
350                                                           implCallback: (T, Entity) -> TypeSpec)
351            : List<PreparedStmtQuery> {
352        return methods.map { method ->
353            val entities = method.entities
354
355            if (entities.isEmpty()) {
356                null
357            } else {
358                val fields = entities.mapValues {
359                    val spec = getOrCreateField(DeleteOrUpdateAdapterField(it.value, methodPrefix))
360                    val impl = implCallback(method, it.value)
361                    spec to impl
362                }
363                val methodSpec = overrideWithoutAnnotations(method.element, declaredDao).apply {
364                    addCode(createDeleteOrUpdateMethodBody(method, fields))
365                }.build()
366                PreparedStmtQuery(fields, methodSpec)
367            }
368        }.filterNotNull()
369    }
370
371    private fun createDeleteOrUpdateMethodBody(method: ShortcutMethod,
372                                               adapters: Map<String, Pair<FieldSpec, TypeSpec>>)
373            : CodeBlock {
374        if (adapters.isEmpty()) {
375            return CodeBlock.builder().build()
376        }
377        val scope = CodeGenScope(this)
378        val resultVar = if (method.returnCount) {
379            scope.getTmpVar("_total")
380        } else {
381            null
382        }
383        return scope.builder().apply {
384            if (resultVar != null) {
385                addStatement("$T $L = 0", TypeName.INT, resultVar)
386            }
387            addStatement("$N.beginTransaction()", dbField)
388            beginControlFlow("try").apply {
389                method.parameters.forEach { param ->
390                    val adapter = adapters[param.name]?.first
391                    addStatement("$L$N.$L($L)",
392                            if (resultVar == null) "" else "$resultVar +=",
393                            adapter, param.handleMethodName(), param.name)
394                }
395                addStatement("$N.setTransactionSuccessful()", dbField)
396                if (resultVar != null) {
397                    addStatement("return $L", resultVar)
398                }
399            }
400            nextControlFlow("finally").apply {
401                addStatement("$N.endTransaction()", dbField)
402            }
403            endControlFlow()
404        }.build()
405    }
406
407    /**
408     * @Query with delete action
409     */
410    private fun createDeleteOrUpdateQueryMethodBody(method: QueryMethod): CodeBlock {
411        val queryWriter = QueryWriter(method)
412        val scope = CodeGenScope(this)
413        val sqlVar = scope.getTmpVar("_sql")
414        val stmtVar = scope.getTmpVar("_stmt")
415        val listSizeArgs = queryWriter.prepareQuery(sqlVar, scope)
416        scope.builder().apply {
417            addStatement("$T $L = $N.compileStatement($L)",
418                    SupportDbTypeNames.SQLITE_STMT, stmtVar, dbField, sqlVar)
419            queryWriter.bindArgs(stmtVar, listSizeArgs, scope)
420            addStatement("$N.beginTransaction()", dbField)
421            beginControlFlow("try").apply {
422                if (method.returnsValue) {
423                    val resultVar = scope.getTmpVar("_result")
424                    addStatement("final $L $L = $L.executeUpdateDelete()",
425                            method.returnType.typeName(), resultVar, stmtVar)
426                    addStatement("$N.setTransactionSuccessful()", dbField)
427                    addStatement("return $L", resultVar)
428                } else {
429                    addStatement("$L.executeUpdateDelete()", stmtVar)
430                    addStatement("$N.setTransactionSuccessful()", dbField)
431                }
432            }
433            nextControlFlow("finally").apply {
434                addStatement("$N.endTransaction()", dbField)
435            }
436            endControlFlow()
437
438        }
439        return scope.builder().build()
440    }
441
442    private fun createQueryMethodBody(method: QueryMethod): CodeBlock {
443        val queryWriter = QueryWriter(method)
444        val scope = CodeGenScope(this)
445        val sqlVar = scope.getTmpVar("_sql")
446        val roomSQLiteQueryVar = scope.getTmpVar("_statement")
447        queryWriter.prepareReadAndBind(sqlVar, roomSQLiteQueryVar, scope)
448        method.queryResultBinder.convertAndReturn(roomSQLiteQueryVar, dbField, scope)
449        return scope.builder().build()
450    }
451
452    private fun overrideWithoutAnnotations(elm: ExecutableElement,
453                                           owner : DeclaredType): MethodSpec.Builder {
454        val baseSpec = MethodSpec.overriding(elm, owner, processingEnv.typeUtils).build()
455        return MethodSpec.methodBuilder(baseSpec.name).apply {
456            addAnnotation(Override::class.java)
457            addModifiers(baseSpec.modifiers)
458            addParameters(baseSpec.parameters)
459            varargs(baseSpec.varargs)
460            returns(baseSpec.returnType)
461        }
462    }
463
464    /**
465     * Represents a query statement prepared in Dao implementation.
466     *
467     * @param fields This map holds all the member fields necessary for this query. The key is the
468     * corresponding parameter name in the defining query method. The value is a pair from the field
469     * declaration to definition.
470     * @param methodImpl The body of the query method implementation.
471     */
472    data class PreparedStmtQuery(val fields: Map<String, Pair<FieldSpec, TypeSpec>>,
473                                 val methodImpl: MethodSpec) {
474        companion object {
475            // The key to be used in `fields` where the method requires a field that is not
476            // associated with any of its parameters
477            const val NO_PARAM_FIELD = "-"
478        }
479    }
480
481    private class InsertionMethodField(val entity: Entity, val onConflictText: String)
482        : SharedFieldSpec(
483            "insertionAdapterOf${Companion.typeNameToFieldName(entity.typeName)}",
484            RoomTypeNames.INSERTION_ADAPTER) {
485
486        override fun getUniqueKey(): String {
487            return "${entity.typeName} $onConflictText"
488        }
489
490        override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) {
491            builder.addModifiers(FINAL, PRIVATE)
492        }
493    }
494
495    class DeleteOrUpdateAdapterField(val entity: Entity, val methodPrefix: String)
496        : SharedFieldSpec(
497            "${methodPrefix}AdapterOf${Companion.typeNameToFieldName(entity.typeName)}",
498            RoomTypeNames.DELETE_OR_UPDATE_ADAPTER) {
499        override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) {
500            builder.addModifiers(PRIVATE, FINAL)
501        }
502
503        override fun getUniqueKey(): String {
504            return entity.typeName.toString() + methodPrefix
505        }
506    }
507
508    class PreparedStatementField(val method: QueryMethod) : SharedFieldSpec(
509            "preparedStmtOf${method.name.capitalize()}", RoomTypeNames.SHARED_SQLITE_STMT) {
510        override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) {
511            builder.addModifiers(PRIVATE, FINAL)
512        }
513
514        override fun getUniqueKey(): String {
515            return method.query.original
516        }
517    }
518}
519