1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package androidx.room.vo
18
19import androidx.room.ext.AndroidTypeNames
20import androidx.room.ext.CommonTypeNames
21import androidx.room.ext.L
22import androidx.room.ext.N
23import androidx.room.ext.T
24import androidx.room.ext.typeName
25import androidx.room.parser.ParsedQuery
26import androidx.room.parser.SQLTypeAffinity
27import androidx.room.parser.SqlParser
28import androidx.room.processor.Context
29import androidx.room.processor.ProcessorErrors.CANNOT_FIND_QUERY_RESULT_ADAPTER
30import androidx.room.processor.ProcessorErrors.relationAffinityMismatch
31import androidx.room.solver.CodeGenScope
32import androidx.room.solver.query.result.RowAdapter
33import androidx.room.solver.query.result.SingleColumnRowAdapter
34import androidx.room.verifier.DatabaseVerificaitonErrors
35import androidx.room.writer.QueryWriter
36import androidx.room.writer.RelationCollectorMethodWriter
37import com.squareup.javapoet.ArrayTypeName
38import com.squareup.javapoet.ClassName
39import com.squareup.javapoet.CodeBlock
40import com.squareup.javapoet.ParameterizedTypeName
41import com.squareup.javapoet.TypeName
42import stripNonJava
43import java.util.ArrayList
44import java.util.HashSet
45import javax.lang.model.type.TypeKind
46import javax.lang.model.type.TypeMirror
47
48/**
49 * Internal class that is used to manage fetching 1/N to N relationships.
50 */
51data class RelationCollector(val relation: Relation,
52                             val affinity: SQLTypeAffinity,
53                             val mapTypeName: ParameterizedTypeName,
54                             val keyTypeName: TypeName,
55                             val collectionTypeName: ParameterizedTypeName,
56                             val queryWriter: QueryWriter,
57                             val rowAdapter: RowAdapter,
58                             val loadAllQuery: ParsedQuery) {
59    // set when writing the code generator in writeInitCode
60    lateinit var varName: String
61
62    fun writeInitCode(scope: CodeGenScope) {
63        val tmpVar = scope.getTmpVar(
64                "_collection${relation.field.getPath().stripNonJava().capitalize()}")
65        scope.builder().addStatement("final $T $L = new $T()", mapTypeName, tmpVar, mapTypeName)
66        varName = tmpVar
67    }
68
69    // called after reading each item to extract the key if it exists
70    fun writeReadParentKeyCode(cursorVarName: String, itemVar: String,
71                               fieldsWithIndices: List<FieldWithIndex>, scope: CodeGenScope) {
72        val indexVar = fieldsWithIndices.firstOrNull {
73            it.field === relation.parentField
74        }?.indexVar
75        scope.builder().apply {
76            readKey(
77                    cursorVarName = cursorVarName,
78                    indexVar = indexVar,
79                    scope = scope
80            ) { tmpVar ->
81                val tmpCollectionVar = scope.getTmpVar("_tmpCollection")
82                addStatement("$T $L = $L.get($L)", collectionTypeName, tmpCollectionVar,
83                        varName, tmpVar)
84                beginControlFlow("if($L == null)", tmpCollectionVar).apply {
85                    addStatement("$L = new $T()", tmpCollectionVar, collectionTypeName)
86                    addStatement("$L.put($L, $L)", varName, tmpVar, tmpCollectionVar)
87                }
88                endControlFlow()
89                // set it on the item
90                relation.field.setter.writeSet(itemVar, tmpCollectionVar, this)
91            }
92        }
93    }
94
95    fun writeCollectionCode(scope: CodeGenScope) {
96        val method = scope.writer
97                .getOrCreateMethod(RelationCollectorMethodWriter(this))
98        scope.builder().apply {
99            addStatement("$N($L)", method, varName)
100        }
101    }
102
103    fun readKey(cursorVarName: String, indexVar: String?, scope: CodeGenScope,
104                postRead: CodeBlock.Builder.(String) -> Unit) {
105        val cursorGetter = when (affinity) {
106            SQLTypeAffinity.INTEGER -> "getLong"
107            SQLTypeAffinity.REAL -> "getDouble"
108            SQLTypeAffinity.TEXT -> "getString"
109            SQLTypeAffinity.BLOB -> "getBlob"
110            else -> {
111                "getString"
112            }
113        }
114        scope.builder().apply {
115            beginControlFlow("if (!$L.isNull($L))", cursorVarName, indexVar).apply {
116                val tmpVar = scope.getTmpVar("_tmpKey")
117                addStatement("final $T $L = $L.$L($L)", keyTypeName,
118                        tmpVar, cursorVarName, cursorGetter, indexVar)
119                this.postRead(tmpVar)
120            }
121            endControlFlow()
122        }
123    }
124
125    companion object {
126        fun createCollectors(
127                baseContext: Context,
128                relations: List<Relation>
129        ): List<RelationCollector> {
130            return relations.map { relation ->
131                // decide on the affinity
132                val context = baseContext.fork(relation.field.element)
133                val parentAffinity = relation.parentField.cursorValueReader?.affinity()
134                val childAffinity = relation.entityField.cursorValueReader?.affinity()
135                val affinity = if (parentAffinity != null && parentAffinity == childAffinity) {
136                    parentAffinity
137                } else {
138                    context.logger.w(Warning.RELATION_TYPE_MISMATCH, relation.field.element,
139                            relationAffinityMismatch(
140                                    parentColumn = relation.parentField.columnName,
141                                    childColumn = relation.entityField.columnName,
142                                    parentAffinity = parentAffinity,
143                                    childAffinity = childAffinity))
144                    SQLTypeAffinity.TEXT
145                }
146                val keyType = keyTypeFor(context, affinity)
147                val collectionTypeName = if (relation.field.typeName is ParameterizedTypeName) {
148                    val paramType = relation.field.typeName as ParameterizedTypeName
149                    if (paramType.rawType == CommonTypeNames.LIST) {
150                        ParameterizedTypeName.get(ClassName.get(ArrayList::class.java),
151                                relation.pojoTypeName)
152                    } else if (paramType.rawType == CommonTypeNames.SET) {
153                        ParameterizedTypeName.get(ClassName.get(HashSet::class.java),
154                                relation.pojoTypeName)
155                    } else {
156                        ParameterizedTypeName.get(ClassName.get(ArrayList::class.java),
157                                relation.pojoTypeName)
158                    }
159                } else {
160                    ParameterizedTypeName.get(ClassName.get(ArrayList::class.java),
161                            relation.pojoTypeName)
162                }
163
164                val canUseArrayMap = context.processingEnv.elementUtils
165                        .getTypeElement(AndroidTypeNames.ARRAY_MAP.toString()) != null
166                val mapClass = if (canUseArrayMap) {
167                    AndroidTypeNames.ARRAY_MAP
168                } else {
169                    ClassName.get(java.util.HashMap::class.java)
170                }
171                val tmpMapType = ParameterizedTypeName.get(mapClass, keyType, collectionTypeName)
172                val keyTypeMirror = keyTypeMirrorFor(context, affinity)
173                val set = context.processingEnv.elementUtils.getTypeElement("java.util.Set")
174                val keySet = context.processingEnv.typeUtils.getDeclaredType(set, keyTypeMirror)
175                val loadAllQuery = relation.createLoadAllSql()
176                val parsedQuery = SqlParser.parse(loadAllQuery)
177                context.checker.check(parsedQuery.errors.isEmpty(), relation.field.element,
178                        parsedQuery.errors.joinToString("\n"))
179                if (parsedQuery.errors.isEmpty()) {
180                    val resultInfo = context.databaseVerifier?.analyze(loadAllQuery)
181                    parsedQuery.resultInfo = resultInfo
182                    if (resultInfo?.error != null) {
183                        context.logger.e(relation.field.element,
184                                DatabaseVerificaitonErrors.cannotVerifyQuery(resultInfo.error))
185                    }
186                }
187                val resultInfo = parsedQuery.resultInfo
188
189                val queryParam = QueryParameter(
190                        name = RelationCollectorMethodWriter.KEY_SET_VARIABLE,
191                        sqlName = RelationCollectorMethodWriter.KEY_SET_VARIABLE,
192                        type = keySet,
193                        queryParamAdapter =
194                                context.typeAdapterStore.findQueryParameterAdapter(keySet))
195                val queryWriter = QueryWriter(
196                        parameters = listOf(queryParam),
197                        sectionToParamMapping = listOf(Pair(parsedQuery.bindSections.first(),
198                                queryParam)),
199                        query = parsedQuery
200                )
201
202                // row adapter that matches full response
203                fun getDefaultRowAdapter(): RowAdapter? {
204                    return context.typeAdapterStore.findRowAdapter(relation.pojoType, parsedQuery)
205                }
206                val rowAdapter = if (relation.projection.size == 1 && resultInfo != null &&
207                        (resultInfo.columns.size == 1 || resultInfo.columns.size == 2)) {
208                    // check for a column adapter first
209                    val cursorReader = context.typeAdapterStore.findCursorValueReader(
210                            relation.pojoType, resultInfo.columns.first().type)
211                    if (cursorReader == null) {
212                        getDefaultRowAdapter()
213                    } else {
214                        context.logger.d("Choosing cursor adapter for the return value since" +
215                                " the query returns only 1 or 2 columns and there is a cursor" +
216                                " adapter for the return type.")
217                        SingleColumnRowAdapter(cursorReader)
218                    }
219                } else {
220                    getDefaultRowAdapter()
221                }
222
223                if (rowAdapter == null) {
224                    context.logger.e(relation.field.element, CANNOT_FIND_QUERY_RESULT_ADAPTER)
225                    null
226                } else {
227                    RelationCollector(
228                            relation = relation,
229                            affinity = affinity,
230                            mapTypeName = tmpMapType,
231                            keyTypeName = keyType,
232                            collectionTypeName = collectionTypeName,
233                            queryWriter = queryWriter,
234                            rowAdapter = rowAdapter,
235                            loadAllQuery = parsedQuery
236                    )
237                }
238            }.filterNotNull()
239        }
240
241        private fun keyTypeMirrorFor(context: Context, affinity: SQLTypeAffinity): TypeMirror {
242            val types = context.processingEnv.typeUtils
243            val elements = context.processingEnv.elementUtils
244            return when (affinity) {
245                SQLTypeAffinity.INTEGER -> elements.getTypeElement("java.lang.Long").asType()
246                SQLTypeAffinity.REAL -> elements.getTypeElement("java.lang.Double").asType()
247                SQLTypeAffinity.TEXT -> context.COMMON_TYPES.STRING
248                SQLTypeAffinity.BLOB -> types.getArrayType(types.getPrimitiveType(TypeKind.BYTE))
249                else -> {
250                    context.COMMON_TYPES.STRING
251                }
252            }
253        }
254
255        private fun keyTypeFor(context: Context, affinity: SQLTypeAffinity): TypeName {
256            return when (affinity) {
257                SQLTypeAffinity.INTEGER -> TypeName.LONG.box()
258                SQLTypeAffinity.REAL -> TypeName.DOUBLE.box()
259                SQLTypeAffinity.TEXT -> TypeName.get(String::class.java)
260                SQLTypeAffinity.BLOB -> ArrayTypeName.of(TypeName.BYTE)
261                else -> {
262                    // no affinity select from type
263                    context.COMMON_TYPES.STRING.typeName()
264                }
265            }
266        }
267    }
268}
269