1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <deque>
16
17#include "tensorflow/core/common_runtime/function.h"
18#include "tensorflow/core/framework/partial_tensor_shape.h"
19#include "tensorflow/core/framework/tensor.h"
20#include "tensorflow/core/kernels/data/captured_function.h"
21#include "tensorflow/core/kernels/data/dataset.h"
22#include "tensorflow/core/kernels/data/dataset_utils.h"
23#include "tensorflow/core/lib/gtl/cleanup.h"
24#include "tensorflow/core/lib/random/random.h"
25
26namespace tensorflow {
27
28namespace {
29
30// See documentation in ../ops/dataset_ops.cc for a high-level
31// description of the following op.
32
33class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
34 public:
35  explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
36      : UnaryDatasetOpKernel(ctx),
37        graph_def_version_(ctx->graph_def_version()) {
38    OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
39    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
40    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
41  }
42
43  void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
44                   DatasetBase** output) override {
45    OpInputList inputs;
46    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
47    std::vector<Tensor> other_arguments;
48    other_arguments.reserve(inputs.size());
49    for (const Tensor& t : inputs) {
50      other_arguments.push_back(t);
51    }
52
53    int64 cycle_length = 0;
54    OP_REQUIRES_OK(ctx,
55                   ParseScalarArgument(ctx, "cycle_length", &cycle_length));
56    OP_REQUIRES(ctx, cycle_length > 0,
57                errors::InvalidArgument("`cycle_length` must be > 0"));
58
59    int64 block_length = 0;
60    OP_REQUIRES_OK(ctx,
61                   ParseScalarArgument(ctx, "block_length", &block_length));
62    OP_REQUIRES(ctx, block_length > 0,
63                errors::InvalidArgument("`block_length` must be > 0"));
64
65    bool sloppy = false;
66    OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "sloppy", &sloppy));
67
68    int64 buffer_output_elements = 0;
69    OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "buffer_output_elements",
70                                            &buffer_output_elements));
71    OP_REQUIRES(
72        ctx, buffer_output_elements > 0,
73        errors::InvalidArgument("`buffer_output_elements` must be > 0"));
74
75    int64 prefetch_input_elements = 0;
76    OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefetch_input_elements",
77                                            &prefetch_input_elements));
78    OP_REQUIRES(
79        ctx, prefetch_input_elements >= 0,
80        errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
81
82    std::unique_ptr<CapturedFunction> captured_func;
83    OP_REQUIRES_OK(ctx, CapturedFunction::Create(
84                            func_, std::move(other_arguments), &captured_func));
85
86    *output =
87        new Dataset(input, std::move(captured_func), cycle_length, block_length,
88                    sloppy, buffer_output_elements, prefetch_input_elements,
89                    output_types_, output_shapes_);
90  }
91
92 private:
93  class Dataset : public DatasetBase {
94   public:
95    Dataset(const DatasetBase* input,
96            std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
97            int64 block_length, bool sloppy, int64 buffer_output_elements,
98            int64 prefetch_input_elements, const DataTypeVector& output_types,
99            const std::vector<PartialTensorShape>& output_shapes)
100        : input_(input),
101          captured_func_(std::move(captured_func)),
102          cycle_length_(cycle_length),
103          block_length_(block_length),
104          sloppy_(sloppy),
105          buffer_output_elements_(buffer_output_elements),
106          prefetch_input_elements_(prefetch_input_elements),
107          output_types_(output_types),
108          output_shapes_(output_shapes) {
109      input_->Ref();
110    }
111
112    ~Dataset() override { input_->Unref(); }
113
114    std::unique_ptr<IteratorBase> MakeIterator(
115        const string& prefix) const override {
116      return std::unique_ptr<IteratorBase>(new Iterator(
117          {this, strings::StrCat(prefix, "::ParallelInterleave")}));
118    }
119
120    const DataTypeVector& output_dtypes() const override {
121      return output_types_;
122    }
123    const std::vector<PartialTensorShape>& output_shapes() const override {
124      return output_shapes_;
125    }
126
127    string DebugString() override {
128      return "ParallelInterleaveDatasetOp::Dataset";
129    }
130
131   private:
132    int64 num_threads() const {
133      return cycle_length_ + prefetch_input_elements_;
134    }
135
136    // Parallel interleave's implementation is designed around a few principles:
137    //  1. Thread creation is relatively expensive. (Not reusing
138    //     threads causes a number of indirect costs such as poorer tcmalloc
139    //     performance due to thread-local caches, etc.) We allocate a fixed
140    //     number of threads at the start and never change. This is why we've
141    //     fused functionality that is theoretically orthogonal (i.e.
142    //     .prefetch()) into the implementation.
143    //  2. Drop-in replacement for standard interleave. The goal will be to
144    //     auto-opt people into an optimized implementation without any work
145    //     on the customer's part. We thus go through great pains to maintain
146    //     identical iteration orders, full determinism (disabled only via a
147    //     flag, etc.)
148    //  3. Performance across a variety of environments and I/O envelopes.
149    //
150    // The actual implementation centers around a collection of worker threads
151    // and their corresponding worker state (tracked in the `workers_` vector).
152    // Worker threads repeatedly receive a vector of Tensors that are used as
153    // input to the flat-map function (`captured_func_`). The output of this
154    // function must be a dataset. The worker thread then repeatedly calls
155    // `GetNext()`, maintaining a buffer of elements to minimize the likelihood
156    // that a caller will block waiting for an element to be produced.
157    //
158    // Pointers to these worker states are kept in 2 disjoint data structures:
159    //  1. `interleave_` is a vector containing pointers to `WorkerState`s that
160    //  we
161    //     are interleaving. Worker threads backing these WorkerStates should
162    //     be regularly producing values.
163    //  2. `staging_` is a deque containing pointers to WorkerStates that we
164    //     will move to `interleave_` when an iterator in `interleave_` is
165    //     exhausted.
166    //
167    // The client calls `GetNext[Internal]()` to retrieve an output element. The
168    // internal implementation updates the state of `interleave_` and `staging_`
169    // as output iterators (run by the worker threads) are exhausted.
170    //
171    // `input_impl_` is the input iterator that generates arguments for the
172    // flat-map function (`captured_func_`). It is set to an iterator at
173    // Iterator construction, and is fixed until we consume all input elements.
174    // Once it is exhausted, we reset the unique_ptr to eagerly deallocate
175    // memory.
176    //
177    // A few invariants are maintained:
178    //  1. No element in interleave_ should be a nullptr unless `staging_` is
179    //     empty and `input_impl_` is empty.
180    //  2. Every `worker_` element is pointed to by at most one element of the
181    //     union of `interleave_` and `staging_`.
182    //  3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
183    //     an element in `interleave_` or `staging_`.
184    class Iterator : public DatasetIterator<Dataset> {
185     public:
186      explicit Iterator(const Params& params)
187          : DatasetIterator<Dataset>(params),
188            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
189            workers_(dataset()->num_threads()) {}
190
191      ~Iterator() override {
192        mutex_lock l(mu_);
193        cancelled_ = true;
194        // Notify all workers in case they are blocked.
195        for (auto& worker : workers_) {
196          worker.cond_var.notify_all();
197        }
198      }
199
200      // It is implemented so that it matches the deterministic interleave
201      // unless getting the next element would block and we are allowed to be
202      // sloppy.
203      Status GetNextInternal(IteratorContext* ctx,
204                             std::vector<Tensor>* out_tensors,
205                             bool* end_of_sequence) override {
206        mutex_lock l(mu_);
207        TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
208        while (!cancelled_) {
209          // Wait for an item to become available, blocking if necessary. If we
210          // are allowed to be sloppy, we can skip over input datasets that do
211          // not have an item readily available.
212          bool can_produce_elements = false;
213          bool must_wait_for_input = true;
214          for (int64 i = 0; i < interleave_.size(); ++i) {
215            int64 index = (next_index_ + i) % interleave_.size();
216            WorkerState* current_worker = interleave_[index];
217            if (!current_worker) continue;  // Empty interleave elements.
218            can_produce_elements |= current_worker->MayHaveElements();
219            if (!current_worker->outputs.empty()) {
220              // We have an element!
221              next_index_ = index;
222              if (i == 0) {
223                block_count_++;
224                if (block_count_ == dataset()->block_length_) {
225                  next_index_ = (index + 1) % interleave_.size();
226                  block_count_ = 0;
227                }
228              } else {
229                block_count_ = 0;
230              }
231              *end_of_sequence = false;
232              Status s = current_worker->outputs.front().status;
233              current_worker->outputs.front().output.swap(*out_tensors);
234              current_worker->outputs.pop_front();
235              current_worker->cond_var.notify_one();
236              return s;
237            } else if (current_worker->is_producing && !dataset()->sloppy_) {
238              // current_worker.outputs.empty(), and we must wait for this
239              // iterator.
240              if (next_index_ != index) {
241                // We have advanced to a new iterator; reset block counts.
242                next_index_ = index;
243                block_count_ = 0;
244              }
245              break;
246            } else if (!current_worker->is_producing) {
247              // This iterator has reached end of input.
248              interleave_[index] = nullptr;
249              if (input_impl_) {
250                // Start prefetching a new iterator.
251                std::vector<Tensor> args;
252                bool end_of_input = false;
253                Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
254                if (end_of_input) {
255                  input_impl_.reset();
256                } else {
257                  current_worker->SetInputs(s, std::move(args));
258                  staging_.emplace_back(current_worker);
259                }
260              }
261
262              if (!staging_.empty()) {
263                // Move a worker from `staging_` to `interleave_`.
264                interleave_[index] = staging_.front();
265                staging_.pop_front();
266
267                next_index_ = (index + 1) % interleave_.size();
268                block_count_ = 0;
269                // Restart the inner [for] loop
270                can_produce_elements = true;
271                must_wait_for_input = false;
272                break;
273              }
274            }
275          }
276
277          if (!can_produce_elements && !input_impl_) {
278            // No potential for future values.
279            *end_of_sequence = true;
280            return Status::OK();
281          }
282
283          if (must_wait_for_input) {
284            // Wait for elements to become available.
285            if (dataset()->sloppy_) {
286              sloppy_cond_var_.wait(l);
287            } else {
288              interleave_[next_index_]->cond_var.wait(l);
289            }
290          }
291        }
292        return errors::Cancelled(
293            "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
294      }
295
296     private:
297      // OutputElem contains the information from a call to GetNext by an output
298      // iterator.
299      struct OutputElem {
300        // The output iterator sets `status` if getting the output element
301        // fails.
302        Status status;
303        // The buffered data element.
304        std::vector<Tensor> output;
305
306        explicit OutputElem(const Status& s) : status(s) {}
307      };
308
309      // Worker threads operate on their relevant WorkerState structs.
310      //
311      // WorkerState's fields are all protected by mu_;
312      struct WorkerState {
313        // The arguments to be used to construct an output iterator.
314        std::vector<Tensor> input;
315        // The buffered output elements.
316        std::deque<OutputElem> outputs;
317        // Set to true iff the worker thread expects to append more elements to
318        // outputs. is_producing can be false despite !outputs.empty().
319        // Concretely, all output elements will have been consumed only when:
320        // is_producing == false && outputs.empty();
321        bool is_producing = false;
322        // Condition variable used to coordinate between threads. The worker
323        // thread waits on this condition variable when it is either (1) waiting
324        // for the main thread to add arguments to `input`, or (2) waiting for
325        // the main thread to consume an element of `outputs`. The main thread
326        // waits on cond_var if it is waiting for the worker thread to produce
327        // an element into `outputs` (this implies sloppy_==false).
328        condition_variable cond_var;
329
330        inline bool MayHaveElements() const {
331          return is_producing || !outputs.empty();
332        }
333
334        // Sets inputs for a worker thread and notifies it to start processing.
335        void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
336          if (s.ok()) {
337            DCHECK(!MayHaveElements())
338                << "Tried to start inputs, despite already producing!";
339            input = std::move(input_arguments);
340            is_producing = true;
341            cond_var.notify_one();
342          } else {
343            outputs.emplace_back(s);
344          }
345        }
346      };
347
348      Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
349          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
350        if (worker_threads_.empty()) {
351          worker_threads_.reserve(dataset()->num_threads());
352          for (int64 i = 0; i < dataset()->num_threads(); ++i) {
353            std::vector<Tensor> args;
354            bool end_of_input = false;
355            Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
356            if (end_of_input) {
357              input_impl_.reset();
358              return Status::OK();
359            }
360            workers_[i].SetInputs(s, std::move(args));
361            worker_threads_.emplace_back(ctx->env()->StartThread(
362                {}, "worker_thread",
363                std::bind(&Iterator::WorkerThread, this,
364                          new IteratorContext(*ctx), i)));
365            if (i < dataset()->cycle_length_) {
366              interleave_.push_back(&workers_[i]);
367            } else {
368              staging_.push_back(&workers_[i]);
369            }
370          }
371          DCHECK(interleave_.size() == dataset()->cycle_length_);
372          DCHECK(staging_.size() == dataset()->prefetch_input_elements_);
373        }
374        return Status::OK();
375      }
376
377      // Produces elements into the worker's output buffers.
378      void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
379        // std::function arguments are copy-constructable, so we pass raw
380        // pointers, and then immediately wrap them to ensure correct ownership.
381        std::unique_ptr<IteratorContext> ctx(ctx_ptr);
382        auto cleanup = gtl::MakeCleanup([this, thread_index] {
383          mutex_lock l(mu_);
384          workers_[thread_index].cond_var.notify_all();
385        });
386
387        while (true) {
388          // 1. Wait for input.
389          std::vector<Tensor> input;
390          {
391            mutex_lock l(mu_);
392            while (!cancelled_ && !workers_[thread_index].is_producing) {
393              workers_[thread_index].cond_var.wait(l);
394            }
395            if (cancelled_) return;
396            input.swap(workers_[thread_index].input);
397          }
398
399          // 2. Run the user defined function to produce a new iterator.
400          std::unique_ptr<IteratorBase> iterator;
401          Status s = dataset::MakeIteratorFromInputElement(
402              ctx.get(), input, thread_index, dataset()->captured_func_.get(),
403              prefix(), &iterator);
404          input.clear();  // Release memory as early as possible.
405
406          if (!s.ok()) {
407            mutex_lock l(mu_);
408            workers_[thread_index].outputs.emplace_back(s);
409            workers_[thread_index].is_producing = false;
410            workers_[thread_index].cond_var.notify_one();
411          } else {
412            // 3. Produce elements
413            bool end_of_sequence = false;
414            while (!end_of_sequence) {
415              // 3.a Produce an element!
416              std::vector<Tensor> output_elem;
417              s = iterator->GetNext(ctx.get(), &output_elem, &end_of_sequence);
418
419              // 3.b Make it available to the client.
420              {
421                mutex_lock l(mu_);
422
423                // Wait for space in the prefetch queue.
424                while (!cancelled_ && workers_[thread_index].outputs.size() ==
425                                          dataset()->buffer_output_elements_) {
426                  workers_[thread_index].cond_var.wait(l);
427                }
428                if (cancelled_) return;
429
430                // Output the element.
431                workers_[thread_index].is_producing = !end_of_sequence;
432                if (!end_of_sequence) {
433                  workers_[thread_index].outputs.emplace_back(s);
434                  workers_[thread_index].outputs.back().output.swap(
435                      output_elem);
436                }
437                if (dataset()->sloppy_) {
438                  sloppy_cond_var_.notify_one();
439                } else {
440                  workers_[thread_index].cond_var.notify_one();
441                }
442              }
443            }
444          }
445        }
446      }
447
448      // Mutex & condition variable to guard mutable iterator internals and
449      // coordinate among worker threads and client thread[s].
450      mutex mu_;
451      // The main thread waits on this condition variable if running in sloppy
452      // mode and no values are available.
453      condition_variable sloppy_cond_var_;
454
455      // The iterator producing elements which are converted to datasets by
456      // the dataset()->captured_func_ then interleaved together.
457      // input_impl_ is reset when we have exhausted its input.
458      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
459
460      // The WorkerState structs the worker threads operate on.
461      // workers_ elements are in at most one of interleave_ and staging_.
462      std::vector<WorkerState> workers_ GUARDED_BY(mu_);
463
464      // The iterators to interleave
465      std::vector<WorkerState*> interleave_ GUARDED_BY(mu_);
466      // Prefetched iterators
467      std::deque<WorkerState*> staging_ GUARDED_BY(mu_);
468
469      // The index into output_elements_ for next element to produce.
470      size_t next_index_ GUARDED_BY(mu_) = 0;
471      // The number of items produced so far within the block
472      size_t block_count_ GUARDED_BY(mu_) = 0;
473      // Flag to instruct the worker threads to exit.
474      bool cancelled_ GUARDED_BY(mu_) = false;
475      // The worker threads. This must be last to ensure the
476      // threads have exited before any other members are deallocated.
477      // TODO(b/65178177): Avoid allocating additional threads.
478      std::vector<std::unique_ptr<Thread>> worker_threads_ GUARDED_BY(mu_);
479    };
480
481    const DatasetBase* const input_;
482    const std::unique_ptr<CapturedFunction> captured_func_;
483    const int64 cycle_length_;
484    const int64 block_length_;
485    const bool sloppy_;
486    const int64 buffer_output_elements_;
487    const int64 prefetch_input_elements_;
488    const DataTypeVector output_types_;
489    const std::vector<PartialTensorShape> output_shapes_;
490  };
491
492  const int graph_def_version_;
493  DataTypeVector output_types_;
494  std::vector<PartialTensorShape> output_shapes_;
495  NameAttrList func_;
496};
497
498REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
499                        ParallelInterleaveDatasetOp);
500
501}  // namespace
502
503}  // namespace tensorflow
504