1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package androidx.lifecycle
18
19import androidx.lifecycle.model.AdapterClass
20import androidx.lifecycle.model.EventMethodCall
21import androidx.lifecycle.model.getAdapterName
22import com.squareup.javapoet.AnnotationSpec
23import com.squareup.javapoet.ClassName
24import com.squareup.javapoet.FieldSpec
25import com.squareup.javapoet.JavaFile
26import com.squareup.javapoet.MethodSpec
27import com.squareup.javapoet.ParameterSpec
28import com.squareup.javapoet.TypeName
29import com.squareup.javapoet.TypeSpec
30import javax.annotation.processing.ProcessingEnvironment
31import javax.lang.model.element.Modifier
32import javax.lang.model.element.TypeElement
33import javax.tools.StandardLocation
34
35fun writeModels(infos: List<AdapterClass>, processingEnv: ProcessingEnvironment) {
36    infos.forEach({ writeAdapter(it, processingEnv) })
37}
38
39private val GENERATED_PACKAGE = "javax.annotation"
40private val GENERATED_NAME = "Generated"
41private val LIFECYCLE_EVENT = Lifecycle.Event::class.java
42
43private val T = "\$T"
44private val N = "\$N"
45private val L = "\$L"
46private val S = "\$S"
47
48private val OWNER_PARAM: ParameterSpec = ParameterSpec.builder(
49        ClassName.get(LifecycleOwner::class.java), "owner").build()
50private val EVENT_PARAM: ParameterSpec = ParameterSpec.builder(
51        ClassName.get(LIFECYCLE_EVENT), "event").build()
52private val ON_ANY_PARAM: ParameterSpec = ParameterSpec.builder(TypeName.BOOLEAN, "onAny").build()
53
54private val METHODS_LOGGER: ParameterSpec = ParameterSpec.builder(
55        ClassName.get(MethodCallsLogger::class.java), "logger").build()
56
57private const val HAS_LOGGER_VAR = "hasLogger"
58
59private fun writeAdapter(adapter: AdapterClass, processingEnv: ProcessingEnvironment) {
60    val receiverField: FieldSpec = FieldSpec.builder(ClassName.get(adapter.type), "mReceiver",
61            Modifier.FINAL).build()
62    val dispatchMethodBuilder = MethodSpec.methodBuilder("callMethods")
63            .returns(TypeName.VOID)
64            .addParameter(OWNER_PARAM)
65            .addParameter(EVENT_PARAM)
66            .addParameter(ON_ANY_PARAM)
67            .addParameter(METHODS_LOGGER)
68            .addModifiers(Modifier.PUBLIC)
69            .addAnnotation(Override::class.java)
70    val dispatchMethod = dispatchMethodBuilder.apply {
71
72        addStatement("boolean $L = $N != null", HAS_LOGGER_VAR, METHODS_LOGGER)
73        val callsByEventType = adapter.calls.groupBy { it.method.onLifecycleEvent.value }
74        beginControlFlow("if ($N)", ON_ANY_PARAM).apply {
75            writeMethodCalls(callsByEventType[Lifecycle.Event.ON_ANY] ?: emptyList(), receiverField)
76        }.endControlFlow()
77
78        callsByEventType
79                .filterKeys { key -> key != Lifecycle.Event.ON_ANY }
80                .forEach { (event, calls) ->
81                    beginControlFlow("if ($N == $T.$L)", EVENT_PARAM, LIFECYCLE_EVENT, event)
82                    writeMethodCalls(calls, receiverField)
83                    endControlFlow()
84                }
85    }.build()
86
87    val receiverParam = ParameterSpec.builder(
88            ClassName.get(adapter.type), "receiver").build()
89
90    val syntheticMethods = adapter.syntheticMethods.map {
91        val method = MethodSpec.methodBuilder(syntheticName(it))
92                .returns(TypeName.VOID)
93                .addModifiers(Modifier.PUBLIC)
94                .addModifiers(Modifier.STATIC)
95                .addParameter(receiverParam)
96        if (it.parameters.size >= 1) {
97            method.addParameter(OWNER_PARAM)
98        }
99        if (it.parameters.size == 2) {
100            method.addParameter(EVENT_PARAM)
101        }
102
103        val count = it.parameters.size
104        val paramString = generateParamString(count)
105        method.addStatement("$N.$L($paramString)", receiverParam, it.name(),
106                *takeParams(count, OWNER_PARAM, EVENT_PARAM))
107        method.build()
108    }
109
110    val constructor = MethodSpec.constructorBuilder()
111            .addParameter(receiverParam)
112            .addStatement("this.$N = $N", receiverField, receiverParam)
113            .build()
114
115    val adapterName = getAdapterName(adapter.type)
116    val adapterTypeSpecBuilder = TypeSpec.classBuilder(adapterName)
117            .addModifiers(Modifier.PUBLIC)
118            .addSuperinterface(ClassName.get(GeneratedAdapter::class.java))
119            .addField(receiverField)
120            .addMethod(constructor)
121            .addMethod(dispatchMethod)
122            .addMethods(syntheticMethods)
123
124    addGeneratedAnnotationIfAvailable(adapterTypeSpecBuilder, processingEnv)
125
126    JavaFile.builder(adapter.type.getPackageQName(), adapterTypeSpecBuilder.build())
127            .build().writeTo(processingEnv.filer)
128
129    generateKeepRule(adapter.type, processingEnv)
130}
131
132private fun addGeneratedAnnotationIfAvailable(adapterTypeSpecBuilder: TypeSpec.Builder,
133                                              processingEnv: ProcessingEnvironment) {
134    val generatedAnnotationAvailable = processingEnv
135            .elementUtils
136            .getTypeElement(GENERATED_PACKAGE + "." + GENERATED_NAME) != null
137    if (generatedAnnotationAvailable) {
138        val generatedAnnotationSpec =
139                AnnotationSpec.builder(ClassName.get(GENERATED_PACKAGE, GENERATED_NAME)).addMember(
140                        "value",
141                        S,
142                        LifecycleProcessor::class.java.canonicalName).build()
143        adapterTypeSpecBuilder.addAnnotation(generatedAnnotationSpec)
144    }
145}
146
147private fun generateKeepRule(type: TypeElement, processingEnv: ProcessingEnvironment) {
148    val adapterClass = type.getPackageQName() + "." + getAdapterName(type)
149    val observerClass = type.toString()
150    val keepRule = """# Generated keep rule for Lifecycle observer adapter.
151        |-if class $observerClass {
152        |    <init>(...);
153        |}
154        |-keep class $adapterClass {
155        |    <init>(...);
156        |}
157        |""".trimMargin()
158
159    // Write the keep rule to the META-INF/proguard directory of the Jar file. The file name
160    // contains the fully qualified observer name so that file names are unique. This will allow any
161    // jar file merging to not overwrite keep rule files.
162    val path = "META-INF/proguard/$observerClass.pro"
163    val out = processingEnv.filer.createResource(StandardLocation.CLASS_OUTPUT, "", path)
164    out.openWriter().use { it.write(keepRule) }
165}
166
167private fun MethodSpec.Builder.writeMethodCalls(calls: List<EventMethodCall>,
168                                                receiverField: FieldSpec) {
169    calls.forEach { (method, syntheticAccess) ->
170        val count = method.method.parameters.size
171        val callType = 1 shl count
172        val methodName = method.method.name()
173        beginControlFlow("if (!$L || $N.approveCall($S, $callType))",
174                HAS_LOGGER_VAR, METHODS_LOGGER, methodName).apply {
175
176            if (syntheticAccess == null) {
177                val paramString = generateParamString(count)
178                addStatement("$N.$L($paramString)", receiverField,
179                        methodName,
180                        *takeParams(count, OWNER_PARAM, EVENT_PARAM))
181            } else {
182                val originalType = syntheticAccess
183                val paramString = generateParamString(count + 1)
184                val className = ClassName.get(originalType.getPackageQName(),
185                        getAdapterName(originalType))
186                addStatement("$T.$L($paramString)", className,
187                        syntheticName(method.method),
188                        *takeParams(count + 1, receiverField, OWNER_PARAM, EVENT_PARAM))
189            }
190        }.endControlFlow()
191    }
192    addStatement("return")
193}
194
195private fun takeParams(count: Int, vararg params: Any) = params.take(count).toTypedArray()
196
197private fun generateParamString(count: Int) = (0 until count).joinToString(",") { N }
198