1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Callbacks.h"
18#include <android-base/logging.h>
19
20namespace android {
21namespace hardware {
22namespace neuralnetworks {
23namespace V1_0 {
24namespace implementation {
25
26CallbackBase::CallbackBase() : mNotified(false) {}
27
28CallbackBase::~CallbackBase() {
29    // Note that we cannot call CallbackBase::join_thread from here:
30    // CallbackBase is intended to be reference counted, and it is possible that
31    // the reference count drops to zero in the bound thread, causing the
32    // bound thread to call this destructor. If a thread tries to join
33    // itself, it throws an exception, producing a message like the
34    // following:
35    //
36    //     terminating with uncaught exception of type std::__1::system_error:
37    //     thread::join failed: Resource deadlock would occur
38}
39
40void CallbackBase::wait() {
41    std::unique_lock<std::mutex> lock(mMutex);
42    mCondition.wait(lock, [this]{return mNotified;});
43    join_thread_locked();
44}
45
46bool CallbackBase::on_finish(std::function<bool(void)> post_work) {
47    std::lock_guard<std::mutex> lock(mMutex);
48    if (mPostWork != nullptr) {
49        LOG(ERROR) << "CallbackBase::on_finish -- a post-work function has already been bound to "
50                   "this callback object";
51        return false;
52    }
53    if (post_work == nullptr) {
54        LOG(ERROR) << "CallbackBase::on_finish -- the new post-work function is invalid";
55        return false;
56    }
57    mPostWork = std::move(post_work);
58    return true;
59}
60
61bool CallbackBase::bind_thread(std::thread&& asyncThread) {
62    std::lock_guard<std::mutex> lock(mMutex);
63    if (mThread.joinable()) {
64        LOG(ERROR) << "CallbackBase::bind_thread -- a thread has already been bound to this "
65                   "callback object";
66        return false;
67    }
68    if (!asyncThread.joinable()) {
69        LOG(ERROR) << "CallbackBase::bind_thread -- the new thread is not joinable";
70        return false;
71    }
72    mThread = std::move(asyncThread);
73    return true;
74}
75
76void CallbackBase::join_thread() {
77    std::lock_guard<std::mutex> lock(mMutex);
78    join_thread_locked();
79}
80
81void CallbackBase::notify() {
82    {
83        std::lock_guard<std::mutex> lock(mMutex);
84        mNotified = true;
85        if (mPostWork != nullptr) {
86            bool success = mPostWork();
87            if (!success) {
88                LOG(ERROR) << "CallbackBase::notify -- post work failed";
89            }
90        }
91    }
92    mCondition.notify_all();
93}
94
95void CallbackBase::join_thread_locked() {
96    if (mThread.joinable()) {
97        mThread.join();
98    }
99}
100
101PreparedModelCallback::PreparedModelCallback() :
102        mErrorStatus(ErrorStatus::GENERAL_FAILURE), mPreparedModel(nullptr) {}
103
104PreparedModelCallback::~PreparedModelCallback() {}
105
106Return<void> PreparedModelCallback::notify(ErrorStatus errorStatus,
107                                           const sp<IPreparedModel>& preparedModel) {
108    mErrorStatus = errorStatus;
109    mPreparedModel = preparedModel;
110    CallbackBase::notify();
111    return Void();
112}
113
114ErrorStatus PreparedModelCallback::getStatus() {
115    wait();
116    return mErrorStatus;
117}
118
119sp<IPreparedModel> PreparedModelCallback::getPreparedModel() {
120    wait();
121    return mPreparedModel;
122}
123
124ExecutionCallback::ExecutionCallback() : mErrorStatus(ErrorStatus::GENERAL_FAILURE) {}
125
126ExecutionCallback::~ExecutionCallback() {}
127
128Return<void> ExecutionCallback::notify(ErrorStatus errorStatus) {
129    mErrorStatus = errorStatus;
130    CallbackBase::notify();
131    return Void();
132}
133
134ErrorStatus ExecutionCallback::getStatus() {
135    wait();
136    return mErrorStatus;
137}
138
139}  // namespace implementation
140}  // namespace V1_0
141}  // namespace neuralnetworks
142}  // namespace hardware
143}  // namespace android
144