1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/core/framework/dataset.h" 16 17#include "tensorflow/core/graph/graph_def_builder.h" 18#include "tensorflow/core/graph/node_builder.h" 19 20namespace tensorflow { 21 22namespace { 23 24// A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor. 25// Objects of the wrapper class own a reference on an instance of `DatasetBase`, 26// and the wrapper's copy constructor and destructor take care of managing the 27// reference count. 28// 29// NOTE(mrry): This is not a feature-complete implementation of the DT_VARIANT 30// specification. In particular, we cannot currently serialize an arbitrary 31// `DatasetBase` object, so the `Encode()` and `Decode()` methods are not 32// implemented. 33class DatasetVariantWrapper { 34 public: 35 DatasetVariantWrapper() : dataset_(nullptr) {} 36 37 // Transfers ownership of `dataset` to `*this`. 38 explicit DatasetVariantWrapper(DatasetBase* dataset) : dataset_(dataset) {} 39 40 DatasetVariantWrapper(const DatasetVariantWrapper& other) 41 : dataset_(other.dataset_) { 42 if (dataset_) dataset_->Ref(); 43 } 44 45 ~DatasetVariantWrapper() { 46 if (dataset_) dataset_->Unref(); 47 } 48 49 DatasetBase* get() const { return dataset_; } 50 51 string TypeName() const { return "tensorflow::DatasetVariantWrapper"; } 52 string DebugString() const { 53 if (dataset_) { 54 return dataset_->DebugString(); 55 } else { 56 return "<Uninitialized DatasetVariantWrapper>"; 57 } 58 } 59 void Encode(VariantTensorData* data) const { 60 LOG(ERROR) << "The Encode() method is not implemented for " 61 "DatasetVariantWrapper objects."; 62 } 63 bool Decode(const VariantTensorData& data) { 64 LOG(ERROR) << "The Decode() method is not implemented for " 65 "DatasetVariantWrapper objects."; 66 return false; 67 } 68 69 private: 70 DatasetBase* const dataset_; // Owns one reference. 71}; 72 73} // namespace 74 75Status GraphDefBuilderWrapper::AddDataset( 76 const GraphDatasetBase* dataset, 77 const std::vector<std::pair<size_t, Node*>>& inputs, 78 const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs, 79 const std::vector<std::pair<StringPiece, AttrValue>>& attrs, 80 Node** output) { 81 const string& op_type_name = dataset->op_name(); 82 std::unique_ptr<const GraphDefBuilder::Options> opts( 83 new GraphDefBuilder::Options(b_->opts())); 84 // TODO(srbs|mrry): Not all datasets have output_types and output_shapes 85 // attributes defined. It will be nice to have a consistent pattern. 86 bool has_output_types_attr = HasAttr(op_type_name, "output_types"); 87 bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes"); 88 if (has_output_shapes_attr) { 89 opts.reset(new GraphDefBuilder::Options( 90 opts->WithAttr("output_shapes", dataset->output_shapes()))); 91 } 92 if (has_output_types_attr) { 93 opts.reset(new GraphDefBuilder::Options( 94 opts->WithAttr("output_types", dataset->output_dtypes()))); 95 } 96 for (auto attr : attrs) { 97 opts.reset( 98 new GraphDefBuilder::Options(opts->WithAttr(attr.first, attr.second))); 99 } 100 if (opts->HaveError()) { 101 return errors::Internal("AddDataset: Failed to build Options with error ", 102 opts->StatusToString()); 103 } 104 NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name, 105 opts->op_registry()); 106 { 107 size_t total_size = inputs.size() + list_inputs.size(); 108 auto inputs_iter = inputs.begin(); 109 auto list_inputs_iter = list_inputs.begin(); 110 for (int i = 0; i < total_size; i++) { 111 if (inputs_iter != inputs.end() && inputs_iter->first == i) { 112 node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second)); 113 inputs_iter++; 114 } else if (list_inputs_iter != list_inputs.end() && 115 list_inputs_iter->first == i) { 116 std::vector<NodeBuilder::NodeOut> nodeout_inputs; 117 nodeout_inputs.reserve(list_inputs_iter->second.size()); 118 for (Node* n : list_inputs_iter->second) { 119 nodeout_inputs.emplace_back(n); 120 } 121 node_builder.Input(nodeout_inputs); 122 list_inputs_iter++; 123 } else { 124 return errors::InvalidArgument("No input found for index ", i); 125 } 126 } 127 } 128 *output = opts->FinalizeBuilder(&node_builder); 129 if (*output == nullptr) { 130 return errors::Internal("AddDataset: Failed to build ", op_type_name, 131 " op with error ", opts->StatusToString()); 132 } 133 return Status::OK(); 134} 135 136Status GraphDefBuilderWrapper::AddFunction(OpKernelContext* ctx, 137 const string& function_name) { 138 if (b_->HasFunction(function_name)) { 139 LOG(INFO) << "Function with name " << function_name << "already exists in" 140 << " the graph. It will not be added again."; 141 return Status::OK(); 142 } 143 TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name)); 144 const FunctionLibraryDefinition* flib_def = 145 ctx->function_library()->GetFunctionLibraryDefinition(); 146 const FunctionDef* f_def = flib_def->Find(function_name); 147 if (f_def == nullptr) { 148 return errors::InvalidArgument("Unable to find FunctionDef for ", 149 function_name, " in the registry."); 150 } 151 FunctionDefLibrary def; 152 *def.add_function() = *f_def; 153 const string gradient_func = flib_def->FindGradient(function_name); 154 if (!gradient_func.empty()) { 155 GradientDef* g_def = def.add_gradient(); 156 g_def->set_function_name(function_name); 157 g_def->set_gradient_func(gradient_func); 158 } 159 TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def)); 160 161 // Recursively add functions in inputs of function_name. 162 for (const NodeDef& node_def : f_def->node_def()) { 163 const OpRegistrationData* op_reg_data = nullptr; 164 TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data)); 165 if (op_reg_data->is_function_op) { 166 TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name())); 167 } 168 // Recursively add functions in attrs of this NodeDef. 169 for (const auto& pair : node_def.attr()) { 170 TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx)); 171 } 172 } 173 174 // Recursively add functions in attrs of function_name. 175 for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) { 176 TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx)); 177 } 178 return Status::OK(); 179} 180 181void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val, 182 Node** output) { 183 *output = ops::SourceOp( 184 "Const", 185 b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val)); 186} 187 188bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name, 189 const string& attr_name) const { 190 const OpDef* op_def = nullptr; 191 Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def); 192 if (!s.ok() || op_def == nullptr) { 193 return false; 194 } 195 return HasAttr(op_def, attr_name); 196} 197 198Status GraphDatasetBase::Serialize(OpKernelContext* ctx, 199 string* serialized_graph_def, 200 string* output_node) const { 201 GraphDefBuilder b; 202 DatasetGraphDefBuilder db(&b); 203 Node* node = nullptr; 204 TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node)); 205 *output_node = node->name(); 206 GraphDef graph_def; 207 TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); 208 graph_def.SerializeToString(serialized_graph_def); 209 return Status::OK(); 210} 211 212Status GetDatasetFromVariantTensor(const Tensor& tensor, 213 DatasetBase** out_dataset) { 214 if (!(tensor.dtype() == DT_VARIANT || 215 TensorShapeUtils::IsScalar(tensor.shape()))) { 216 return errors::InvalidArgument( 217 "Dataset tensor must be a scalar of dtype DT_VARIANT."); 218 } 219 const Variant& variant = tensor.scalar<Variant>()(); 220 const DatasetVariantWrapper* wrapper = variant.get<DatasetVariantWrapper>(); 221 if (wrapper == nullptr) { 222 return errors::InvalidArgument("Tensor must be a Dataset object."); 223 } 224 *out_dataset = wrapper->get(); 225 if (*out_dataset == nullptr) { 226 return errors::Internal("Read uninitialized Dataset variant."); 227 } 228 return Status::OK(); 229} 230 231Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) { 232 if (!(tensor->dtype() == DT_VARIANT || 233 TensorShapeUtils::IsScalar(tensor->shape()))) { 234 return errors::InvalidArgument( 235 "Dataset tensor must be a scalar of dtype DT_VARIANT."); 236 } 237 tensor->scalar<Variant>()() = DatasetVariantWrapper(dataset); 238 return Status::OK(); 239} 240 241void DatasetOpKernel::Compute(OpKernelContext* ctx) { 242 DatasetBase* dataset = nullptr; 243 MakeDataset(ctx, &dataset); 244 if (ctx->status().ok()) { 245 Tensor* output = nullptr; 246 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); 247 OP_REQUIRES_OK(ctx, StoreDatasetInVariantTensor(dataset, output)); 248 } 249} 250 251void UnaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx, 252 DatasetBase** output) { 253 DatasetBase* input; 254 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input)); 255 MakeDataset(ctx, input, output); 256} 257 258void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx, 259 DatasetBase** output) { 260 DatasetBase* input; 261 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &input)); 262 DatasetBase* another_input; 263 OP_REQUIRES_OK(ctx, 264 GetDatasetFromVariantTensor(ctx->input(1), &another_input)); 265 MakeDataset(ctx, input, another_input, output); 266} 267 268const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH"; 269const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] = 270 "_DATASET_GRAPH_OUTPUT_NODE"; 271 272} // namespace tensorflow 273