1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package androidx.room.writer
18
19import androidx.room.ext.L
20import androidx.room.ext.T
21import androidx.room.ext.defaultValue
22import androidx.room.ext.typeName
23import androidx.room.solver.CodeGenScope
24import androidx.room.vo.CallType
25import androidx.room.vo.Constructor
26import androidx.room.vo.EmbeddedField
27import androidx.room.vo.Field
28import androidx.room.vo.FieldWithIndex
29import androidx.room.vo.Pojo
30import androidx.room.vo.RelationCollector
31import com.squareup.javapoet.TypeName
32
33/**
34 * Handles writing a field into statement or reading it form statement.
35 */
36class FieldReadWriteWriter(fieldWithIndex: FieldWithIndex) {
37    val field = fieldWithIndex.field
38    val indexVar = fieldWithIndex.indexVar
39    val alwaysExists = fieldWithIndex.alwaysExists
40
41    companion object {
42        /*
43         * Get all parents including the ones which have grand children in this list but does not
44         * have any direct children in the list.
45         */
46        fun getAllParents(fields: List<Field>): Set<EmbeddedField> {
47            val allParents = mutableSetOf<EmbeddedField>()
48            fun addAllParents(field: Field) {
49                var parent = field.parent
50                while (parent != null) {
51                    if (allParents.add(parent)) {
52                        parent = parent.parent
53                    } else {
54                        break
55                    }
56                }
57            }
58            fields.forEach(::addAllParents)
59            return allParents
60        }
61
62        /**
63         * Convert the fields with indices into a Node tree so that we can recursively process
64         * them. This work is done here instead of parsing because the result may include arbitrary
65         * fields.
66         */
67        private fun createNodeTree(
68                rootVar: String,
69                fieldsWithIndices: List<FieldWithIndex>,
70                scope: CodeGenScope): Node {
71            val allParents = getAllParents(fieldsWithIndices.map { it.field })
72            val rootNode = Node(rootVar, null)
73            rootNode.directFields = fieldsWithIndices.filter { it.field.parent == null }
74            val parentNodes = allParents.associate {
75                Pair(it, Node(
76                        varName = scope.getTmpVar("_tmp${it.field.name.capitalize()}"),
77                        fieldParent = it))
78            }
79            parentNodes.values.forEach { node ->
80                val fieldParent = node.fieldParent!!
81                val grandParent = fieldParent.parent
82                val grandParentNode = grandParent?.let {
83                    parentNodes[it]
84                } ?: rootNode
85                node.directFields = fieldsWithIndices.filter { it.field.parent == fieldParent }
86                node.parentNode = grandParentNode
87                grandParentNode.subNodes.add(node)
88            }
89            return rootNode
90        }
91
92        fun bindToStatement(
93                ownerVar: String,
94                stmtParamVar: String,
95                fieldsWithIndices: List<FieldWithIndex>,
96                scope: CodeGenScope
97        ) {
98            fun visitNode(node: Node) {
99                fun bindWithDescendants() {
100                    node.directFields.forEach {
101                        FieldReadWriteWriter(it).bindToStatement(
102                                ownerVar = node.varName,
103                                stmtParamVar = stmtParamVar,
104                                scope = scope
105                        )
106                    }
107                    node.subNodes.forEach(::visitNode)
108                }
109
110                val fieldParent = node.fieldParent
111                if (fieldParent != null) {
112                    fieldParent.getter.writeGet(
113                            ownerVar = node.parentNode!!.varName,
114                            outVar = node.varName,
115                            builder = scope.builder()
116                    )
117                    scope.builder().apply {
118                        beginControlFlow("if($L != null)", node.varName).apply {
119                            bindWithDescendants()
120                        }
121                        nextControlFlow("else").apply {
122                            node.allFields().forEach {
123                                addStatement("$L.bindNull($L)", stmtParamVar, it.indexVar)
124                            }
125                        }
126                        endControlFlow()
127                    }
128                } else {
129                    bindWithDescendants()
130                }
131            }
132            visitNode(createNodeTree(ownerVar, fieldsWithIndices, scope))
133        }
134
135        /**
136         * Just constructs the given item, does NOT DECLARE. Declaration happens outside the
137         * reading statement since we may never read if the cursor does not have necessary
138         * columns.
139         */
140        private fun construct(
141                outVar: String,
142                constructor: Constructor?,
143                typeName: TypeName,
144                localVariableNames: Map<String, FieldWithIndex>,
145                localEmbeddeds: List<Node>, scope: CodeGenScope
146        ) {
147            if (constructor == null) {
148                // best hope code generation
149                scope.builder().apply {
150                    addStatement("$L = new $T()", outVar, typeName)
151                }
152                return
153            }
154            val variableNames = constructor.params.map { param ->
155                when (param) {
156                    is Constructor.FieldParam -> localVariableNames.entries.firstOrNull {
157                        it.value.field === param.field
158                    }?.key
159                    is Constructor.EmbeddedParam -> localEmbeddeds.firstOrNull {
160                        it.fieldParent === param.embedded
161                    }?.varName
162                    else -> null
163                }
164            }
165            val args = variableNames.joinToString(",") { it ?: "null" }
166            scope.builder().apply {
167                addStatement("$L = new $T($L)", outVar, typeName, args)
168            }
169        }
170
171        /**
172         * Reads the row into the given variable. It does not declare it but constructs it.
173         */
174        fun readFromCursor(
175                outVar: String,
176                outPojo: Pojo,
177                cursorVar: String,
178                fieldsWithIndices: List<FieldWithIndex>,
179                scope: CodeGenScope,
180                relationCollectors: List<RelationCollector>
181        ) {
182            fun visitNode(node: Node) {
183                val fieldParent = node.fieldParent
184                fun readNode() {
185                    // read constructor parameters into local fields
186                    val constructorFields = node.directFields.filter {
187                        it.field.setter.callType == CallType.CONSTRUCTOR
188                    }.associateBy { fwi ->
189                        FieldReadWriteWriter(fwi).readIntoTmpVar(cursorVar, scope)
190                    }
191                    // read decomposed fields
192                    node.subNodes.forEach(::visitNode)
193                    // construct the object
194                    if (fieldParent != null) {
195                        construct(outVar = node.varName,
196                                constructor = fieldParent.pojo.constructor,
197                                typeName = fieldParent.field.typeName,
198                                localEmbeddeds = node.subNodes,
199                                localVariableNames = constructorFields,
200                                scope = scope)
201                    } else {
202                        construct(outVar = node.varName,
203                                constructor = outPojo.constructor,
204                                typeName = outPojo.typeName,
205                                localEmbeddeds = node.subNodes,
206                                localVariableNames = constructorFields,
207                                scope = scope)
208                    }
209                    // ready any field that was not part of the constructor
210                    node.directFields.filterNot {
211                        it.field.setter.callType == CallType.CONSTRUCTOR
212                    }.forEach { fwi ->
213                        FieldReadWriteWriter(fwi).readFromCursor(
214                                ownerVar = node.varName,
215                                cursorVar = cursorVar,
216                                scope = scope)
217                    }
218                    // assign relationship fields which will be read later
219                    relationCollectors.filter { (relation) ->
220                        relation.field.parent === fieldParent
221                    }.forEach {
222                        it.writeReadParentKeyCode(
223                                cursorVarName = cursorVar,
224                                itemVar = node.varName,
225                                fieldsWithIndices = fieldsWithIndices,
226                                scope = scope)
227                    }
228                    // assign sub modes to fields if they were not part of the constructor.
229                    node.subNodes.mapNotNull {
230                        val setter = it.fieldParent?.setter
231                        if (setter != null && setter.callType != CallType.CONSTRUCTOR) {
232                            Pair(it.varName, setter)
233                        } else {
234                            null
235                        }
236                    }.forEach { (varName, setter) ->
237                        setter.writeSet(
238                                ownerVar = node.varName,
239                                inVar = varName,
240                                builder = scope.builder())
241                    }
242                }
243                if (fieldParent == null) {
244                    // root element
245                    // always declared by the caller so we don't declare this
246                    readNode()
247                } else {
248                    // always declare, we'll set below
249                    scope.builder().addStatement("final $T $L", fieldParent.pojo.typeName,
250                                        node.varName)
251                    if (fieldParent.nonNull) {
252                        readNode()
253                    } else {
254                        val myDescendants = node.allFields()
255                        val allNullCheck = myDescendants.joinToString(" && ") {
256                            if (it.alwaysExists) {
257                                "$cursorVar.isNull(${it.indexVar})"
258                            } else {
259                                "( ${it.indexVar} == -1 || $cursorVar.isNull(${it.indexVar}))"
260                            }
261                        }
262                        scope.builder().apply {
263                            beginControlFlow("if (! ($L))", allNullCheck).apply {
264                                readNode()
265                            }
266                            nextControlFlow(" else ").apply {
267                                addStatement("$L = null", node.varName)
268                            }
269                            endControlFlow()
270                        }
271                    }
272                }
273            }
274            visitNode(createNodeTree(outVar, fieldsWithIndices, scope))
275        }
276    }
277
278    /**
279     * @param ownerVar The entity / pojo that owns this field. It must own this field! (not the
280     * container pojo)
281     * @param stmtParamVar The statement variable
282     * @param scope The code generation scope
283     */
284    private fun bindToStatement(ownerVar: String, stmtParamVar: String, scope: CodeGenScope) {
285        field.statementBinder?.let { binder ->
286            val varName = if (field.getter.callType == CallType.FIELD) {
287                "$ownerVar.${field.name}"
288            } else {
289                "$ownerVar.${field.getter.name}()"
290            }
291            binder.bindToStmt(stmtParamVar, indexVar, varName, scope)
292        }
293    }
294
295    /**
296     * @param ownerVar The entity / pojo that owns this field. It must own this field (not the
297     * container pojo)
298     * @param cursorVar The cursor variable
299     * @param scope The code generation scope
300     */
301    private fun readFromCursor(ownerVar: String, cursorVar: String, scope: CodeGenScope) {
302        fun toRead() {
303            field.cursorValueReader?.let { reader ->
304                scope.builder().apply {
305                    when (field.setter.callType) {
306                        CallType.FIELD -> {
307                            reader.readFromCursor("$ownerVar.${field.getter.name}", cursorVar,
308                                    indexVar, scope)
309                        }
310                        CallType.METHOD -> {
311                            val tmpField = scope.getTmpVar("_tmp${field.name.capitalize()}")
312                            addStatement("final $T $L", field.getter.type.typeName(), tmpField)
313                            reader.readFromCursor(tmpField, cursorVar, indexVar, scope)
314                            addStatement("$L.$L($L)", ownerVar, field.setter.name, tmpField)
315                        }
316                        CallType.CONSTRUCTOR -> {
317                            // no-op
318                        }
319                    }
320                }
321            }
322        }
323        if (alwaysExists) {
324            toRead()
325        } else {
326            scope.builder().apply {
327                beginControlFlow("if ($L != -1)", indexVar).apply {
328                    toRead()
329                }
330                endControlFlow()
331            }
332        }
333    }
334
335    /**
336     * Reads the value into a temporary local variable.
337     */
338    fun readIntoTmpVar(cursorVar: String, scope: CodeGenScope): String {
339        val tmpField = scope.getTmpVar("_tmp${field.name.capitalize()}")
340        val typeName = field.getter.type.typeName()
341        scope.builder().apply {
342            addStatement("final $T $L", typeName, tmpField)
343            if (alwaysExists) {
344                field.cursorValueReader?.readFromCursor(tmpField, cursorVar, indexVar, scope)
345            } else {
346                beginControlFlow("if ($L == -1)", indexVar).apply {
347                    addStatement("$L = $L", tmpField, typeName.defaultValue())
348                }
349                nextControlFlow("else").apply {
350                    field.cursorValueReader?.readFromCursor(tmpField, cursorVar, indexVar, scope)
351                }
352                endControlFlow()
353            }
354        }
355        return tmpField
356    }
357
358    /**
359     * On demand node which is created based on the fields that were passed into this class.
360     */
361    private class Node(
362            // root for me
363            val varName: String,
364            // set if I'm a FieldParent
365            val fieldParent: EmbeddedField?
366    ) {
367        // whom do i belong
368        var parentNode: Node? = null
369        // these fields are my direct fields
370        lateinit var directFields: List<FieldWithIndex>
371        // these nodes are under me
372        val subNodes = mutableListOf<Node>()
373
374        fun allFields(): List<FieldWithIndex> {
375            return directFields + subNodes.flatMap { it.allFields() }
376        }
377    }
378}
379