1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package androidx.room.writer 18 19import androidx.room.ext.L 20import androidx.room.ext.T 21import androidx.room.ext.defaultValue 22import androidx.room.ext.typeName 23import androidx.room.solver.CodeGenScope 24import androidx.room.vo.CallType 25import androidx.room.vo.Constructor 26import androidx.room.vo.EmbeddedField 27import androidx.room.vo.Field 28import androidx.room.vo.FieldWithIndex 29import androidx.room.vo.Pojo 30import androidx.room.vo.RelationCollector 31import com.squareup.javapoet.TypeName 32 33/** 34 * Handles writing a field into statement or reading it form statement. 35 */ 36class FieldReadWriteWriter(fieldWithIndex: FieldWithIndex) { 37 val field = fieldWithIndex.field 38 val indexVar = fieldWithIndex.indexVar 39 val alwaysExists = fieldWithIndex.alwaysExists 40 41 companion object { 42 /* 43 * Get all parents including the ones which have grand children in this list but does not 44 * have any direct children in the list. 45 */ 46 fun getAllParents(fields: List<Field>): Set<EmbeddedField> { 47 val allParents = mutableSetOf<EmbeddedField>() 48 fun addAllParents(field: Field) { 49 var parent = field.parent 50 while (parent != null) { 51 if (allParents.add(parent)) { 52 parent = parent.parent 53 } else { 54 break 55 } 56 } 57 } 58 fields.forEach(::addAllParents) 59 return allParents 60 } 61 62 /** 63 * Convert the fields with indices into a Node tree so that we can recursively process 64 * them. This work is done here instead of parsing because the result may include arbitrary 65 * fields. 66 */ 67 private fun createNodeTree( 68 rootVar: String, 69 fieldsWithIndices: List<FieldWithIndex>, 70 scope: CodeGenScope): Node { 71 val allParents = getAllParents(fieldsWithIndices.map { it.field }) 72 val rootNode = Node(rootVar, null) 73 rootNode.directFields = fieldsWithIndices.filter { it.field.parent == null } 74 val parentNodes = allParents.associate { 75 Pair(it, Node( 76 varName = scope.getTmpVar("_tmp${it.field.name.capitalize()}"), 77 fieldParent = it)) 78 } 79 parentNodes.values.forEach { node -> 80 val fieldParent = node.fieldParent!! 81 val grandParent = fieldParent.parent 82 val grandParentNode = grandParent?.let { 83 parentNodes[it] 84 } ?: rootNode 85 node.directFields = fieldsWithIndices.filter { it.field.parent == fieldParent } 86 node.parentNode = grandParentNode 87 grandParentNode.subNodes.add(node) 88 } 89 return rootNode 90 } 91 92 fun bindToStatement( 93 ownerVar: String, 94 stmtParamVar: String, 95 fieldsWithIndices: List<FieldWithIndex>, 96 scope: CodeGenScope 97 ) { 98 fun visitNode(node: Node) { 99 fun bindWithDescendants() { 100 node.directFields.forEach { 101 FieldReadWriteWriter(it).bindToStatement( 102 ownerVar = node.varName, 103 stmtParamVar = stmtParamVar, 104 scope = scope 105 ) 106 } 107 node.subNodes.forEach(::visitNode) 108 } 109 110 val fieldParent = node.fieldParent 111 if (fieldParent != null) { 112 fieldParent.getter.writeGet( 113 ownerVar = node.parentNode!!.varName, 114 outVar = node.varName, 115 builder = scope.builder() 116 ) 117 scope.builder().apply { 118 beginControlFlow("if($L != null)", node.varName).apply { 119 bindWithDescendants() 120 } 121 nextControlFlow("else").apply { 122 node.allFields().forEach { 123 addStatement("$L.bindNull($L)", stmtParamVar, it.indexVar) 124 } 125 } 126 endControlFlow() 127 } 128 } else { 129 bindWithDescendants() 130 } 131 } 132 visitNode(createNodeTree(ownerVar, fieldsWithIndices, scope)) 133 } 134 135 /** 136 * Just constructs the given item, does NOT DECLARE. Declaration happens outside the 137 * reading statement since we may never read if the cursor does not have necessary 138 * columns. 139 */ 140 private fun construct( 141 outVar: String, 142 constructor: Constructor?, 143 typeName: TypeName, 144 localVariableNames: Map<String, FieldWithIndex>, 145 localEmbeddeds: List<Node>, scope: CodeGenScope 146 ) { 147 if (constructor == null) { 148 // best hope code generation 149 scope.builder().apply { 150 addStatement("$L = new $T()", outVar, typeName) 151 } 152 return 153 } 154 val variableNames = constructor.params.map { param -> 155 when (param) { 156 is Constructor.FieldParam -> localVariableNames.entries.firstOrNull { 157 it.value.field === param.field 158 }?.key 159 is Constructor.EmbeddedParam -> localEmbeddeds.firstOrNull { 160 it.fieldParent === param.embedded 161 }?.varName 162 else -> null 163 } 164 } 165 val args = variableNames.joinToString(",") { it ?: "null" } 166 scope.builder().apply { 167 addStatement("$L = new $T($L)", outVar, typeName, args) 168 } 169 } 170 171 /** 172 * Reads the row into the given variable. It does not declare it but constructs it. 173 */ 174 fun readFromCursor( 175 outVar: String, 176 outPojo: Pojo, 177 cursorVar: String, 178 fieldsWithIndices: List<FieldWithIndex>, 179 scope: CodeGenScope, 180 relationCollectors: List<RelationCollector> 181 ) { 182 fun visitNode(node: Node) { 183 val fieldParent = node.fieldParent 184 fun readNode() { 185 // read constructor parameters into local fields 186 val constructorFields = node.directFields.filter { 187 it.field.setter.callType == CallType.CONSTRUCTOR 188 }.associateBy { fwi -> 189 FieldReadWriteWriter(fwi).readIntoTmpVar(cursorVar, scope) 190 } 191 // read decomposed fields 192 node.subNodes.forEach(::visitNode) 193 // construct the object 194 if (fieldParent != null) { 195 construct(outVar = node.varName, 196 constructor = fieldParent.pojo.constructor, 197 typeName = fieldParent.field.typeName, 198 localEmbeddeds = node.subNodes, 199 localVariableNames = constructorFields, 200 scope = scope) 201 } else { 202 construct(outVar = node.varName, 203 constructor = outPojo.constructor, 204 typeName = outPojo.typeName, 205 localEmbeddeds = node.subNodes, 206 localVariableNames = constructorFields, 207 scope = scope) 208 } 209 // ready any field that was not part of the constructor 210 node.directFields.filterNot { 211 it.field.setter.callType == CallType.CONSTRUCTOR 212 }.forEach { fwi -> 213 FieldReadWriteWriter(fwi).readFromCursor( 214 ownerVar = node.varName, 215 cursorVar = cursorVar, 216 scope = scope) 217 } 218 // assign relationship fields which will be read later 219 relationCollectors.filter { (relation) -> 220 relation.field.parent === fieldParent 221 }.forEach { 222 it.writeReadParentKeyCode( 223 cursorVarName = cursorVar, 224 itemVar = node.varName, 225 fieldsWithIndices = fieldsWithIndices, 226 scope = scope) 227 } 228 // assign sub modes to fields if they were not part of the constructor. 229 node.subNodes.mapNotNull { 230 val setter = it.fieldParent?.setter 231 if (setter != null && setter.callType != CallType.CONSTRUCTOR) { 232 Pair(it.varName, setter) 233 } else { 234 null 235 } 236 }.forEach { (varName, setter) -> 237 setter.writeSet( 238 ownerVar = node.varName, 239 inVar = varName, 240 builder = scope.builder()) 241 } 242 } 243 if (fieldParent == null) { 244 // root element 245 // always declared by the caller so we don't declare this 246 readNode() 247 } else { 248 // always declare, we'll set below 249 scope.builder().addStatement("final $T $L", fieldParent.pojo.typeName, 250 node.varName) 251 if (fieldParent.nonNull) { 252 readNode() 253 } else { 254 val myDescendants = node.allFields() 255 val allNullCheck = myDescendants.joinToString(" && ") { 256 if (it.alwaysExists) { 257 "$cursorVar.isNull(${it.indexVar})" 258 } else { 259 "( ${it.indexVar} == -1 || $cursorVar.isNull(${it.indexVar}))" 260 } 261 } 262 scope.builder().apply { 263 beginControlFlow("if (! ($L))", allNullCheck).apply { 264 readNode() 265 } 266 nextControlFlow(" else ").apply { 267 addStatement("$L = null", node.varName) 268 } 269 endControlFlow() 270 } 271 } 272 } 273 } 274 visitNode(createNodeTree(outVar, fieldsWithIndices, scope)) 275 } 276 } 277 278 /** 279 * @param ownerVar The entity / pojo that owns this field. It must own this field! (not the 280 * container pojo) 281 * @param stmtParamVar The statement variable 282 * @param scope The code generation scope 283 */ 284 private fun bindToStatement(ownerVar: String, stmtParamVar: String, scope: CodeGenScope) { 285 field.statementBinder?.let { binder -> 286 val varName = if (field.getter.callType == CallType.FIELD) { 287 "$ownerVar.${field.name}" 288 } else { 289 "$ownerVar.${field.getter.name}()" 290 } 291 binder.bindToStmt(stmtParamVar, indexVar, varName, scope) 292 } 293 } 294 295 /** 296 * @param ownerVar The entity / pojo that owns this field. It must own this field (not the 297 * container pojo) 298 * @param cursorVar The cursor variable 299 * @param scope The code generation scope 300 */ 301 private fun readFromCursor(ownerVar: String, cursorVar: String, scope: CodeGenScope) { 302 fun toRead() { 303 field.cursorValueReader?.let { reader -> 304 scope.builder().apply { 305 when (field.setter.callType) { 306 CallType.FIELD -> { 307 reader.readFromCursor("$ownerVar.${field.getter.name}", cursorVar, 308 indexVar, scope) 309 } 310 CallType.METHOD -> { 311 val tmpField = scope.getTmpVar("_tmp${field.name.capitalize()}") 312 addStatement("final $T $L", field.getter.type.typeName(), tmpField) 313 reader.readFromCursor(tmpField, cursorVar, indexVar, scope) 314 addStatement("$L.$L($L)", ownerVar, field.setter.name, tmpField) 315 } 316 CallType.CONSTRUCTOR -> { 317 // no-op 318 } 319 } 320 } 321 } 322 } 323 if (alwaysExists) { 324 toRead() 325 } else { 326 scope.builder().apply { 327 beginControlFlow("if ($L != -1)", indexVar).apply { 328 toRead() 329 } 330 endControlFlow() 331 } 332 } 333 } 334 335 /** 336 * Reads the value into a temporary local variable. 337 */ 338 fun readIntoTmpVar(cursorVar: String, scope: CodeGenScope): String { 339 val tmpField = scope.getTmpVar("_tmp${field.name.capitalize()}") 340 val typeName = field.getter.type.typeName() 341 scope.builder().apply { 342 addStatement("final $T $L", typeName, tmpField) 343 if (alwaysExists) { 344 field.cursorValueReader?.readFromCursor(tmpField, cursorVar, indexVar, scope) 345 } else { 346 beginControlFlow("if ($L == -1)", indexVar).apply { 347 addStatement("$L = $L", tmpField, typeName.defaultValue()) 348 } 349 nextControlFlow("else").apply { 350 field.cursorValueReader?.readFromCursor(tmpField, cursorVar, indexVar, scope) 351 } 352 endControlFlow() 353 } 354 } 355 return tmpField 356 } 357 358 /** 359 * On demand node which is created based on the fields that were passed into this class. 360 */ 361 private class Node( 362 // root for me 363 val varName: String, 364 // set if I'm a FieldParent 365 val fieldParent: EmbeddedField? 366 ) { 367 // whom do i belong 368 var parentNode: Node? = null 369 // these fields are my direct fields 370 lateinit var directFields: List<FieldWithIndex> 371 // these nodes are under me 372 val subNodes = mutableListOf<Node>() 373 374 fun allFields(): List<FieldWithIndex> { 375 return directFields + subNodes.flatMap { it.allFields() } 376 } 377 } 378} 379