1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/function.h"
16#include "tensorflow/core/framework/partial_tensor_shape.h"
17#include "tensorflow/core/framework/tensor.h"
18#include "tensorflow/core/kernels/data/captured_function.h"
19#include "tensorflow/core/kernels/data/dataset.h"
20#include "tensorflow/core/lib/random/random.h"
21
22namespace tensorflow {
23
24namespace {
25
26// See documentation in ../ops/dataset_ops.cc for a high-level
27// description of the following op.
28
29class MapDatasetOp : public UnaryDatasetOpKernel {
30 public:
31  explicit MapDatasetOp(OpKernelConstruction* ctx)
32      : UnaryDatasetOpKernel(ctx),
33        graph_def_version_(ctx->graph_def_version()) {
34    OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
35    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
36    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
37  }
38
39  void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
40                   DatasetBase** output) override {
41    OpInputList inputs;
42    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
43    std::vector<Tensor> other_arguments;
44    other_arguments.reserve(inputs.size());
45    for (const Tensor& t : inputs) {
46      other_arguments.push_back(t);
47    }
48
49    std::unique_ptr<CapturedFunction> captured_func;
50    OP_REQUIRES_OK(ctx, CapturedFunction::Create(
51                            func_, std::move(other_arguments), &captured_func));
52
53    *output = new Dataset(ctx, input, func_, std::move(captured_func),
54                          output_types_, output_shapes_);
55  }
56
57 private:
58  class Dataset : public GraphDatasetBase {
59   public:
60    Dataset(OpKernelContext* ctx, const DatasetBase* input,
61            const NameAttrList& func,
62            std::unique_ptr<CapturedFunction> captured_func,
63            const DataTypeVector& output_types,
64            const std::vector<PartialTensorShape>& output_shapes)
65        : GraphDatasetBase(ctx),
66          input_(input),
67          func_(func),
68          captured_func_(std::move(captured_func)),
69          output_types_(output_types),
70          output_shapes_(output_shapes) {
71      input_->Ref();
72    }
73
74    ~Dataset() override { input_->Unref(); }
75
76    std::unique_ptr<IteratorBase> MakeIterator(
77        const string& prefix) const override {
78      return std::unique_ptr<IteratorBase>(
79          new Iterator({this, strings::StrCat(prefix, "::Map")}));
80    }
81
82    const DataTypeVector& output_dtypes() const override {
83      return output_types_;
84    }
85    const std::vector<PartialTensorShape>& output_shapes() const override {
86      return output_shapes_;
87    }
88
89    string DebugString() override { return "MapDatasetOp::Dataset"; }
90
91   protected:
92    Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
93                              Node** output) const override {
94      TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
95      Node* input_graph_node = nullptr;
96      TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
97
98      DataTypeVector other_arguments_types;
99      other_arguments_types.reserve(captured_func_->captured_inputs().size());
100      std::vector<Node*> other_arguments;
101      other_arguments.reserve(captured_func_->captured_inputs().size());
102      for (const Tensor& t : captured_func_->captured_inputs()) {
103        Node* node;
104        TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
105        other_arguments.emplace_back(node);
106        other_arguments_types.emplace_back(t.dtype());
107      }
108      AttrValue f;
109      b->BuildAttrValue(func_, &f);
110      AttrValue other_arguments_types_attr;
111      b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
112
113      TF_RETURN_IF_ERROR(b->AddDataset(
114          this, {std::make_pair(0, input_graph_node)},  // Single tensor inputs.
115          {std::make_pair(1, other_arguments)},         // Tensor list inputs.
116          {std::make_pair("f", f),
117           std::make_pair("Targuments", other_arguments_types_attr)},  // Attrs
118          output));
119      return Status::OK();
120    }
121
122   private:
123    class Iterator : public DatasetIterator<Dataset> {
124     public:
125      explicit Iterator(const Params& params)
126          : DatasetIterator<Dataset>(params),
127            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
128
129      Status GetNextInternal(IteratorContext* ctx,
130                             std::vector<Tensor>* out_tensors,
131                             bool* end_of_sequence) override {
132        // NOTE(mrry): This method is thread-safe as long as
133        // `input_impl_` and `f` are thread-safe. However, if multiple
134        // threads enter this method, outputs may be observed in a
135        // non-deterministic order.
136
137        std::vector<Tensor> args;
138        TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence));
139        if (*end_of_sequence) {
140          return Status::OK();
141        }
142
143        // TODO(mrry): Avoid blocking a threadpool thread. We will need to
144        // stack-rip the iterators and use async kernels.
145        Status s =
146            dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
147        if (errors::IsOutOfRange(s)) {
148          // `f` may deliberately raise `errors::OutOfRange` to indicate
149          // that we should terminate the iteration early.
150          *end_of_sequence = true;
151          return Status::OK();
152        } else {
153          return s;
154        }
155      }
156
157     protected:
158      Status SaveInternal(IteratorStateWriter* writer) override {
159        TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
160        return Status::OK();
161      }
162
163      Status RestoreInternal(IteratorContext* ctx,
164                             IteratorStateReader* reader) override {
165        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
166        return Status::OK();
167      }
168
169     private:
170      const std::unique_ptr<IteratorBase> input_impl_;
171    };
172
173    const DatasetBase* const input_;
174    const NameAttrList func_;
175    const std::unique_ptr<CapturedFunction> captured_func_;
176    const DataTypeVector output_types_;
177    const std::vector<PartialTensorShape> output_shapes_;
178  };
179
180  const int graph_def_version_;
181  DataTypeVector output_types_;
182  std::vector<PartialTensorShape> output_shapes_;
183  NameAttrList func_;
184};
185
186REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
187
188}  // namespace
189
190}  // namespace tensorflow
191