1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
18#define ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
19
20#include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
21#include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
22#include <chrono>
23#include <condition_variable>
24#include <functional>
25#include <hidl/MQDescriptor.h>
26#include <hidl/Status.h>
27#include <mutex>
28#include <thread>
29
30namespace android {
31namespace hardware {
32namespace neuralnetworks {
33namespace V1_0 {
34namespace implementation {
35
36using ::android::hardware::hidl_array;
37using ::android::hardware::hidl_memory;
38using ::android::hardware::hidl_string;
39using ::android::hardware::hidl_vec;
40using ::android::hardware::Return;
41using ::android::hardware::Void;
42using ::android::sp;
43
44/**
45 * The CallbackBase class is used internally by the NeuralNetworks runtime to
46 * synchronize between different threads. An asynchronous task is launched
47 * paired with a callback object. When a client thread requires the output being
48 * generated by the asynchronous task, the client thread can wait for the result
49 * and be blocked until it has completed or a timeout condition has been
50 * reached. Any wait* may safely be called concurrently, even on the same
51 * callback object. When the asynchronous task has finished its workload, it
52 * must immediately call "notify". If the asynchronous task has failed to launch,
53 * the function that tried to launch the asynchronous task must immediately call
54 * "notify". This "notify" call awakens any client threads waiting on the
55 * callback object.
56 *
57 * callback object. When the asynchronous task has finished its workload or has
58 * failed to launch, it must immediately call "notify", awakening any client
59 * threads waiting on the callback object.
60 *
61 * The CallbackBase class implements some of the base synchronization common to
62 * both PrepareModelCallback and ExecutionCallback. For consistency, any HIDL
63 * callback class must inherit from CallbackBase as well as the HIDL callback
64 * interface it implements.
65 *
66 * This class exists to enable synchronization across HIDL. When synchronization
67 * is only required in the same process, consider using std::future, std::mutex,
68 * std::condition_variable, or std::experimental::latch instead.
69 */
70class CallbackBase {
71 public:
72    CallbackBase();
73    ~CallbackBase();
74
75    /**
76     * CallbackBase::wait blocks until notify has been called on the callback
77     * object.
78     */
79    void wait();
80
81    /**
82     * CallbackBase::wait_for blocks until notify has been called on the
83     * callback object or the time duration from the time the wait_for function
84     * was called has expired, whichever comes first.
85     *
86     * @return Status std::cv_status::no_timeout if the callback was notified
87     *                before the time duration expired, std::cv_status::timeout
88     *                otherwise.
89     */
90    template<class Rep, class Period>
91    std::cv_status wait_for(const std::chrono::duration<Rep,Period>& timeout_duration);
92
93    /**
94     * CallbackBase::on_finish binds a function to the callback object. This
95     * bound function will be executed when CallbackBase::notify is called,
96     * before any calls to wait* return. (Note that CallbackBase::wait_for can
97     * return std::cv_status::timeout before CallbackBase::notify is called for
98     * the first time, and hence before the bound function is executed.)
99     *
100     * The bound function must not synchronize with or otherwise access the
101     * callback object it is bound to, as this could cause a deadlock.
102     *
103     * CallbackBase::on_finish can be called at most once on a given callback
104     * object, and the call to CallbackBase::on_finish must finish before
105     * CallbackBase::notify is called.
106     *
107     * @param post_work Function to be invoked the first time
108     *                  CallbackBase::notify is called. Must have a target --
109     *                  i.e., must not compare equal to nullptr. post_work
110     *                  returns true if it successfully completes, false if it
111     *                  fails.
112     * @return bool True if the function was successfully bound, false if
113     *              unsuccessful.
114     *
115     * TODO: Why does the return value of the callback matter?
116     */
117    bool on_finish(std::function<bool(void)> post_work);
118
119    /**
120     * CallbackBase::bind_thread binds a thread to the event for later use by
121     * CallbackBase::join_thread.
122     *
123     * The thread must be passed using std::move.
124     *
125     * Once a thread is bound with CallbackBase::bind_thread, the client code
126     * should ensure that one of the following occurs before the event is
127     * destroyed:
128     * - CallbackBase::join_thread has been called.
129     * - CallbackBase::wait has been called.
130     * - CallbackBase::wait_for has been called and returned other than
131     *   std::cv_status::no_timeout.
132     *
133     * The bound thread shall not call any CallbackBase method with the
134     * exception of CallbackBase::notify, which it must call when the thread has
135     * finished its computation.
136     *
137     * CallbackBase::bind_thread can be called at most once on a given callback
138     * object.
139     *
140     * @param asyncThread Thread to be bound to the callback object. The thread
141     *                    object must represent a thread of execution -- i.e.,
142     *                    asyncThread.joinable() must be true.
143     * @return bool True if successful, false if thread was not properly bound.
144     */
145    bool bind_thread(std::thread&& asyncThread);
146
147    /**
148     * CallbackBase::join_thread ensures that the thread (if any) bound to this
149     * event with CallbackBase::bind_thread has fully finished and cleaned its
150     * resources. It is legal to call this function multiple times, concurrently
151     * or sequentially.
152     */
153    void join_thread();
154
155 protected:
156    /**
157     * CallbackBase::notify enables all prior and future wait* calls on the
158     * callback object to proceed. The call to CallbackBase::notify happens
159     * before any wait* calls on this callback object return (except in the case
160     * of wait_for timing out). The asynchronous call the callback object is
161     * paired with must ensure that any update to state that should be visible
162     * to the caller of wait* happens before the call to CallbackBase::notify.
163     *
164     * CallbackBase::notify must be called exactly once on a given callback
165     * object.
166     */
167    void notify();
168
169 private:
170    // Same as CallbackBase::join_thread but assumes we already hold a lock on
171    // mMutex.
172    void join_thread_locked();
173
174    bool                      mNotified;
175    std::mutex                mMutex;
176    std::condition_variable   mCondition;
177    std::function<bool(void)> mPostWork;
178    std::thread               mThread;
179};
180
181/**
182 * The PreparedModelCallback class is used to receive the error status of
183 * preparing a model as well as the prepared model from a task executing
184 * asynchronously with respect to the runtime. If a calling thread calls wait*
185 * or get* on a PreparedModelCallback object and the corresponding asynchronous
186 * task has not finished preparing the model, the calling thread will block
187 * until the asynchronous task has called notify. For more information on the
188 * synchronization behavior, refer to the CallbackBase class.
189 *
190 * This class inherits the basic blocking and signaling calls from
191 * CallbackBase, and implements the HIDL notify call from
192 * IPreparedModelCallback. This callback object is passed as an argument to
193 * IDevice::prepareModel.
194 */
195class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback {
196 public:
197    PreparedModelCallback();
198    ~PreparedModelCallback() override;
199
200    /**
201     * IPreparedModelCallback::notify marks the callback object with the return
202     * status of the asynchronous model preparation along with the prepared
203     * model, and calls CallbackBase::notify, enabling all prior and future
204     * wait* calls on the PreparedModelCallback object to proceed. For more
205     * information on the synchronization behavior, refer to the CallbackBase
206     * class.
207     *
208     * IPreparedModelCallback::notify must be called exactly once on a given
209     * PreparedModelCallback object.
210     *
211     * @param status Error status returned from asynchronously preparing the
212     *               model; will be:
213     *               - NONE if the asynchronous preparation was successful
214     *               - DEVICE_UNAVAILABLE if driver is offline or busy
215     *               - GENERAL_FAILURE if there is an unspecified error
216     *               - INVALID_ARGUMENT if the input model is invalid
217     * @param preparedModel Returned model that has been prepared for execution,
218     *                      nullptr if the model was unable to be prepared.
219     */
220    Return<void> notify(ErrorStatus status, const sp<IPreparedModel>& preparedModel) override;
221
222    /**
223     * Retrieves the error status returned from the asynchronous task launched
224     * by IDevice::prepareModel. If IDevice::prepareModel has not finished
225     * asynchronously preparing the model, this call will block until the
226     * asynchronous task notifies the object.
227     *
228     * @return status Error status returned from asynchronously preparing the
229     *                model; will be:
230     *                - NONE if the asynchronous preparation was successful
231     *                - DEVICE_UNAVAILABLE if driver is offline or busy
232     *                - GENERAL_FAILURE if there is an unspecified error
233     *                - INVALID_ARGUMENT if the input model is invalid
234     */
235    ErrorStatus getStatus();
236
237    /**
238     * Retrieves the model that has been prepared for execution from the
239     * asynchronous task launched by IDevice::prepareModel. If
240     * IDevice::prepareModel has not finished asynchronously preparing the
241     * model, this call will block until the asynchronous task notifies the
242     * object.
243     *
244     * @return preparedModel Returned model that has been prepared for
245     *                       execution, nullptr if the model was unable to be
246     *                       prepared.
247     */
248    sp<IPreparedModel> getPreparedModel();
249
250 private:
251    ErrorStatus        mErrorStatus;
252    sp<IPreparedModel> mPreparedModel;
253};
254
255/**
256 * The ExecutionCallback class is used to receive the error status of the
257 * execution from a task executing asynchronously with respect to the runtime.
258 * If a calling thread calls wait* or get* on a PreparedModelCallback object and
259 * the corresponding asynchronous task has not finished the execution, the
260 * calling thread will block until the asynchronous task has called notify. For
261 * more information on the synchronization behavior, refer to the CallbackBase
262 * class.
263 *
264 * This class inherits the basic blocking and signaling calls from
265 * CallbackBase, and implements the HIDL notify call from
266 * IExecutionCallback. This callback object is passed as an argument to
267 * IPreparedModel::execute.
268 */
269class ExecutionCallback : public CallbackBase,  public IExecutionCallback {
270 public:
271    ExecutionCallback();
272    ~ExecutionCallback() override;
273
274    /**
275     * IExecutionCallback::notify marks the callback object with the return
276     * status of the asynchronous execution that held this callback and enables
277     * all prior and future wait* calls on the ExecutionCallback object to
278     * proceed. For more information on the synchronization behavior, refer to
279     * the CallbackBase class.
280     *
281     * IExecutionCallback::notify must be called exactly once on a given
282     * ExecutionCallback object.
283     *
284     * @param status Error status returned from asynchronously preparing the
285     *               model; will be:
286     *               - NONE if the asynchronous execution was successful
287     *               - DEVICE_UNAVAILABLE if driver is offline or busy
288     *               - GENERAL_FAILURE if there is an unspecified error
289     *               - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
290     *                 not large enough to store the resultant values
291     *               - INVALID_ARGUMENT if the input request is invalid
292     */
293    Return<void> notify(ErrorStatus status) override;
294
295    /**
296     * Retrieves the error status returned from the asynchronous task launched
297     * by IPreparedModel::execute. If IPreparedModel::execute has not finished
298     * asynchronously executing, this call will block until the asynchronous task
299     * notifies the object.
300     *
301     * @return status Error status returned from asynchronously preparing the
302     *                model; will be:
303     *                - NONE if the asynchronous execution was successful
304     *                - DEVICE_UNAVAILABLE if driver is offline or busy
305     *                - GENERAL_FAILURE if there is an unspecified error
306     *                - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
307     *                  not large enough to store the resultant values
308     *                - INVALID_ARGUMENT if the input request is invalid
309     */
310    ErrorStatus getStatus();
311
312 private:
313    ErrorStatus mErrorStatus;
314};
315
316
317// template function implementation(s) below this point
318
319template<class Rep, class Period>
320std::cv_status CallbackBase::wait_for(const std::chrono::duration<Rep,Period>& timeout_duration) {
321    std::unique_lock<std::mutex> lock(mMutex);
322    std::cv_status status = mCondition.wait_for(lock, timeout_duration, [this]{return mNotified;});
323    if (status != std::cv_status::timeout) {
324        join_thread_locked();
325    }
326    return status;
327}
328
329}  // namespace implementation
330}  // namespace V1_0
331}  // namespace neuralnetworks
332}  // namespace hardware
333}  // namespace android
334
335#endif  // ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
336