1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/framework/dataset.h"
16
17#include "tensorflow/core/graph/graph_def_builder.h"
18#include "tensorflow/core/graph/node_builder.h"
19
20namespace tensorflow {
21
22namespace {
23
24// A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
25// Objects of the wrapper class own a reference on an instance of `DatasetBase`,
26// and the wrapper's copy constructor and destructor take care of managing the
27// reference count.
28//
29// NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT
30// specification. In particular, we cannot currently serialize an arbitrary
31// `DatasetBase` object, so the `Encode()` and `Decode()` methods are not
32// implemented.
33class DatasetVariantWrapper {
34 public:
35  DatasetVariantWrapper() : dataset_(nullptr) {}
36
37  // Transfers ownership of `dataset` to `*this`.
38  explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {}
39
40  DatasetVariantWrapper(const DatasetVariantWrapper& other)
41      : dataset_(other.dataset_) {
42    if (dataset_) dataset_->Ref();
43  }
44
45  ~DatasetVariantWrapper() {
46    if (dataset_) dataset_->Unref();
47  }
48
49  DatasetBase* get() const { return dataset_; }
50
51  string TypeName() const { return "tensorflow::DatasetVariantWrapper"; }
52  string DebugString() const {
53    if (dataset_) {
54      return dataset_->DebugString();
55    } else {
56      return "<Uninitialized DatasetVariantWrapper>";
57    }
58  }
59  void Encode(VariantTensorData* data) const {
60    LOG(ERROR) << "The Encode() method is not implemented for "
61                  "DatasetVariantWrapper objects.";
62  }
63  bool Decode(const VariantTensorData& data) {
64    LOG(ERROR) << "The Decode() method is not implemented for "
65                  "DatasetVariantWrapper objects.";
66    return false;
67  }
68
69 private:
70  DatasetBase* const dataset_;  // Owns one reference.
71};
72
73}  // namespace
74
75Status GraphDefBuilderWrapper::AddDataset(
76    const GraphDatasetBase* dataset,
77    const std::vector<std::pair<size_t, Node*>>& inputs,
78    const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
79    const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
80    Node** output) {
81  const string& op_type_name = dataset->op_name();
82  std::unique_ptr<const GraphDefBuilder::Options> opts(
83      new GraphDefBuilder::Options(b_->opts()));
84  // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
85  // attributes defined. It will be nice to have a consistent pattern.
86  bool has_output_types_attr = HasAttr(op_type_name, "output_types");
87  bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
88  if (has_output_shapes_attr) {
89    opts.reset(new GraphDefBuilder::Options(
90        opts->WithAttr("output_shapes", dataset->output_shapes())));
91  }
92  if (has_output_types_attr) {
93    opts.reset(new GraphDefBuilder::Options(
94        opts->WithAttr("output_types", dataset->output_dtypes())));
95  }
96  for (auto attr : attrs) {
97    opts.reset(
98        new GraphDefBuilder::Options(opts->WithAttr(attr.first, attr.second)));
99  }
100  if (opts->HaveError()) {
101    return errors::Internal("AddDataset: Failed to build Options with error ",
102                            opts->StatusToString());
103  }
104  NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
105                           opts->op_registry());
106  {
107    size_t total_size = inputs.size() + list_inputs.size();
108    auto inputs_iter = inputs.begin();
109    auto list_inputs_iter = list_inputs.begin();
110    for (int i = 0; i < total_size; i++) {
111      if (inputs_iter != inputs.end() && inputs_iter->first == i) {
112        node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
113        inputs_iter++;
114      } else if (list_inputs_iter != list_inputs.end() &&
115                 list_inputs_iter->first == i) {
116        std::vector<NodeBuilder::NodeOut> nodeout_inputs;
117        nodeout_inputs.reserve(list_inputs_iter->second.size());
118        for (Node* n : list_inputs_iter->second) {
119          nodeout_inputs.emplace_back(n);
120        }
121        node_builder.Input(nodeout_inputs);
122        list_inputs_iter++;
123      } else {
124        return errors::InvalidArgument("No input found for index ", i);
125      }
126    }
127  }
128  *output = opts->FinalizeBuilder(&node_builder);
129  if (*output == nullptr) {
130    return errors::Internal("AddDataset: Failed to build ", op_type_name,
131                            " op with error ", opts->StatusToString());
132  }
133  return Status::OK();
134}
135
136Status GraphDefBuilderWrapper::AddFunction(OpKernelContext* ctx,
137                                           const string& function_name) {
138  if (b_->HasFunction(function_name)) {
139    LOG(INFO) << "Function with name " << function_name << "already exists in"
140              << " the graph. It will not be added again.";
141    return Status::OK();
142  }
143  TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name));
144  const FunctionLibraryDefinition* flib_def =
145      ctx->function_library()->GetFunctionLibraryDefinition();
146  const FunctionDef* f_def = flib_def->Find(function_name);
147  if (f_def == nullptr) {
148    return errors::InvalidArgument("Unable to find FunctionDef for ",
149                                   function_name, " in the registry.");
150  }
151  FunctionDefLibrary def;
152  *def.add_function() = *f_def;
153  const string gradient_func = flib_def->FindGradient(function_name);
154  if (!gradient_func.empty()) {
155    GradientDef* g_def = def.add_gradient();
156    g_def->set_function_name(function_name);
157    g_def->set_gradient_func(gradient_func);
158  }
159  TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
160
161  // Recursively add functions in inputs of function_name.
162  for (const NodeDef& node_def : f_def->node_def()) {
163    const OpRegistrationData* op_reg_data = nullptr;
164    TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data));
165    if (op_reg_data->is_function_op) {
166      TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
167    }
168    // Recursively add functions in attrs of this NodeDef.
169    for (const auto& pair : node_def.attr()) {
170      TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx));
171    }
172  }
173
174  // Recursively add functions in attrs of function_name.
175  for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
176    TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx));
177  }
178  return Status::OK();
179}
180
181void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
182                                               Node** output) {
183  *output = ops::SourceOp(
184      "Const",
185      b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
186}
187
188bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name,
189                                     const string& attr_name) const {
190  const OpDef* op_def = nullptr;
191  Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);
192  if (!s.ok() || op_def == nullptr) {
193    return false;
194  }
195  return HasAttr(op_def, attr_name);
196}
197
198Status GraphDatasetBase::Serialize(OpKernelContext* ctx,
199                                   string* serialized_graph_def,
200                                   string* output_node) const {
201  GraphDefBuilder b;
202  DatasetGraphDefBuilder db(&b);
203  Node* node = nullptr;
204  TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
205  *output_node = node->name();
206  GraphDef graph_def;
207  TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
208  graph_def.SerializeToString(serialized_graph_def);
209  return Status::OK();
210}
211
212Status GetDatasetFromVariantTensor(const Tensor& tensor,
213                                   DatasetBase** out_dataset) {
214  if (!(tensor.dtype() == DT_VARIANT ||
215        TensorShapeUtils::IsScalar(tensor.shape()))) {
216    return errors::InvalidArgument(
217        "Dataset tensor must be a scalar of dtype DT_VARIANT.");
218  }
219  const Variant& variant = tensor.scalar<Variant>()();
220  const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>();
221  if (wrapper == nullptr) {
222    return errors::InvalidArgument("Tensor must be a Dataset object.");
223  }
224  *out_dataset = wrapper->get();
225  if (*out_dataset == nullptr) {
226    return errors::Internal("Read uninitialized Dataset variant.");
227  }
228  return Status::OK();
229}
230
231Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
232  if (!(tensor->dtype() == DT_VARIANT ||
233        TensorShapeUtils::IsScalar(tensor->shape()))) {
234    return errors::InvalidArgument(
235        "Dataset tensor must be a scalar of dtype DT_VARIANT.");
236  }
237  tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset);
238  return Status::OK();
239}
240
241void DatasetOpKernel::Compute(OpKernelContext* ctx) {
242  DatasetBase* dataset = nullptr;
243  MakeDataset(ctx, &dataset);
244  if (ctx->status().ok()) {
245    Tensor* output = nullptr;
246    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
247    OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output));
248  }
249}
250
251void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
252                                       DatasetBase** output) {
253  DatasetBase* input;
254  OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
255  MakeDataset(ctx, input, output);
256}
257
258void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
259                                        DatasetBase** output) {
260  DatasetBase* input;
261  OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input));
262  DatasetBase* another_input;
263  OP_REQUIRES_OK(ctx,
264                 GetDatasetFromVariantTensor(ctx->input(1), &another_input));
265  MakeDataset(ctx, input, another_input, output);
266}
267
268const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
269const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
270    "_DATASET_GRAPH_OUTPUT_NODE";
271
272}  // namespace tensorflow
273