1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Contains all the entry points to the C Neural Networks API.
18// We do basic validation of the operands and then call the class
19// that implements the functionality.
20
21#define LOG_TAG "NeuralNetworks"
22
23#include "NeuralNetworks.h"
24
25#include "Callbacks.h"
26#include "CompilationBuilder.h"
27#include "ExecutionBuilder.h"
28#include "Manager.h"
29#include "Memory.h"
30#include "NeuralNetworksOEM.h"
31#include "ModelBuilder.h"
32
33#include <memory>
34#include <vector>
35
36// Make sure the constants defined in the header files have not changed values.
37// IMPORTANT: When adding new values, update kNumberOfDataTypes or kNumberOfDataTypesOEM
38// in Utils.h.
39static_assert(ANEURALNETWORKS_FLOAT32 == 0, "ANEURALNETWORKS_FLOAT32 has changed");
40static_assert(ANEURALNETWORKS_INT32 == 1, "ANEURALNETWORKS_INT32 has changed");
41static_assert(ANEURALNETWORKS_UINT32 == 2, "ANEURALNETWORKS_UINT32 has changed");
42static_assert(ANEURALNETWORKS_TENSOR_FLOAT32 == 3,
43              "ANEURALNETWORKS_TENSOR_FLOAT32 has changed");
44static_assert(ANEURALNETWORKS_TENSOR_INT32 == 4, "ANEURALNETWORKS_TENSOR_INT32 has changed");
45static_assert(ANEURALNETWORKS_TENSOR_QUANT8_ASYMM == 5,
46              "ANEURALNETWORKS_TENSOR_QUANT8_ASYMM has changed");
47static_assert(ANEURALNETWORKS_OEM_SCALAR == 10000, "ANEURALNETWORKS_OEM_SCALAR has changed");
48static_assert(ANEURALNETWORKS_TENSOR_OEM_BYTE == 10001,
49              "ANEURALNETWORKS_TENSOR_OEM_BYTE has changed");
50
51// IMPORTANT: When adding new values, update kNumberOfOperationTypes or
52// kNumberOfOperationTypesOEMin Utils.h.
53static_assert(ANEURALNETWORKS_ADD == 0, "ANEURALNETWORKS_ADD has changed");
54static_assert(ANEURALNETWORKS_AVERAGE_POOL_2D == 1,
55              "ANEURALNETWORKS_AVERAGE_POOL_2D has changed");
56static_assert(ANEURALNETWORKS_CONCATENATION == 2, "ANEURALNETWORKS_CONCATENATION has changed");
57static_assert(ANEURALNETWORKS_CONV_2D == 3, "ANEURALNETWORKS_CONV_2D has changed");
58static_assert(ANEURALNETWORKS_DEPTHWISE_CONV_2D == 4,
59              "ANEURALNETWORKS_DEPTHWISE_CONV_2D has changed");
60static_assert(ANEURALNETWORKS_DEPTH_TO_SPACE == 5,
61              "ANEURALNETWORKS_DEPTH_TO_SPACE has changed");
62static_assert(ANEURALNETWORKS_DEQUANTIZE == 6, "ANEURALNETWORKS_DEQUANTIZE has changed");
63static_assert(ANEURALNETWORKS_EMBEDDING_LOOKUP == 7,
64              "ANEURALNETWORKS_EMBEDDING_LOOKUP has changed");
65static_assert(ANEURALNETWORKS_FLOOR == 8, "ANEURALNETWORKS_FLOOR has changed");
66static_assert(ANEURALNETWORKS_FULLY_CONNECTED == 9,
67              "ANEURALNETWORKS_FULLY_CONNECTED has changed");
68static_assert(ANEURALNETWORKS_HASHTABLE_LOOKUP == 10,
69              "ANEURALNETWORKS_HASHTABLE_LOOKUP has changed");
70static_assert(ANEURALNETWORKS_L2_NORMALIZATION == 11,
71              "ANEURALNETWORKS_L2_NORMALIZATION has changed");
72static_assert(ANEURALNETWORKS_L2_POOL_2D == 12, "ANEURALNETWORKS_L2_POOL has changed");
73static_assert(ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION == 13,
74              "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION has changed");
75static_assert(ANEURALNETWORKS_LOGISTIC == 14, "ANEURALNETWORKS_LOGISTIC has changed");
76static_assert(ANEURALNETWORKS_LSH_PROJECTION == 15,
77              "ANEURALNETWORKS_LSH_PROJECTION has changed");
78static_assert(ANEURALNETWORKS_LSTM == 16, "ANEURALNETWORKS_LSTM has changed");
79static_assert(ANEURALNETWORKS_MAX_POOL_2D == 17, "ANEURALNETWORKS_MAX_POOL has changed");
80static_assert(ANEURALNETWORKS_MUL == 18, "ANEURALNETWORKS_MUL has changed");
81static_assert(ANEURALNETWORKS_RELU == 19, "ANEURALNETWORKS_RELU has changed");
82static_assert(ANEURALNETWORKS_RELU1 == 20, "ANEURALNETWORKS_RELU1 has changed");
83static_assert(ANEURALNETWORKS_RELU6 == 21, "ANEURALNETWORKS_RELU6 has changed");
84static_assert(ANEURALNETWORKS_RESHAPE == 22, "ANEURALNETWORKS_RESHAPE has changed");
85static_assert(ANEURALNETWORKS_RESIZE_BILINEAR == 23,
86              "ANEURALNETWORKS_RESIZE_BILINEAR has changed");
87static_assert(ANEURALNETWORKS_RNN == 24, "ANEURALNETWORKS_RNN has changed");
88static_assert(ANEURALNETWORKS_SOFTMAX == 25, "ANEURALNETWORKS_SOFTMAX has changed");
89static_assert(ANEURALNETWORKS_SPACE_TO_DEPTH == 26,
90              "ANEURALNETWORKS_SPACE_TO_DEPTH has changed");
91static_assert(ANEURALNETWORKS_SVDF == 27, "ANEURALNETWORKS_SVDF has changed");
92static_assert(ANEURALNETWORKS_TANH == 28, "ANEURALNETWORKS_TANH has changed");
93static_assert(ANEURALNETWORKS_OEM_OPERATION == 10000,
94              "ANEURALNETWORKS_OEM_OPERATION has changed");
95
96static_assert(ANEURALNETWORKS_FUSED_NONE == 0, "ANEURALNETWORKS_FUSED_NONE has changed");
97static_assert(ANEURALNETWORKS_FUSED_RELU == 1, "ANEURALNETWORKS_FUSED_RELU has changed");
98static_assert(ANEURALNETWORKS_FUSED_RELU1 == 2, "ANEURALNETWORKS_FUSED_RELU1 has changed");
99static_assert(ANEURALNETWORKS_FUSED_RELU6 == 3, "ANEURALNETWORKS_FUSED_RELU6 has changed");
100
101static_assert(ANEURALNETWORKS_PREFER_LOW_POWER == 0,
102              "ANEURALNETWORKS_PREFER_LOW_POWER has changed");
103static_assert(ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER == 1,
104              "ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER has changed");
105static_assert(ANEURALNETWORKS_PREFER_SUSTAINED_SPEED == 2,
106              "ANEURALNETWORKS_PREFER_SUSTAINED_SPEED has changed");
107
108static_assert(ANEURALNETWORKS_NO_ERROR == 0, "ANEURALNETWORKS_NO_ERROR has changed");
109static_assert(ANEURALNETWORKS_OUT_OF_MEMORY == 1, "ANEURALNETWORKS_OUT_OF_MEMORY has changed");
110static_assert(ANEURALNETWORKS_INCOMPLETE == 2, "ANEURALNETWORKS_INCOMPLETE has changed");
111static_assert(ANEURALNETWORKS_UNEXPECTED_NULL == 3,
112              "ANEURALNETWORKS_UNEXPECTED_NULL has changed");
113static_assert(ANEURALNETWORKS_BAD_DATA == 4, "ANEURALNETWORKS_BAD_DATA has changed");
114static_assert(ANEURALNETWORKS_OP_FAILED == 5, "ANEURALNETWORKS_OP_FAILED has changed");
115static_assert(ANEURALNETWORKS_BAD_STATE == 6, "ANEURALNETWORKS_BAD_STATE has changed");
116
117static_assert(ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES == 128,
118              "ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES has changed");
119
120// Make sure that the constants are compatible with the values defined in
121// hardware/interfaces/neuralnetworks/1.0/types.hal.
122static_assert(static_cast<int32_t>(OperandType::OEM) == ANEURALNETWORKS_OEM_SCALAR,
123              "OEM != ANEURALNETWORKS_OEM");
124static_assert(static_cast<int32_t>(OperandType::FLOAT32) == ANEURALNETWORKS_FLOAT32,
125              "FLOAT32 != ANEURALNETWORKS_FLOAT32");
126static_assert(static_cast<int32_t>(OperandType::INT32) == ANEURALNETWORKS_INT32,
127              "INT32 != ANEURALNETWORKS_INT32");
128static_assert(static_cast<int32_t>(OperandType::UINT32) == ANEURALNETWORKS_UINT32,
129              "UINT32 != ANEURALNETWORKS_UINT32");
130static_assert(static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) == ANEURALNETWORKS_TENSOR_OEM_BYTE,
131              "TENSOR_OEM_BYTE != ANEURALNETWORKS_TENSOR_OEM_BYTE");
132static_assert(static_cast<int32_t>(OperandType::TENSOR_FLOAT32) == ANEURALNETWORKS_TENSOR_FLOAT32,
133              "TENSOR_FLOAT32 != ANEURALNETWORKS_TENSOR_FLOAT32");
134static_assert(static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) ==
135                          ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
136              "TENSOR_QUANT8_ASYMM != ANEURALNETWORKS_TENSOR_QUANT8_ASYMM");
137
138static_assert(static_cast<int32_t>(OperationType::ADD) == ANEURALNETWORKS_ADD,
139              "OperationType::ADD != ANEURALNETWORKS_ADD");
140static_assert(static_cast<int32_t>(OperationType::AVERAGE_POOL_2D) ==
141                          ANEURALNETWORKS_AVERAGE_POOL_2D,
142              "OperationType::AVERAGE_POOL_2D != ANEURALNETWORKS_AVERAGE_POOL_2D");
143static_assert(static_cast<int32_t>(OperationType::CONV_2D) == ANEURALNETWORKS_CONV_2D,
144              "OperationType::CONV_2D != ANEURALNETWORKS_CONV_2D");
145static_assert(static_cast<int32_t>(OperationType::DEPTHWISE_CONV_2D) ==
146                          ANEURALNETWORKS_DEPTHWISE_CONV_2D,
147              "OperationType::DEPTHWISE_CONV_2D != ANEURALNETWORKS_DEPTHWISE_CONV_2D");
148static_assert(static_cast<int32_t>(OperationType::DEPTH_TO_SPACE) ==
149                          ANEURALNETWORKS_DEPTH_TO_SPACE,
150              "OperationType::DEPTH_TO_SPACE != ANEURALNETWORKS_DEPTH_TO_SPACE");
151static_assert(static_cast<int32_t>(OperationType::DEQUANTIZE) == ANEURALNETWORKS_DEQUANTIZE,
152              "OperationType::DEQUANTIZE != ANEURALNETWORKS_DEQUANTIZE");
153static_assert(static_cast<int32_t>(OperationType::EMBEDDING_LOOKUP) ==
154                          ANEURALNETWORKS_EMBEDDING_LOOKUP,
155              "OperationType::EMBEDDING_LOOKUP != ANEURALNETWORKS_EMBEDDING_LOOKUP");
156static_assert(static_cast<int32_t>(OperationType::FLOOR) == ANEURALNETWORKS_FLOOR,
157              "OperationType::FLOOR != ANEURALNETWORKS_FLOOR");
158static_assert(static_cast<int32_t>(OperationType::FULLY_CONNECTED) ==
159                          ANEURALNETWORKS_FULLY_CONNECTED,
160              "OperationType::FULLY_CONNECTED != ANEURALNETWORKS_FULLY_CONNECTED");
161static_assert(static_cast<int32_t>(OperationType::HASHTABLE_LOOKUP) ==
162                          ANEURALNETWORKS_HASHTABLE_LOOKUP,
163              "OperationType::HASHTABLE_LOOKUP != ANEURALNETWORKS_HASHTABLE_LOOKUP");
164static_assert(static_cast<int32_t>(OperationType::L2_NORMALIZATION) ==
165                          ANEURALNETWORKS_L2_NORMALIZATION,
166              "OperationType::L2_NORMALIZATION != ANEURALNETWORKS_L2_NORMALIZATION");
167static_assert(static_cast<int32_t>(OperationType::L2_POOL_2D) == ANEURALNETWORKS_L2_POOL_2D,
168              "OperationType::L2_POOL_2D != ANEURALNETWORKS_L2_POOL_2D");
169static_assert(static_cast<int32_t>(OperationType::LOCAL_RESPONSE_NORMALIZATION) ==
170                          ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION,
171              "OperationType::LOCAL_RESPONSE_NORMALIZATION != "
172              "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION");
173static_assert(static_cast<int32_t>(OperationType::LOGISTIC) == ANEURALNETWORKS_LOGISTIC,
174              "OperationType::LOGISTIC != ANEURALNETWORKS_LOGISTIC");
175static_assert(static_cast<int32_t>(OperationType::LSH_PROJECTION) ==
176                          ANEURALNETWORKS_LSH_PROJECTION,
177              "OperationType::LSH_PROJECTION != ANEURALNETWORKS_LSH_PROJECTION");
178static_assert(static_cast<int32_t>(OperationType::LSTM) == ANEURALNETWORKS_LSTM,
179              "OperationType::LSTM != ANEURALNETWORKS_LSTM");
180static_assert(static_cast<int32_t>(OperationType::MAX_POOL_2D) == ANEURALNETWORKS_MAX_POOL_2D,
181              "OperationType::MAX_POOL_2D != ANEURALNETWORKS_MAX_POOL_2D");
182static_assert(static_cast<int32_t>(OperationType::MUL) == ANEURALNETWORKS_MUL,
183              "OperationType::MUL != ANEURALNETWORKS_MUL");
184static_assert(static_cast<int32_t>(OperationType::RELU) == ANEURALNETWORKS_RELU,
185              "OperationType::RELU != ANEURALNETWORKS_RELU");
186static_assert(static_cast<int32_t>(OperationType::RELU1) == ANEURALNETWORKS_RELU1,
187              "OperationType::RELU1 != ANEURALNETWORKS_RELU1");
188static_assert(static_cast<int32_t>(OperationType::RELU6) == ANEURALNETWORKS_RELU6,
189              "OperationType::RELU6 != ANEURALNETWORKS_RELU6");
190static_assert(static_cast<int32_t>(OperationType::RESHAPE) == ANEURALNETWORKS_RESHAPE,
191              "OperationType::RESHAPE != ANEURALNETWORKS_RESHAPE");
192static_assert(static_cast<int32_t>(OperationType::RESIZE_BILINEAR) ==
193                          ANEURALNETWORKS_RESIZE_BILINEAR,
194              "OperationType::RESIZE_BILINEAR != ANEURALNETWORKS_RESIZE_BILINEAR");
195static_assert(static_cast<int32_t>(OperationType::RNN) == ANEURALNETWORKS_RNN,
196              "OperationType::RNN != ANEURALNETWORKS_RNN");
197static_assert(static_cast<int32_t>(OperationType::SOFTMAX) == ANEURALNETWORKS_SOFTMAX,
198              "OperationType::SOFTMAX != ANEURALNETWORKS_SOFTMAX");
199static_assert(static_cast<int32_t>(OperationType::SPACE_TO_DEPTH) ==
200                          ANEURALNETWORKS_SPACE_TO_DEPTH,
201              "OperationType::SPACE_TO_DEPTH != ANEURALNETWORKS_SPACE_TO_DEPTH");
202static_assert(static_cast<int32_t>(OperationType::SVDF) == ANEURALNETWORKS_SVDF,
203              "OperationType::SVDF != ANEURALNETWORKS_SVDF");
204static_assert(static_cast<int32_t>(OperationType::TANH) == ANEURALNETWORKS_TANH,
205              "OperationType::TANH != ANEURALNETWORKS_TANH");
206
207static_assert(static_cast<int32_t>(FusedActivationFunc::NONE) == ANEURALNETWORKS_FUSED_NONE,
208              "FusedActivationFunc::NONE != ANEURALNETWORKS_FUSED_NONE");
209static_assert(static_cast<int32_t>(FusedActivationFunc::RELU) == ANEURALNETWORKS_FUSED_RELU,
210              "FusedActivationFunc::RELU != ANEURALNETWORKS_FUSED_RELU");
211static_assert(static_cast<int32_t>(FusedActivationFunc::RELU1) == ANEURALNETWORKS_FUSED_RELU1,
212              "FusedActivationFunc::RELU1 != ANEURALNETWORKS_FUSED_RELU1");
213static_assert(static_cast<int32_t>(FusedActivationFunc::RELU6) == ANEURALNETWORKS_FUSED_RELU6,
214              "FusedActivationFunc::RELU6 != ANEURALNETWORKS_FUSED_RELU6");
215
216using android::sp;
217using namespace android::nn;
218
219int ANeuralNetworksMemory_createFromFd(size_t size, int prot, int fd, size_t offset,
220                                       ANeuralNetworksMemory** memory) {
221    *memory = nullptr;
222    std::unique_ptr<MemoryFd> m = std::make_unique<MemoryFd>();
223    if (m == nullptr) {
224        return ANEURALNETWORKS_OUT_OF_MEMORY;
225    }
226    int n = m->set(size, prot, fd, offset);
227    if (n != ANEURALNETWORKS_NO_ERROR) {
228        return n;
229    }
230    *memory = reinterpret_cast<ANeuralNetworksMemory*>(m.release());
231    return ANEURALNETWORKS_NO_ERROR;
232}
233
234void ANeuralNetworksMemory_free(ANeuralNetworksMemory* memory) {
235    // No validation.  Free of nullptr is valid.
236    Memory* m = reinterpret_cast<Memory*>(memory);
237    delete m;
238}
239
240int ANeuralNetworksModel_create(ANeuralNetworksModel** model) {
241    initVLogMask();
242    if (!model) {
243        LOG(ERROR) << "ANeuralNetworksModel_create passed a nullptr";
244        return ANEURALNETWORKS_UNEXPECTED_NULL;
245    }
246    ModelBuilder* m = new ModelBuilder();
247    if (m == nullptr) {
248        *model = nullptr;
249        return ANEURALNETWORKS_OUT_OF_MEMORY;
250    }
251    *model = reinterpret_cast<ANeuralNetworksModel*>(m);
252    return ANEURALNETWORKS_NO_ERROR;
253}
254
255void ANeuralNetworksModel_free(ANeuralNetworksModel* model) {
256    // No validation.  Free of nullptr is valid.
257    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
258    delete m;
259}
260
261int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) {
262    if (!model) {
263        LOG(ERROR) << "ANeuralNetworksModel_finish passed a nullptr";
264        return ANEURALNETWORKS_UNEXPECTED_NULL;
265    }
266    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
267    return m->finish();
268}
269
270int ANeuralNetworksModel_addOperand(ANeuralNetworksModel* model,
271                                    const ANeuralNetworksOperandType* type) {
272    if (!model || !type) {
273        LOG(ERROR) << "ANeuralNetworksModel_addOperand passed a nullptr";
274        return ANEURALNETWORKS_UNEXPECTED_NULL;
275    }
276    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
277    return m->addOperand(*type);
278}
279
280int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model, int32_t index,
281                                         const void* buffer, size_t length) {
282    if (!model || !buffer) {
283        LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr";
284        return ANEURALNETWORKS_UNEXPECTED_NULL;
285    }
286    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
287    return m->setOperandValue(index, buffer, length);
288}
289
290int ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel* model, int32_t index,
291                                                   const ANeuralNetworksMemory* memory,
292                                                   size_t offset, size_t length) {
293    if (!model || !memory) {
294        LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr";
295        return ANEURALNETWORKS_UNEXPECTED_NULL;
296    }
297    const Memory* mem = reinterpret_cast<const Memory*>(memory);
298    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
299    return m->setOperandValueFromMemory(index, mem, offset, length);
300}
301
302int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model,
303                                      ANeuralNetworksOperationType type, uint32_t inputCount,
304                                      const uint32_t* inputs, uint32_t outputCount,
305                                      const uint32_t* outputs) {
306    if (!model || !inputs || !outputs) {
307        LOG(ERROR) << "ANeuralNetworksModel_addOperation passed a nullptr";
308        return ANEURALNETWORKS_UNEXPECTED_NULL;
309    }
310    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
311    return m->addOperation(type, inputCount, inputs, outputCount, outputs);
312}
313
314int ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel* model, uint32_t inputCount,
315                                                  const uint32_t* inputs, uint32_t outputCount,
316                                                  const uint32_t* outputs) {
317    if (!model || !inputs || !outputs) {
318        LOG(ERROR) << ("ANeuralNetworksModel_identifyInputsAndOutputs passed a nullptr");
319        return ANEURALNETWORKS_UNEXPECTED_NULL;
320    }
321    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
322    return m->identifyInputsAndOutputs(inputCount, inputs, outputCount, outputs);
323}
324
325int ANeuralNetworksCompilation_create(ANeuralNetworksModel* model,
326                                      ANeuralNetworksCompilation** compilation) {
327    if (!model || !compilation) {
328        LOG(ERROR) << "ANeuralNetworksCompilation_create passed a nullptr";
329        return ANEURALNETWORKS_UNEXPECTED_NULL;
330    }
331
332    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
333    CompilationBuilder* c = nullptr;
334    int result = m->createCompilation(&c);
335    *compilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
336    return result;
337}
338
339void ANeuralNetworksCompilation_free(ANeuralNetworksCompilation* compilation) {
340    // No validation.  Free of nullptr is valid.
341    // TODO specification says that a compilation-in-flight can be deleted
342    CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
343    delete c;
344}
345
346int ANeuralNetworksCompilation_setPreference(ANeuralNetworksCompilation* compilation,
347                                             int32_t preference) {
348    if (!compilation) {
349        LOG(ERROR) << "ANeuralNetworksCompilation_setPreference passed a nullptr";
350        return ANEURALNETWORKS_UNEXPECTED_NULL;
351    }
352    CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
353    return c->setPreference(preference);
354}
355
356int ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation* compilation) {
357    if (!compilation) {
358        LOG(ERROR) << "ANeuralNetworksCompilation_finish passed a nullptr";
359        return ANEURALNETWORKS_UNEXPECTED_NULL;
360    }
361    CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
362    return c->finish();
363}
364
365int ANeuralNetworksExecution_create(ANeuralNetworksCompilation* compilation,
366                                    ANeuralNetworksExecution** execution) {
367    if (!compilation || !execution) {
368        LOG(ERROR) << "ANeuralNetworksExecution_create passed a nullptr";
369        return ANEURALNETWORKS_UNEXPECTED_NULL;
370    }
371
372    CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
373    ExecutionBuilder* r = nullptr;
374    int result = c->createExecution(&r);
375    *execution = reinterpret_cast<ANeuralNetworksExecution*>(r);
376    return result;
377}
378
379void ANeuralNetworksExecution_free(ANeuralNetworksExecution* execution) {
380    // TODO specification says that an execution-in-flight can be deleted
381    // No validation.  Free of nullptr is valid.
382    ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
383    delete r;
384}
385
386int ANeuralNetworksExecution_setInput(ANeuralNetworksExecution* execution, int32_t index,
387                                      const ANeuralNetworksOperandType* type, const void* buffer,
388                                      size_t length) {
389    // TODO: For a non-optional input, also verify that buffer is not null.
390    if (!execution) {
391        LOG(ERROR) << "ANeuralNetworksExecution_setInput passed a nullptr";
392        return ANEURALNETWORKS_UNEXPECTED_NULL;
393    }
394    ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
395    return r->setInput(index, type, buffer, length);
396}
397
398int ANeuralNetworksExecution_setInputFromMemory(ANeuralNetworksExecution* execution, int32_t index,
399                                                const ANeuralNetworksOperandType* type,
400                                                const ANeuralNetworksMemory* memory, size_t offset,
401                                                size_t length) {
402    if (!execution || !memory) {
403        LOG(ERROR) << "ANeuralNetworksExecution_setInputFromMemory passed a nullptr";
404        return ANEURALNETWORKS_UNEXPECTED_NULL;
405    }
406
407    const Memory* m = reinterpret_cast<const Memory*>(memory);
408    ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
409    return r->setInputFromMemory(index, type, m, offset, length);
410}
411
412int ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution* execution, int32_t index,
413                                       const ANeuralNetworksOperandType* type, void* buffer,
414                                       size_t length) {
415    if (!execution || !buffer) {
416        LOG(ERROR) << "ANeuralNetworksExecution_setOutput passed a nullptr";
417        return ANEURALNETWORKS_UNEXPECTED_NULL;
418    }
419    ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
420    return r->setOutput(index, type, buffer, length);
421}
422
423int ANeuralNetworksExecution_setOutputFromMemory(ANeuralNetworksExecution* execution, int32_t index,
424                                                 const ANeuralNetworksOperandType* type,
425                                                 const ANeuralNetworksMemory* memory, size_t offset,
426                                                 size_t length) {
427    if (!execution || !memory) {
428        LOG(ERROR) << "ANeuralNetworksExecution_setOutputFromMemory passed a nullptr";
429        return ANEURALNETWORKS_UNEXPECTED_NULL;
430    }
431
432    ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
433    const Memory* m = reinterpret_cast<const Memory*>(memory);
434    return r->setOutputFromMemory(index, type, m, offset, length);
435}
436
437int ANeuralNetworksExecution_startCompute(ANeuralNetworksExecution* execution,
438                                          ANeuralNetworksEvent** event) {
439    if (!execution || !event) {
440        LOG(ERROR) << "ANeuralNetworksExecution_startCompute passed a nullptr";
441        return ANEURALNETWORKS_UNEXPECTED_NULL;
442    }
443    // TODO validate the rest
444
445    ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
446
447    // Dynamically allocate an sp to wrap an ExecutionCallback, seen in the NN
448    // API as an abstract event object. The sp<ExecutionCallback> object is
449    // returned when the execution has been successfully launched, otherwise a
450    // nullptr is returned. The sp is used for ref-counting purposes. Without
451    // it, the HIDL service could attempt to communicate with a dead callback
452    // object.
453    std::unique_ptr<sp<ExecutionCallback>> e = std::make_unique<sp<ExecutionCallback>>();
454    *event = nullptr;
455
456    int n = r->startCompute(e.get());
457    if (n != ANEURALNETWORKS_NO_ERROR) {
458        return n;
459    }
460    *event = reinterpret_cast<ANeuralNetworksEvent*>(e.release());
461    return ANEURALNETWORKS_NO_ERROR;
462}
463
464int ANeuralNetworksEvent_wait(ANeuralNetworksEvent* event) {
465    if (event == nullptr) {
466        LOG(ERROR) << "ANeuralNetworksEvent_wait passed a nullptr";
467        return ANEURALNETWORKS_UNEXPECTED_NULL;
468    }
469
470    sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event);
471    (*e)->wait();
472    return ANEURALNETWORKS_NO_ERROR;
473}
474
475void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) {
476    // No validation.  Free of nullptr is valid.
477    if (event) {
478        sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event);
479        (*e)->wait();
480        delete e;
481    }
482}
483