1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include <deque> 16 17#include "tensorflow/core/common_runtime/function.h" 18#include "tensorflow/core/framework/partial_tensor_shape.h" 19#include "tensorflow/core/framework/tensor.h" 20#include "tensorflow/core/kernels/data/captured_function.h" 21#include "tensorflow/core/kernels/data/dataset.h" 22#include "tensorflow/core/kernels/data/dataset_utils.h" 23#include "tensorflow/core/lib/gtl/cleanup.h" 24#include "tensorflow/core/lib/random/random.h" 25 26namespace tensorflow { 27 28namespace { 29 30// See documentation in ../ops/dataset_ops.cc for a high-level 31// description of the following op. 32 33class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { 34 public: 35 explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx) 36 : UnaryDatasetOpKernel(ctx), 37 graph_def_version_(ctx->graph_def_version()) { 38 OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); 39 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 40 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 41 } 42 43 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 44 DatasetBase** output) override { 45 OpInputList inputs; 46 OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); 47 std::vector<Tensor> other_arguments; 48 other_arguments.reserve(inputs.size()); 49 for (const Tensor& t : inputs) { 50 other_arguments.push_back(t); 51 } 52 53 int64 cycle_length = 0; 54 OP_REQUIRES_OK(ctx, 55 ParseScalarArgument(ctx, "cycle_length", &cycle_length)); 56 OP_REQUIRES(ctx, cycle_length > 0, 57 errors::InvalidArgument("`cycle_length` must be > 0")); 58 59 int64 block_length = 0; 60 OP_REQUIRES_OK(ctx, 61 ParseScalarArgument(ctx, "block_length", &block_length)); 62 OP_REQUIRES(ctx, block_length > 0, 63 errors::InvalidArgument("`block_length` must be > 0")); 64 65 bool sloppy = false; 66 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy)); 67 68 int64 buffer_output_elements = 0; 69 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_output_elements", 70 &buffer_output_elements)); 71 OP_REQUIRES( 72 ctx, buffer_output_elements > 0, 73 errors::InvalidArgument("`buffer_output_elements` must be > 0")); 74 75 int64 prefetch_input_elements = 0; 76 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefetch_input_elements", 77 &prefetch_input_elements)); 78 OP_REQUIRES( 79 ctx, prefetch_input_elements >= 0, 80 errors::InvalidArgument("`prefetch_input_elements` must be >= 0")); 81 82 std::unique_ptr<CapturedFunction> captured_func; 83 OP_REQUIRES_OK(ctx, CapturedFunction::Create( 84 func_, std::move(other_arguments), &captured_func)); 85 86 *output = 87 new Dataset(input, std::move(captured_func), cycle_length, block_length, 88 sloppy, buffer_output_elements, prefetch_input_elements, 89 output_types_, output_shapes_); 90 } 91 92 private: 93 class Dataset : public DatasetBase { 94 public: 95 Dataset(const DatasetBase* input, 96 std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, 97 int64 block_length, bool sloppy, int64 buffer_output_elements, 98 int64 prefetch_input_elements, const DataTypeVector& output_types, 99 const std::vector<PartialTensorShape>& output_shapes) 100 : input_(input), 101 captured_func_(std::move(captured_func)), 102 cycle_length_(cycle_length), 103 block_length_(block_length), 104 sloppy_(sloppy), 105 buffer_output_elements_(buffer_output_elements), 106 prefetch_input_elements_(prefetch_input_elements), 107 output_types_(output_types), 108 output_shapes_(output_shapes) { 109 input_->Ref(); 110 } 111 112 ~Dataset() override { input_->Unref(); } 113 114 std::unique_ptr<IteratorBase> MakeIterator( 115 const string& prefix) const override { 116 return std::unique_ptr<IteratorBase>(new Iterator( 117 {this, strings::StrCat(prefix, "::ParallelInterleave")})); 118 } 119 120 const DataTypeVector& output_dtypes() const override { 121 return output_types_; 122 } 123 const std::vector<PartialTensorShape>& output_shapes() const override { 124 return output_shapes_; 125 } 126 127 string DebugString() override { 128 return "ParallelInterleaveDatasetOp::Dataset"; 129 } 130 131 private: 132 int64 num_threads() const { 133 return cycle_length_ + prefetch_input_elements_; 134 } 135 136 // Parallel interleave's implementation is designed around a few principles: 137 // 1. Thread creation is relatively expensive. (Not reusing 138 // threads causes a number of indirect costs such as poorer tcmalloc 139 // performance due to thread-local caches, etc.) We allocate a fixed 140 // number of threads at the start and never change. This is why we've 141 // fused functionality that is theoretically orthogonal (i.e. 142 // .prefetch()) into the implementation. 143 // 2. Drop-in replacement for standard interleave. The goal will be to 144 // auto-opt people into an optimized implementation without any work 145 // on the customer's part. We thus go through great pains to maintain 146 // identical iteration orders, full determinism (disabled only via a 147 // flag, etc.) 148 // 3. Performance across a variety of environments and I/O envelopes. 149 // 150 // The actual implementation centers around a collection of worker threads 151 // and their corresponding worker state (tracked in the `workers_` vector). 152 // Worker threads repeatedly receive a vector of Tensors that are used as 153 // input to the flat-map function (`captured_func_`). The output of this 154 // function must be a dataset. The worker thread then repeatedly calls 155 // `GetNext()`, maintaining a buffer of elements to minimize the likelihood 156 // that a caller will block waiting for an element to be produced. 157 // 158 // Pointers to these worker states are kept in 2 disjoint data structures: 159 // 1. `interleave_` is a vector containing pointers to `WorkerState`s that 160 // we 161 // are interleaving. Worker threads backing these WorkerStates should 162 // be regularly producing values. 163 // 2. `staging_` is a deque containing pointers to WorkerStates that we 164 // will move to `interleave_` when an iterator in `interleave_` is 165 // exhausted. 166 // 167 // The client calls `GetNext[Internal]()` to retrieve an output element. The 168 // internal implementation updates the state of `interleave_` and `staging_` 169 // as output iterators (run by the worker threads) are exhausted. 170 // 171 // `input_impl_` is the input iterator that generates arguments for the 172 // flat-map function (`captured_func_`). It is set to an iterator at 173 // Iterator construction, and is fixed until we consume all input elements. 174 // Once it is exhausted, we reset the unique_ptr to eagerly deallocate 175 // memory. 176 // 177 // A few invariants are maintained: 178 // 1. No element in interleave_ should be a nullptr unless `staging_` is 179 // empty and `input_impl_` is empty. 180 // 2. Every `worker_` element is pointed to by at most one element of the 181 // union of `interleave_` and `staging_`. 182 // 3. Unless `input_impl_` is empty, every `worker_` must be pointed to by 183 // an element in `interleave_` or `staging_`. 184 class Iterator : public DatasetIterator<Dataset> { 185 public: 186 explicit Iterator(const Params& params) 187 : DatasetIterator<Dataset>(params), 188 input_impl_(params.dataset->input_->MakeIterator(params.prefix)), 189 workers_(dataset()->num_threads()) {} 190 191 ~Iterator() override { 192 mutex_lock l(mu_); 193 cancelled_ = true; 194 // Notify all workers in case they are blocked. 195 for (auto& worker : workers_) { 196 worker.cond_var.notify_all(); 197 } 198 } 199 200 // It is implemented so that it matches the deterministic interleave 201 // unless getting the next element would block and we are allowed to be 202 // sloppy. 203 Status GetNextInternal(IteratorContext* ctx, 204 std::vector<Tensor>* out_tensors, 205 bool* end_of_sequence) override { 206 mutex_lock l(mu_); 207 TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx)); 208 while (!cancelled_) { 209 // Wait for an item to become available, blocking if necessary. If we 210 // are allowed to be sloppy, we can skip over input datasets that do 211 // not have an item readily available. 212 bool can_produce_elements = false; 213 bool must_wait_for_input = true; 214 for (int64 i = 0; i < interleave_.size(); ++i) { 215 int64 index = (next_index_ + i) % interleave_.size(); 216 WorkerState* current_worker = interleave_[index]; 217 if (!current_worker) continue; // Empty interleave elements. 218 can_produce_elements |= current_worker->MayHaveElements(); 219 if (!current_worker->outputs.empty()) { 220 // We have an element! 221 next_index_ = index; 222 if (i == 0) { 223 block_count_++; 224 if (block_count_ == dataset()->block_length_) { 225 next_index_ = (index + 1) % interleave_.size(); 226 block_count_ = 0; 227 } 228 } else { 229 block_count_ = 0; 230 } 231 *end_of_sequence = false; 232 Status s = current_worker->outputs.front().status; 233 current_worker->outputs.front().output.swap(*out_tensors); 234 current_worker->outputs.pop_front(); 235 current_worker->cond_var.notify_one(); 236 return s; 237 } else if (current_worker->is_producing && !dataset()->sloppy_) { 238 // current_worker.outputs.empty(), and we must wait for this 239 // iterator. 240 if (next_index_ != index) { 241 // We have advanced to a new iterator; reset block counts. 242 next_index_ = index; 243 block_count_ = 0; 244 } 245 break; 246 } else if (!current_worker->is_producing) { 247 // This iterator has reached end of input. 248 interleave_[index] = nullptr; 249 if (input_impl_) { 250 // Start prefetching a new iterator. 251 std::vector<Tensor> args; 252 bool end_of_input = false; 253 Status s = input_impl_->GetNext(ctx, &args, &end_of_input); 254 if (end_of_input) { 255 input_impl_.reset(); 256 } else { 257 current_worker->SetInputs(s, std::move(args)); 258 staging_.emplace_back(current_worker); 259 } 260 } 261 262 if (!staging_.empty()) { 263 // Move a worker from `staging_` to `interleave_`. 264 interleave_[index] = staging_.front(); 265 staging_.pop_front(); 266 267 next_index_ = (index + 1) % interleave_.size(); 268 block_count_ = 0; 269 // Restart the inner [for] loop 270 can_produce_elements = true; 271 must_wait_for_input = false; 272 break; 273 } 274 } 275 } 276 277 if (!can_produce_elements && !input_impl_) { 278 // No potential for future values. 279 *end_of_sequence = true; 280 return Status::OK(); 281 } 282 283 if (must_wait_for_input) { 284 // Wait for elements to become available. 285 if (dataset()->sloppy_) { 286 sloppy_cond_var_.wait(l); 287 } else { 288 interleave_[next_index_]->cond_var.wait(l); 289 } 290 } 291 } 292 return errors::Cancelled( 293 "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext"); 294 } 295 296 private: 297 // OutputElem contains the information from a call to GetNext by an output 298 // iterator. 299 struct OutputElem { 300 // The output iterator sets `status` if getting the output element 301 // fails. 302 Status status; 303 // The buffered data element. 304 std::vector<Tensor> output; 305 306 explicit OutputElem(const Status& s) : status(s) {} 307 }; 308 309 // Worker threads operate on their relevant WorkerState structs. 310 // 311 // WorkerState's fields are all protected by mu_; 312 struct WorkerState { 313 // The arguments to be used to construct an output iterator. 314 std::vector<Tensor> input; 315 // The buffered output elements. 316 std::deque<OutputElem> outputs; 317 // Set to true iff the worker thread expects to append more elements to 318 // outputs. is_producing can be false despite !outputs.empty(). 319 // Concretely, all output elements will have been consumed only when: 320 // is_producing == false && outputs.empty(); 321 bool is_producing = false; 322 // Condition variable used to coordinate between threads. The worker 323 // thread waits on this condition variable when it is either (1) waiting 324 // for the main thread to add arguments to `input`, or (2) waiting for 325 // the main thread to consume an element of `outputs`. The main thread 326 // waits on cond_var if it is waiting for the worker thread to produce 327 // an element into `outputs` (this implies sloppy_==false). 328 condition_variable cond_var; 329 330 inline bool MayHaveElements() const { 331 return is_producing || !outputs.empty(); 332 } 333 334 // Sets inputs for a worker thread and notifies it to start processing. 335 void SetInputs(const Status& s, std::vector<Tensor> input_arguments) { 336 if (s.ok()) { 337 DCHECK(!MayHaveElements()) 338 << "Tried to start inputs, despite already producing!"; 339 input = std::move(input_arguments); 340 is_producing = true; 341 cond_var.notify_one(); 342 } else { 343 outputs.emplace_back(s); 344 } 345 } 346 }; 347 348 Status EnsureWorkerThreadsStarted(IteratorContext* ctx) 349 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 350 if (worker_threads_.empty()) { 351 worker_threads_.reserve(dataset()->num_threads()); 352 for (int64 i = 0; i < dataset()->num_threads(); ++i) { 353 std::vector<Tensor> args; 354 bool end_of_input = false; 355 Status s = input_impl_->GetNext(ctx, &args, &end_of_input); 356 if (end_of_input) { 357 input_impl_.reset(); 358 return Status::OK(); 359 } 360 workers_[i].SetInputs(s, std::move(args)); 361 worker_threads_.emplace_back(ctx->env()->StartThread( 362 {}, "worker_thread", 363 std::bind(&Iterator::WorkerThread, this, 364 new IteratorContext(*ctx), i))); 365 if (i < dataset()->cycle_length_) { 366 interleave_.push_back(&workers_[i]); 367 } else { 368 staging_.push_back(&workers_[i]); 369 } 370 } 371 DCHECK(interleave_.size() == dataset()->cycle_length_); 372 DCHECK(staging_.size() == dataset()->prefetch_input_elements_); 373 } 374 return Status::OK(); 375 } 376 377 // Produces elements into the worker's output buffers. 378 void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) { 379 // std::function arguments are copy-constructable, so we pass raw 380 // pointers, and then immediately wrap them to ensure correct ownership. 381 std::unique_ptr<IteratorContext> ctx(ctx_ptr); 382 auto cleanup = gtl::MakeCleanup([this, thread_index] { 383 mutex_lock l(mu_); 384 workers_[thread_index].cond_var.notify_all(); 385 }); 386 387 while (true) { 388 // 1. Wait for input. 389 std::vector<Tensor> input; 390 { 391 mutex_lock l(mu_); 392 while (!cancelled_ && !workers_[thread_index].is_producing) { 393 workers_[thread_index].cond_var.wait(l); 394 } 395 if (cancelled_) return; 396 input.swap(workers_[thread_index].input); 397 } 398 399 // 2. Run the user defined function to produce a new iterator. 400 std::unique_ptr<IteratorBase> iterator; 401 Status s = dataset::MakeIteratorFromInputElement( 402 ctx.get(), input, thread_index, dataset()->captured_func_.get(), 403 prefix(), &iterator); 404 input.clear(); // Release memory as early as possible. 405 406 if (!s.ok()) { 407 mutex_lock l(mu_); 408 workers_[thread_index].outputs.emplace_back(s); 409 workers_[thread_index].is_producing = false; 410 workers_[thread_index].cond_var.notify_one(); 411 } else { 412 // 3. Produce elements 413 bool end_of_sequence = false; 414 while (!end_of_sequence) { 415 // 3.a Produce an element! 416 std::vector<Tensor> output_elem; 417 s = iterator->GetNext(ctx.get(), &output_elem, &end_of_sequence); 418 419 // 3.b Make it available to the client. 420 { 421 mutex_lock l(mu_); 422 423 // Wait for space in the prefetch queue. 424 while (!cancelled_ && workers_[thread_index].outputs.size() == 425 dataset()->buffer_output_elements_) { 426 workers_[thread_index].cond_var.wait(l); 427 } 428 if (cancelled_) return; 429 430 // Output the element. 431 workers_[thread_index].is_producing = !end_of_sequence; 432 if (!end_of_sequence) { 433 workers_[thread_index].outputs.emplace_back(s); 434 workers_[thread_index].outputs.back().output.swap( 435 output_elem); 436 } 437 if (dataset()->sloppy_) { 438 sloppy_cond_var_.notify_one(); 439 } else { 440 workers_[thread_index].cond_var.notify_one(); 441 } 442 } 443 } 444 } 445 } 446 } 447 448 // Mutex & condition variable to guard mutable iterator internals and 449 // coordinate among worker threads and client thread[s]. 450 mutex mu_; 451 // The main thread waits on this condition variable if running in sloppy 452 // mode and no values are available. 453 condition_variable sloppy_cond_var_; 454 455 // The iterator producing elements which are converted to datasets by 456 // the dataset()->captured_func_ then interleaved together. 457 // input_impl_ is reset when we have exhausted its input. 458 std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); 459 460 // The WorkerState structs the worker threads operate on. 461 // workers_ elements are in at most one of interleave_ and staging_. 462 std::vector<WorkerState> workers_ GUARDED_BY(mu_); 463 464 // The iterators to interleave 465 std::vector<WorkerState*> interleave_ GUARDED_BY(mu_); 466 // Prefetched iterators 467 std::deque<WorkerState*> staging_ GUARDED_BY(mu_); 468 469 // The index into output_elements_ for next element to produce. 470 size_t next_index_ GUARDED_BY(mu_) = 0; 471 // The number of items produced so far within the block 472 size_t block_count_ GUARDED_BY(mu_) = 0; 473 // Flag to instruct the worker threads to exit. 474 bool cancelled_ GUARDED_BY(mu_) = false; 475 // The worker threads. This must be last to ensure the 476 // threads have exited before any other members are deallocated. 477 // TODO(b/65178177): Avoid allocating additional threads. 478 std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_); 479 }; 480 481 const DatasetBase* const input_; 482 const std::unique_ptr<CapturedFunction> captured_func_; 483 const int64 cycle_length_; 484 const int64 block_length_; 485 const bool sloppy_; 486 const int64 buffer_output_elements_; 487 const int64 prefetch_input_elements_; 488 const DataTypeVector output_types_; 489 const std::vector<PartialTensorShape> output_shapes_; 490 }; 491 492 const int graph_def_version_; 493 DataTypeVector output_types_; 494 std::vector<PartialTensorShape> output_shapes_; 495 NameAttrList func_; 496}; 497 498REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU), 499 ParallelInterleaveDatasetOp); 500 501} // namespace 502 503} // namespace tensorflow 504