1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.arch.persistence.room.writer
18
19import android.arch.persistence.room.ext.L
20import android.arch.persistence.room.ext.RoomTypeNames
21import android.arch.persistence.room.ext.S
22import android.arch.persistence.room.ext.SupportDbTypeNames
23import android.arch.persistence.room.solver.CodeGenScope
24import android.arch.persistence.room.vo.Entity
25import android.arch.persistence.room.vo.FieldWithIndex
26import com.squareup.javapoet.ClassName
27import com.squareup.javapoet.MethodSpec
28import com.squareup.javapoet.ParameterSpec
29import com.squareup.javapoet.ParameterizedTypeName
30import com.squareup.javapoet.TypeName
31import com.squareup.javapoet.TypeSpec
32import javax.lang.model.element.Modifier.PUBLIC
33
34class EntityInsertionAdapterWriter(val entity: Entity, val onConflict: String) {
35    fun createAnonymous(classWriter: ClassWriter, dbParam : String): TypeSpec {
36        @Suppress("RemoveSingleExpressionStringTemplate")
37        return TypeSpec.anonymousClassBuilder("$L", dbParam).apply {
38            superclass(
39                    ParameterizedTypeName.get(RoomTypeNames.INSERTION_ADAPTER, entity.typeName)
40            )
41
42            // If there is an auto-increment primary key with primitive type, we consider 0 as
43            // not set. For such fields, we must generate a slightly different insertion SQL.
44            val primitiveAutoGenerateField = if (entity.primaryKey.autoGenerateId) {
45                entity.primaryKey.fields.firstOrNull()?.let { field ->
46                    field.statementBinder?.typeMirror()?.let { binderType ->
47                        if (binderType.kind.isPrimitive) {
48                            field
49                        } else {
50                            null
51                        }
52                    }
53                }
54            } else {
55                null
56            }
57            addMethod(MethodSpec.methodBuilder("createQuery").apply {
58                addAnnotation(Override::class.java)
59                returns(ClassName.get("java.lang", "String"))
60                addModifiers(PUBLIC)
61                val query =
62                        "INSERT OR $onConflict INTO `${entity.tableName}`(" +
63                                entity.fields.joinToString(",") {
64                                    "`${it.columnName}`"
65                                } + ") VALUES (" +
66                                entity.fields.joinToString(",") {
67                                    if (primitiveAutoGenerateField == it) {
68                                        "nullif(?, 0)"
69                                    } else {
70                                        "?"
71                                    }
72                                } + ")"
73                addStatement("return $S", query)
74            }.build())
75            addMethod(MethodSpec.methodBuilder("bind").apply {
76                val bindScope = CodeGenScope(classWriter)
77                addAnnotation(Override::class.java)
78                val stmtParam = "stmt"
79                addParameter(ParameterSpec.builder(SupportDbTypeNames.SQLITE_STMT,
80                        stmtParam).build())
81                val valueParam = "value"
82                addParameter(ParameterSpec.builder(entity.typeName, valueParam).build())
83                returns(TypeName.VOID)
84                addModifiers(PUBLIC)
85                val mapped = FieldWithIndex.byOrder(entity.fields)
86                FieldReadWriteWriter.bindToStatement(
87                        ownerVar = valueParam,
88                        stmtParamVar = stmtParam,
89                        fieldsWithIndices = mapped,
90                        scope = bindScope
91                )
92                addCode(bindScope.builder().build())
93            }.build())
94        }.build()
95    }
96}
97