2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
17package androidx.room.writer
19import androidx.room.ext.L
20import androidx.room.ext.T
21import androidx.room.ext.defaultValue
22import androidx.room.ext.typeName
23import androidx.room.solver.CodeGenScope
24import androidx.room.vo.CallType
25import androidx.room.vo.Constructor
26import androidx.room.vo.EmbeddedField
27import androidx.room.vo.Field
28import androidx.room.vo.FieldWithIndex
29import androidx.room.vo.Pojo
30import androidx.room.vo.RelationCollector
31import com.squareup.javapoet.TypeName
34 * Handles writing a field into statement or reading it form statement.
35 */
36class FieldReadWriteWriter(fieldWithIndex: FieldWithIndex) {
37    val field = fieldWithIndex.field
38    val indexVar = fieldWithIndex.indexVar
39    val alwaysExists = fieldWithIndex.alwaysExists
41    companion object {
42        /*
43         * Get all parents including the ones which have grand children in this list but does not
44         * have any direct children in the list.
45         */
46        fun getAllParents(fields: List<Field>): Set<EmbeddedField> {
47            val allParents = mutableSetOf<EmbeddedField>()
48            fun addAllParents(field: Field) {
49                var parent = field.parent
50                while (parent != null) {
51                    if (allParents.add(parent)) {
52                        parent = parent.parent
53                    } else {
54                        break
55                    }
56                }
57            }
58            fields.forEach(::addAllParents)
59            return allParents
60        }
62        /**
63         * Convert the fields with indices into a Node tree so that we can recursively process
64         * them. This work is done here instead of parsing because the result may include arbitrary
65         * fields.
66         */
67        private fun createNodeTree(
68                rootVar: String,
69                fieldsWithIndices: List<FieldWithIndex>,
70                scope: CodeGenScope): Node {
71            val allParents = getAllParents(fieldsWithIndices.map { it.field })
72            val rootNode = Node(rootVar, null)
73            rootNode.directFields = fieldsWithIndices.filter { it.field.parent == null }
74            val parentNodes = allParents.associate {
75                Pair(it, Node(
76                        varName = scope.getTmpVar("_tmp${it.field.name.capitalize()}"),
77                        fieldParent = it))
78            }
79            parentNodes.values.forEach { node ->
80                val fieldParent = node.fieldParent!!
81                val grandParent = fieldParent.parent
82                val grandParentNode = grandParent?.let {
83                    parentNodes[it]
84                } ?: rootNode
85                node.directFields = fieldsWithIndices.filter { it.field.parent == fieldParent }
86                node.parentNode = grandParentNode
87                grandParentNode.subNodes.add(node)
88            }
89            return rootNode
90        }
92        fun bindToStatement(
93                ownerVar: String,
94                stmtParamVar: String,
95                fieldsWithIndices: List<FieldWithIndex>,
96                scope: CodeGenScope
97        ) {
98            fun visitNode(node: Node) {
99                fun bindWithDescendants() {
100                    node.directFields.forEach {
101                        FieldReadWriteWriter(it).bindToStatement(
102                                ownerVar = node.varName,
103                                stmtParamVar = stmtParamVar,
104                                scope = scope
105                        )
106                    }
107                    node.subNodes.forEach(::visitNode)
108                }
110                val fieldParent = node.fieldParent
111                if (fieldParent != null) {
112                    fieldParent.getter.writeGet(
113                            ownerVar = node.parentNode!!.varName,
114                            outVar = node.varName,
115                            builder = scope.builder()
116                    )
117                    scope.builder().apply {
118                        beginControlFlow("if($L != null)", node.varName).apply {
119                            bindWithDescendants()
120                        }
121                        nextControlFlow("else").apply {
122                            node.allFields().forEach {
123                                addStatement("$L.bindNull($L)", stmtParamVar, it.indexVar)
124                            }
125                        }
126                        endControlFlow()
127                    }
128                } else {
129                    bindWithDescendants()
130                }
131            }
132            visitNode(createNodeTree(ownerVar, fieldsWithIndices, scope))
133        }
135        /**
136         * Just constructs the given item, does NOT DECLARE. Declaration happens outside the
137         * reading statement since we may never read if the cursor does not have necessary
138         * columns.
139         */
140        private fun construct(
141                outVar: String,
142                constructor: Constructor?,
143                typeName: TypeName,
144                localVariableNames: Map<String, FieldWithIndex>,
145                localEmbeddeds: List<Node>, scope: CodeGenScope
146        ) {
147            if (constructor == null) {
148                // best hope code generation
149                scope.builder().apply {
150                    addStatement("$L = new $T()", outVar, typeName)
151                }
152                return
153            }
154            val variableNames = constructor.params.map { param ->
155                when (param) {
156                    is Constructor.FieldParam -> localVariableNames.entries.firstOrNull {
157                        it.value.field === param.field
158                    }?.key
159                    is Constructor.EmbeddedParam -> localEmbeddeds.firstOrNull {
160                        it.fieldParent === param.embedded
161                    }?.varName
162                    else -> null
163                }
164            }
165            val args = variableNames.joinToString(",") { it ?: "null" }
166            scope.builder().apply {
167                addStatement("$L = new $T($L)", outVar, typeName, args)
168            }
169        }
171        /**
172         * Reads the row into the given variable. It does not declare it but constructs it.
173         */
174        fun readFromCursor(
175                outVar: String,
176                outPojo: Pojo,
177                cursorVar: String,
178                fieldsWithIndices: List<FieldWithIndex>,
179                scope: CodeGenScope,
180                relationCollectors: List<RelationCollector>
181        ) {
182            fun visitNode(node: Node) {
183                val fieldParent = node.fieldParent
184                fun readNode() {
185                    // read constructor parameters into local fields
186                    val constructorFields = node.directFields.filter {
187                        it.field.setter.callType == CallType.CONSTRUCTOR
188                    }.associateBy { fwi ->
189                        FieldReadWriteWriter(fwi).readIntoTmpVar(cursorVar, scope)
190                    }
191                    // read decomposed fields
192                    node.subNodes.forEach(::visitNode)
193                    // construct the object
194                    if (fieldParent != null) {
195                        construct(outVar = node.varName,
196                                constructor = fieldParent.pojo.constructor,
197                                typeName = fieldParent.field.typeName,
198                                localEmbeddeds = node.subNodes,
199                                localVariableNames = constructorFields,
200                                scope = scope)
201                    } else {
202                        construct(outVar = node.varName,
203                                constructor = outPojo.constructor,
204                                typeName = outPojo.typeName,
205                                localEmbeddeds = node.subNodes,
206                                localVariableNames = constructorFields,
207                                scope = scope)
208                    }
209                    // ready any field that was not part of the constructor
210                    node.directFields.filterNot {
211                        it.field.setter.callType == CallType.CONSTRUCTOR
212                    }.forEach { fwi ->
213                        FieldReadWriteWriter(fwi).readFromCursor(
214                                ownerVar = node.varName,
215                                cursorVar = cursorVar,
216                                scope = scope)
217                    }
218                    // assign relationship fields which will be read later
219                    relationCollectors.filter { (relation) ->
220                        relation.field.parent === fieldParent
221                    }.forEach {
222                        it.writeReadParentKeyCode(
223                                cursorVarName = cursorVar,
224                                itemVar = node.varName,
225                                fieldsWithIndices = fieldsWithIndices,
226                                scope = scope)
227                    }
228                    // assign sub modes to fields if they were not part of the constructor.
229                    node.subNodes.mapNotNull {
230                        val setter = it.fieldParent?.setter
231                        if (setter != null && setter.callType != CallType.CONSTRUCTOR) {
232                            Pair(it.varName, setter)
233                        } else {
234                            null
235                        }
236                    }.forEach { (varName, setter) ->
237                        setter.writeSet(
238                                ownerVar = node.varName,
239                                inVar = varName,
240                                builder = scope.builder())
241                    }
242                }
243                if (fieldParent == null) {
244                    // root element
245                    // always declared by the caller so we don't declare this
246                    readNode()
247                } else {
248                    // always declare, we'll set below
249                    scope.builder().addStatement("final $T $L", fieldParent.pojo.typeName,
250                                        node.varName)
251                    if (fieldParent.nonNull) {
252                        readNode()
253                    } else {
254                        val myDescendants = node.allFields()
255                        val allNullCheck = myDescendants.joinToString(" && ") {
256                            if (it.alwaysExists) {
257                                "$cursorVar.isNull(${it.indexVar})"
258                            } else {
259                                "( ${it.indexVar} == -1 || $cursorVar.isNull(${it.indexVar}))"
260                            }
261                        }
262                        scope.builder().apply {
263                            beginControlFlow("if (! ($L))", allNullCheck).apply {
264                                readNode()
265                            }
266                            nextControlFlow(" else ").apply {
267                                addStatement("$L = null", node.varName)
268                            }
269                            endControlFlow()
270                        }
271                    }
272                }
273            }
274            visitNode(createNodeTree(outVar, fieldsWithIndices, scope))
275        }
276    }
278    /**
279     * @param ownerVar The entity / pojo that owns this field. It must own this field! (not the
280     * container pojo)
281     * @param stmtParamVar The statement variable
282     * @param scope The code generation scope
283     */
284    private fun bindToStatement(ownerVar: String, stmtParamVar: String, scope: CodeGenScope) {
285        field.statementBinder?.let { binder ->
286            val varName = if (field.getter.callType == CallType.FIELD) {
287                "$ownerVar.${field.name}"
288            } else {
289                "$ownerVar.${field.getter.name}()"
290            }
291            binder.bindToStmt(stmtParamVar, indexVar, varName, scope)
292        }
293    }
295    /**
296     * @param ownerVar The entity / pojo that owns this field. It must own this field (not the
297     * container pojo)
298     * @param cursorVar The cursor variable
299     * @param scope The code generation scope
300     */
301    private fun readFromCursor(ownerVar: String, cursorVar: String, scope: CodeGenScope) {
302        fun toRead() {
303            field.cursorValueReader?.let { reader ->
304                scope.builder().apply {
305                    when (field.setter.callType) {
306                        CallType.FIELD -> {
307                            reader.readFromCursor("$ownerVar.${field.getter.name}", cursorVar,
308                                    indexVar, scope)
309                        }
310                        CallType.METHOD -> {
311                            val tmpField = scope.getTmpVar("_tmp${field.name.capitalize()}")
312                            addStatement("final $T $L", field.getter.type.typeName(), tmpField)
313                            reader.readFromCursor(tmpField, cursorVar, indexVar, scope)
314                            addStatement("$L.$L($L)", ownerVar, field.setter.name, tmpField)
315                        }
316                        CallType.CONSTRUCTOR -> {
317                            // no-op
318                        }
319                    }
320                }
321            }
322        }
323        if (alwaysExists) {
324            toRead()
325        } else {
326            scope.builder().apply {
327                beginControlFlow("if ($L != -1)", indexVar).apply {
328                    toRead()
329                }
330                endControlFlow()
331            }
332        }
333    }
335    /**
336     * Reads the value into a temporary local variable.
337     */
338    fun readIntoTmpVar(cursorVar: String, scope: CodeGenScope): String {
339        val tmpField = scope.getTmpVar("_tmp${field.name.capitalize()}")
340        val typeName = field.getter.type.typeName()
341        scope.builder().apply {
342            addStatement("final $T $L", typeName, tmpField)
343            if (alwaysExists) {
344                field.cursorValueReader?.readFromCursor(tmpField, cursorVar, indexVar, scope)
345            } else {
346                beginControlFlow("if ($L == -1)", indexVar).apply {
347                    addStatement("$L = $L", tmpField, typeName.defaultValue())
348                }
349                nextControlFlow("else").apply {
350                    field.cursorValueReader?.readFromCursor(tmpField, cursorVar, indexVar, scope)
351                }
352                endControlFlow()
353            }
354        }
355        return tmpField
356    }
358    /**
359     * On demand node which is created based on the fields that were passed into this class.
360     */
361    private class Node(
362            // root for me
363            val varName: String,
364            // set if I'm a FieldParent
365            val fieldParent: EmbeddedField?
366    ) {
367        // whom do i belong
368        var parentNode: Node? = null
369        // these fields are my direct fields
370        lateinit var directFields: List<FieldWithIndex>
371        // these nodes are under me
372        val subNodes = mutableListOf<Node>()
374        fun allFields(): List<FieldWithIndex> {
375            return directFields + subNodes.flatMap { it.allFields() }
376        }
377    }