1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/training/queue_runner.h"
17#include "tensorflow/core/kernels/ops_util.h"
18#include "tensorflow/core/platform/env.h"
19
20namespace tensorflow {
21
22Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
23                        std::unique_ptr<QueueRunner>* result) {
24  result->reset(new QueueRunner());
25  return (*result)->Init(queue_runner_def);
26}
27
28Status QueueRunner::New(const QueueRunnerDef& queue_runner_def,
29                        Coordinator* coord,
30                        std::unique_ptr<QueueRunner>* result) {
31  result->reset(new QueueRunner());
32  (*result)->coord_ = coord;
33  return (*result)->Init(queue_runner_def);
34}
35
36void QueueRunner::AddErrorCallback(const std::function<void(Status)>& cb) {
37  mutex_lock l(cb_mu_);
38  callbacks_.push_back(cb);
39}
40
41void QueueRunner::ClearErrorCallbacks() {
42  mutex_lock l(cb_mu_);
43  callbacks_.clear();
44}
45
46Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) {
47  queue_name_ = queue_runner_def.queue_name();
48  enqueue_op_names_.clear();
49  enqueue_op_names_.insert(enqueue_op_names_.end(),
50                           queue_runner_def.enqueue_op_name().begin(),
51                           queue_runner_def.enqueue_op_name().end());
52  size_t op_names_size = enqueue_op_names_.size();
53  if (op_names_size > kint32max) {
54    return Status(error::INVALID_ARGUMENT,
55                  "Enqueue ops to run cannot exceed kint32max");
56  }
57  runs_ = static_cast<int>(op_names_size);
58  if (runs_ == 0) {
59    return Status(error::INVALID_ARGUMENT, "Empty enqueue ops to run.");
60  }
61  close_op_name_ = queue_runner_def.close_op_name();
62  cancel_op_name_ = queue_runner_def.cancel_op_name();
63  if (queue_runner_def.queue_closed_exception_types_size() == 0) {
64    queue_closed_exception_types_.insert(error::OUT_OF_RANGE);
65  } else {
66    for (const auto& code : queue_runner_def.queue_closed_exception_types()) {
67      queue_closed_exception_types_.insert(static_cast<int>(code));
68    }
69  }
70
71  int nthreads = runs_;
72  if (coord_) {
73    // One more thread to call Stop()
74    nthreads++;
75  }
76  thread_pool_.reset(new thread::ThreadPool(
77      Env::Default(), SanitizeThreadSuffix(queue_name_), nthreads));
78
79  return Status::OK();
80}
81
82QueueRunner::~QueueRunner() {
83  // Cannot run Stop() here because the session might already be closed or
84  // destroyed.
85  Join().IgnoreError();
86}
87
88Status QueueRunner::Start(Session* sess) { return Start(sess, 0); }
89
90Status QueueRunner::StartAndCollectCostGraph(Session* sess,
91                                             const RunOptions& run_options) {
92  SetRunArgumentsAndCostGraph(run_options);
93  return Start(sess, 0);
94}
95
96Status QueueRunner::Start(Session* sess, int wait_for) {
97  counter_.reset(new BlockingCounter(runs_));
98  for (const string& enqueue_op : enqueue_op_names_) {
99    thread_pool_->Schedule(
100        std::bind(&QueueRunner::Run, this, sess, enqueue_op));
101  }
102  if (coord_) {
103    thread_pool_->Schedule(std::bind(&QueueRunner::Stop, this, sess));
104  }
105  // Wait for up to 'wait_for' milliseconds.
106  if (wait_for > 0) {
107    if (!counter_->WaitFor(std::chrono::milliseconds(wait_for))) {
108      return Status(error::DEADLINE_EXCEEDED,
109                    "Queues not fed before the timeout");
110    }
111    // Check the status of the queue runner as well as the result of the enqueue
112    // operations.
113    mutex_lock l(mu_);
114    if (!enqueue_status_.ok()) {
115      return enqueue_status_;
116    } else {
117      return status_;
118    }
119  }
120  return Status::OK();
121}
122
123Status QueueRunner::StartAndCollectCostGraph(Session* session, int wait_for_ms,
124                                             const RunOptions& run_options) {
125  SetRunArgumentsAndCostGraph(run_options);
126  return Start(session, wait_for_ms);
127}
128
129void QueueRunner::Stop(Session* sess) {
130  if (coord_ != nullptr) {
131    coord_->WaitForStop();
132  }
133  if (!cancel_op_name_.empty()) {
134    UpdateStatus(RealRun(sess, cancel_op_name_, false));
135  }
136  stopped_ = true;
137}
138
139Status QueueRunner::Join() {
140  thread_pool_.reset();
141  mutex_lock l(mu_);
142  return status_;
143}
144
145void QueueRunner::UpdateStatus(const Status& status) {
146  {
147    mutex_lock l(mu_);
148    if (!status_.ok() || status.ok() || IsQueueClosed(status)) {
149      return;
150    }
151    status_ = status;
152  }
153  if (coord_) {
154    coord_->ReportStatus(status);
155  }
156  mutex_lock l(cb_mu_);
157  for (auto& cb : callbacks_) {
158    cb(status);
159  }
160}
161
162void QueueRunner::Run(Session* sess, const string& enqueue_op) {
163  bool first_iteration = true;
164  Status status;
165  while (status.ok()) {
166    if (coord_ && coord_->ShouldStop()) {
167      break;
168    }
169    status = RealRun(sess, enqueue_op, true);
170    if (first_iteration) {
171      if (!status.ok()) {
172        mutex_lock l(mu_);
173        enqueue_status_ = status;
174      }
175      counter_->DecrementCount();
176      first_iteration = false;
177    }
178  }
179  bool last_run = false;
180  {
181    mutex_lock l(mu_);
182    runs_--;
183    last_run = (runs_ == 0);
184  }
185
186  // Close the queue unless the coordinator is shutting down since the cancel op
187  // will be run anway in this case.
188  if (IsQueueClosed(status) && (!coord_ || !coord_->ShouldStop())) {
189    if (last_run && !close_op_name_.empty()) {
190      UpdateStatus(RealRun(sess, close_op_name_, false));
191    }
192  } else if (!status.ok()) {
193    LOG(ERROR) << "Queue runner thread got a failure status: "
194               << status.ToString();
195    UpdateStatus(status);
196    if (coord_) {
197      coord_->RequestStop().IgnoreError();
198    }
199  }
200}
201
202Status QueueRunner::GetStatus() {
203  mutex_lock l(mu_);
204  return status_;
205}
206
207Status QueueRunner::ExportCostGraph(CostGraphDef* cost_graph) const {
208  if (!cg_mu_) {
209    return Status(error::FAILED_PRECONDITION,
210                  "This QueueRunner doesn't collect a cost graph.");
211  }
212  mutex_lock l(*cg_mu_);
213  cost_graph->MergeFrom(*cost_graph_);
214  return Status::OK();
215}
216
217void QueueRunner::SetRunArgumentsAndCostGraph(const RunOptions& run_options) {
218  cg_mu_.reset(new mutex());
219  {
220    mutex_lock l(*cg_mu_);
221    cost_graph_.reset(new CostGraphDef());
222  }
223  run_options_ = run_options;
224}
225
226Status QueueRunner::RealRun(Session* sess, const string& op,
227                            bool update_costs) {
228  Status s;
229  if (update_costs && cg_mu_) {
230    RunMetadata metadata;
231    s = sess->Run(run_options_, {}, {}, {op}, nullptr, &metadata);
232    mutex_lock l(*cg_mu_);
233    cost_graph_->Swap(metadata.mutable_cost_graph());
234  } else {
235    s = sess->Run({}, {}, {op}, nullptr);
236  }
237  return s;
238}
239
240}  // namespace tensorflow
241