python_eager_op_gen.cc revision 40f0cf009641ef0d827729bb01e8ed50d97fd109
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/python/eager/python_eager_op_gen.h" 16 17#include <stdio.h> 18#include <sstream> 19#include <unordered_map> 20#include "tensorflow/core/framework/api_def.pb.h" 21#include "tensorflow/core/framework/attr_value.pb.h" 22#include "tensorflow/core/framework/op.h" 23#include "tensorflow/core/framework/op_def.pb_text.h" 24#include "tensorflow/core/framework/op_def.pb.h" 25#include "tensorflow/core/framework/op_def_util.h" 26#include "tensorflow/core/framework/op_gen_lib.h" 27#include "tensorflow/core/framework/tensor.pb_text.h" 28#include "tensorflow/core/framework/types.h" 29#include "tensorflow/core/framework/types.pb.h" 30#include "tensorflow/core/lib/gtl/map_util.h" 31#include "tensorflow/core/lib/gtl/stl_util.h" 32#include "tensorflow/core/lib/strings/str_util.h" 33#include "tensorflow/core/lib/strings/strcat.h" 34#include "tensorflow/core/lib/strings/stringprintf.h" 35#include "tensorflow/core/platform/logging.h" 36#include "tensorflow/core/platform/macros.h" 37#include "tensorflow/core/platform/types.h" 38#include "tensorflow/python/framework/python_op_gen_internal.h" 39 40namespace tensorflow { 41namespace { 42 43const int kRightMargin = 78; 44 45constexpr char kEagerFallbackSuffix[] = "_eager_fallback"; 46 47string AttrVarName(const string& attr_name, 48 std::unordered_map<string, string>* attr_expressions) { 49 const string var = strings::StrCat("_attr_", attr_name); 50 if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var; 51 return var; 52} 53 54void AddInferredAttr(const string& indentation, const string& attr_name, 55 const string& value_expression, string* result, 56 std::unordered_map<string, string>* attr_expressions) { 57 strings::StrAppend(result, indentation, 58 AttrVarName(attr_name, attr_expressions), " = ", 59 value_expression, "\n"); 60} 61 62string VectorToTuple(const std::vector<string>& l) { 63 if (l.size() == 1) return strings::StrCat("(", l.front(), ",)"); 64 string ret = "("; 65 for (int i = 0; i < l.size(); ++i) { 66 if (i > 0) { 67 strings::StrAppend(&ret, ", "); 68 } 69 strings::StrAppend(&ret, l[i]); 70 } 71 strings::StrAppend(&ret, ")"); 72 return ret; 73} 74 75void Unflatten(const string& prefix, const std::vector<string>& output_sizes, 76 const string& var, string* result) { 77 for (int i = 0; i < output_sizes.size(); ++i) { 78 if (!output_sizes[i].empty()) { 79 strings::StrAppend(result, prefix, var, " = "); 80 if (i > 0) strings::StrAppend(result, var, "[:", i, "] + "); 81 if (i + 1 < output_sizes.size()) { 82 // Special case i == 0 to avoid "0 +" in the generated code. 83 if (i == 0) { 84 strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ", 85 var, "[", output_sizes[i], ":]"); 86 } else { 87 strings::StrAppend(result, "[", var, "[", i, ":", i, " + ", 88 output_sizes[i], "]] + ", var, "[", i, " + ", 89 output_sizes[i], ":]"); 90 } 91 } else { 92 strings::StrAppend(result, "[", var, "[", i, ":]]"); 93 } 94 strings::StrAppend(result, "\n"); 95 } 96 } 97} 98 99string TensorPBString(const TensorProto& pb) { 100 // Note: This gets used in the argument list, and so must survive naive 101 // word wrapping. 102 return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\""); 103} 104 105const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { 106 for (int i = 0; i < api_def.in_arg_size(); ++i) { 107 if (api_def.in_arg(i).name() == name) { 108 return &api_def.in_arg(i); 109 } 110 } 111 return nullptr; 112} 113 114class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { 115 public: 116 GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, 117 const string& function_name) 118 : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) { 119 op_name_ = function_name_; 120 op_name_.Consume("_"); 121 } 122 ~GenEagerPythonOp() override {} 123 124 string Code() override; 125 126 protected: 127 void HandleGraphMode(const string& function_setup); 128 129 string GetEagerNotAllowedError(); 130 void ExpectListArg(const string& indentation, const string& arg_name, 131 string* output); 132 bool GetEagerFunctionSetup(const string& indentation, string* function_setup); 133 void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes, 134 string* num_outputs_expr); 135 136 void AddEagerFunctionTeardown(const string& indentation, 137 const std::vector<string>& output_sizes, 138 bool execute_record_gradient); 139 140 bool AddEagerFastPathAndGraphCode(const string& parameters, 141 const std::vector<string>& output_sizes, 142 const string& eager_not_allowed_error); 143 bool AddEagerFallbackCode(const string& parameters, 144 const std::vector<string>& output_sizes, 145 const string& num_outputs_expr, 146 const string& eager_not_allowed_error); 147 void AddEagerFastPathExecute(); 148 149 void AddEagerInferredAttrs(const string& indentation); 150 void AddEagerInputCasts(const string& indentation); 151 void AddEagerAttrs(const string& indentation); 152 void AddEagerExecute(const string& indentation, 153 const string& num_outputs_expr); 154 155 void AddAttrForArg(const string& attr, int arg_index) { 156 gtl::InsertIfNotPresent(&inferred_attrs_, attr, 157 op_def_.input_arg(arg_index).name()); 158 auto iter = attr_to_args_.find(attr); 159 if (iter == attr_to_args_.end()) { 160 attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index})); 161 } else { 162 iter->second.push_back(arg_index); 163 } 164 } 165 166 // Returns a string expression representing a flattened list of all 167 // the inputs given by `*input_indices` (or all inputs if 168 // `input_indices` is nullptr). `*output_sizes` can be used to unflatten. 169 string FlattenInputs(const std::vector<int>* input_indices, 170 std::vector<string>* output_sizes) const; 171 172 StringPiece op_name_; 173 typedef std::unordered_map<string, std::vector<int>> AttrToArgMap; 174 AttrToArgMap attr_to_args_; 175 std::unordered_map<string, string> attr_expressions_; 176 // This has all the input args followed by those attrs that don't have 177 // defaults. 178 std::vector<python_op_gen_internal::ParamNames> params_no_default_; 179 // The parameters with defaults (these have to be listed after those without). 180 // No input args are included, just attrs. 181 std::vector<std::pair<python_op_gen_internal::ParamNames, string>> 182 params_with_default_; 183}; 184 185string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, 186 const string& function_name) { 187 return GenEagerPythonOp(op_def, api_def, function_name).Code(); 188} 189 190string GenEagerPythonOp::FlattenInputs( 191 const std::vector<int>* input_indices, 192 std::vector<string>* output_sizes) const { 193 string inputs; 194 enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING; 195 const int n = input_indices != nullptr ? input_indices->size() 196 : op_def_.input_arg_size(); 197 for (int j = 0; j < n; ++j) { 198 const int i = input_indices ? (*input_indices)[j] : j; 199 const auto& arg(op_def_.input_arg(i)); 200 const bool is_list = 201 !arg.type_list_attr().empty() || !arg.number_attr().empty(); 202 if (is_list) { 203 if (inputs_state == WAS_SOLO_INPUT) { 204 strings::StrAppend(&inputs, "] + "); 205 } else if (inputs_state == WAS_LIST_INPUT) { 206 strings::StrAppend(&inputs, " + "); 207 } 208 strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")"); 209 inputs_state = WAS_LIST_INPUT; 210 if (output_sizes != nullptr) { 211 if (!arg.number_attr().empty()) { 212 output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr)); 213 } else { 214 output_sizes->emplace_back( 215 strings::StrCat("len(", param_names_[i].GetRenameTo(), ")")); 216 } 217 } 218 } else { 219 if (inputs_state == WAS_SOLO_INPUT) { 220 strings::StrAppend(&inputs, ", "); 221 } else if (inputs_state == WAS_LIST_INPUT) { 222 strings::StrAppend(&inputs, " + ["); 223 } else { 224 strings::StrAppend(&inputs, "["); 225 } 226 strings::StrAppend(&inputs, param_names_[i].GetRenameTo()); 227 inputs_state = WAS_SOLO_INPUT; 228 if (output_sizes != nullptr) output_sizes->emplace_back(); 229 } 230 } 231 if (inputs_state == STARTING) return "[]"; 232 if (inputs_state == WAS_SOLO_INPUT) { 233 strings::StrAppend(&inputs, "]"); 234 } 235 return inputs; 236} 237 238string GenEagerPythonOp::Code() { 239 if (api_def_.visibility() == ApiDef::SKIP) { 240 return ""; 241 } 242 243 for (int i = 0; i < api_def_.arg_order_size(); ++i) { 244 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); 245 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); 246 params_no_default_.emplace_back(api_def_arg.name(), 247 api_def_arg.rename_to()); 248 if (!arg.type_attr().empty()) { 249 AddAttrForArg(arg.type_attr(), i); 250 } else if (!arg.type_list_attr().empty()) { 251 AddAttrForArg(arg.type_list_attr(), i); 252 } 253 if (!arg.number_attr().empty()) { 254 AddAttrForArg(arg.number_attr(), i); 255 } 256 } 257 for (int i = 0; i < op_def_.attr_size(); ++i) { 258 const auto& attr(op_def_.attr(i)); 259 const auto& api_def_attr(api_def_.attr(i)); 260 // Do not add inferred attrs to the Python function signature. 261 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { 262 if (api_def_attr.has_default_value()) { 263 if (attr.type() == "tensor") { 264 params_with_default_.emplace_back( 265 python_op_gen_internal::ParamNames(api_def_attr.name(), 266 api_def_attr.rename_to()), 267 strings::StrCat( 268 "_execute.make_tensor(", 269 TensorPBString(api_def_attr.default_value().tensor()), ", \"", 270 api_def_attr.rename_to(), "\")")); 271 } else if (attr.type() == "list(tensor)") { 272 std::vector<string> pbtxt; 273 for (const auto& pb : api_def_attr.default_value().list().tensor()) { 274 pbtxt.emplace_back(TensorPBString(pb)); 275 } 276 params_with_default_.emplace_back( 277 python_op_gen_internal::ParamNames(api_def_attr.name(), 278 api_def_attr.rename_to()), 279 strings::StrCat("[_execute.make_tensor(_pb, \"", 280 api_def_attr.rename_to(), "\") for _pb in ", 281 VectorToTuple(pbtxt), "]")); 282 } else { 283 params_with_default_.emplace_back( 284 python_op_gen_internal::ParamNames(api_def_attr.name(), 285 api_def_attr.rename_to()), 286 python_op_gen_internal::AttrValueToPython( 287 attr.type(), api_def_attr.default_value(), "_dtypes.")); 288 } 289 } else { 290 params_no_default_.emplace_back(api_def_attr.name(), 291 api_def_attr.rename_to()); 292 } 293 } 294 } 295 296 // Save the list of attr parameters (attrs that won't be inferred), 297 // those with defaults go at the end. 298 // Get the attrs in the order we want by taking the attrs without defaults 299 // from the end of params_no_default_, and adding params_no_default_. 300 attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() + 301 params_with_default_.size()); 302 for (int i = op_def_.input_arg_size(); i < params_no_default_.size(); ++i) { 303 attrs_.push_back(params_no_default_[i].GetName()); 304 } 305 for (const auto& p : params_with_default_) { 306 attrs_.push_back(p.first.GetName()); 307 } 308 309 param_names_.reserve(params_no_default_.size() + params_with_default_.size()); 310 param_names_.insert(param_names_.begin(), params_no_default_.begin(), 311 params_no_default_.end()); 312 for (const auto& param_and_default : params_with_default_) { 313 param_names_.push_back(param_and_default.first); 314 } 315 316 string parameters; 317 for (const auto& param : params_no_default_) { 318 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 319 strings::StrAppend(¶meters, param.GetRenameTo()); 320 } 321 for (const auto& param_and_default : params_with_default_) { 322 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 323 strings::StrAppend(¶meters, param_and_default.first.GetRenameTo(), "=", 324 param_and_default.second); 325 } 326 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 327 strings::StrAppend(¶meters, "name=None"); 328 329 // Add attr_expressions_ for attrs that are params. 330 for (int i = 0; i < attrs_.size(); ++i) { 331 const string& attr_name = attrs_[i]; 332 const string& attr_api_name = 333 param_names_[i + op_def_.input_arg_size()].GetRenameTo(); 334 attr_expressions_[attr_name] = attr_api_name; 335 } 336 // Add attr_expressions_ for attrs that are inferred. 337 for (int i = 0; i < op_def_.attr_size(); ++i) { 338 const auto& attr(op_def_.attr(i)); 339 if (attr.type() == "int") { 340 auto arg_list = attr_to_args_.find(attr.name()); 341 if (arg_list != attr_to_args_.end()) { 342 AttrVarName(attr.name(), &attr_expressions_); 343 } 344 } 345 } 346 347 string num_outputs_expr; 348 std::vector<string> output_sizes(num_outs_); 349 GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr); 350 351 string eager_not_allowed_error = GetEagerNotAllowedError(); 352 353 if (!AddEagerFastPathAndGraphCode(parameters, output_sizes, 354 eager_not_allowed_error)) { 355 return result_; 356 } 357 358 if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr, 359 eager_not_allowed_error)) { 360 return result_; 361 } 362 363 return prelude_ + result_; 364} 365 366void GenEagerPythonOp::HandleGraphMode(const string& function_setup) { 367 // Handle graph-mode case 368 strings::StrAppend(&result_, 369 " _ctx = _context.context()\n" 370 " if _ctx.in_graph_mode():\n", 371 function_setup, 372 " _, _, _op = _op_def_lib._apply_op_helper(\n"); 373 AddBodyNoReturn(" "); 374 if (num_outs_ > 0) { 375 strings::StrAppend(&result_, " _result = _op.outputs[:]\n"); 376 // Special case handling for stateful op with single list output 377 // that might be empty. 378 if (num_outs_ == 1 && op_def_.is_stateful() && 379 (!op_def_.output_arg(0).number_attr().empty() || 380 !op_def_.output_arg(0).type_list_attr().empty())) { 381 // TODO(josh11b): Can skip this if the number_attr/type_list_attr has 382 // a constraint indicating that this can never be empty. 383 strings::StrAppend(&result_, 384 " if not _result:\n" 385 " return _op\n"); 386 } 387 strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n"); 388 389 // Compute graph-mode attrs. 390 if (op_def_.attr_size() > 0) { 391 string attr_values; 392 for (int i = 0; i < op_def_.attr_size(); ++i) { 393 if (i > 0) strings::StrAppend(&attr_values, ", "); 394 const auto& attr_name(op_def_.attr(i).name()); 395 strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"", 396 attr_name, "\")"); 397 } 398 strings::StrAppend(&attr_values, ")"); 399 strings::StrAppend(&result_, 400 WordWrap(" _attrs = (", attr_values, kRightMargin), 401 "\n"); 402 } else { 403 strings::StrAppend(&result_, " _attrs = None\n"); 404 } 405 } else { 406 strings::StrAppend(&result_, " return _op\n"); 407 } 408} 409 410string GenEagerPythonOp::GetEagerNotAllowedError() { 411 bool eager_allowed = true; 412 string ref_arg; 413 for (int i = 0; i < op_def_.input_arg_size(); ++i) { 414 const auto& arg = op_def_.input_arg(i); 415 if (arg.is_ref()) { 416 eager_allowed = false; 417 DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name()); 418 ref_arg = api_def_.in_arg(i).rename_to(); 419 } 420 } 421 for (int i = 0; i < op_def_.output_arg_size(); ++i) { 422 const auto& arg = op_def_.output_arg(i); 423 if (arg.is_ref()) { 424 eager_allowed = false; 425 DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name()); 426 ref_arg = api_def_.out_arg(i).rename_to(); 427 } 428 } 429 430 if (eager_allowed) return ""; 431 432 return strings::StrCat("raise RuntimeError(\"", op_name_, 433 " op does not support eager execution. ", "Arg '", 434 ref_arg, "' is a ref.\")\n"); 435} 436 437void GenEagerPythonOp::ExpectListArg(const string& indentation, 438 const string& arg_name, string* output) { 439 strings::StrAppend(output, indentation, "if not isinstance(", arg_name, 440 ", (list, tuple)):\n", indentation, " raise TypeError(\n", 441 indentation, " \"Expected list for '", arg_name, 442 "' argument to \"\n", indentation, " \"'", op_name_, 443 "' Op, not %r.\" % ", arg_name, ")\n"); 444} 445 446bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation, 447 string* function_setup) { 448 // Validate list inputs, infer length attrs. 449 for (int i = 0; i < op_def_.attr_size(); ++i) { 450 const auto& attr(op_def_.attr(i)); 451 if (attr.type() == "int") { 452 auto arg_list = attr_to_args_.find(attr.name()); 453 if (arg_list != attr_to_args_.end()) { 454 // Inferred int attrs are the lengths of inputs. Validate those 455 // inputs are lists and have the same length. 456 for (auto iter = arg_list->second.begin(); 457 iter != arg_list->second.end(); ++iter) { 458 const string& arg_api_name = param_names_[*iter].GetRenameTo(); 459 ExpectListArg(indentation, arg_api_name, function_setup); 460 if (iter == arg_list->second.begin()) { 461 AddInferredAttr(indentation, attr.name(), 462 strings::StrCat("len(", arg_api_name, ")"), 463 function_setup, &attr_expressions_); 464 } else { 465 const auto& attr_var = attr_expressions_[attr.name()]; 466 strings::StrAppend( 467 function_setup, indentation, "if len(", arg_api_name, 468 ") != ", attr_var, ":\n", indentation, " raise ValueError(\n", 469 indentation, " \"List argument '", arg_api_name, "' to '", 470 op_name_, "' Op with length %d \"\n", indentation, 471 " \"must match length %d of argument '", 472 inferred_attrs_[attr.name()], "'.\" %\n", indentation, 473 " (len(", arg_api_name, "), ", attr_var, "))\n"); 474 } 475 } 476 } 477 } 478 } 479 480 for (int i = 0; i < attrs_.size(); ++i) { 481 const string& attr_name = attrs_[i]; 482 const auto& param = param_names_[i + op_def_.input_arg_size()]; 483 const auto& attr = *FindAttr(attr_name, op_def_); 484 const string& attr_api_name = param.GetRenameTo(); 485 StringPiece attr_type = attr.type(); 486 attr_expressions_[attr_name] = attr_api_name; 487 const int default_index = i - (attrs_.size() - params_with_default_.size()); 488 if (default_index >= 0) { 489 const string& default_value = params_with_default_[default_index].second; 490 strings::StrAppend(function_setup, indentation, "if ", attr_api_name, 491 " is None:\n"); 492 strings::StrAppend(function_setup, indentation, " ", attr_api_name, 493 " = ", default_value, "\n"); 494 } 495 if (attr_type.starts_with("list(")) { 496 ExpectListArg(indentation, attr_api_name, function_setup); 497 } 498 499 if (attr_type == "string") { 500 strings::StrAppend(function_setup, indentation, attr_api_name, 501 " = _execute.make_str(", attr_api_name, ", \"", 502 attr_api_name, "\")\n"); 503 } else if (attr_type == "list(string)") { 504 strings::StrAppend(function_setup, indentation, attr_api_name, 505 " = [_execute.make_str(_s, \"", attr_api_name, 506 "\") for _s in ", attr_api_name, "]\n"); 507 } else if (attr_type == "int") { 508 strings::StrAppend(function_setup, indentation, attr_api_name, 509 " = _execute.make_int(", attr_api_name, ", \"", 510 attr_api_name, "\")\n"); 511 } else if (attr_type == "list(int)") { 512 strings::StrAppend(function_setup, indentation, attr_api_name, 513 " = [_execute.make_int(_i, \"", attr_api_name, 514 "\") for _i in ", attr_api_name, "]\n"); 515 } else if (attr_type == "float") { 516 strings::StrAppend(function_setup, indentation, attr_api_name, 517 " = _execute.make_float(", attr_api_name, ", \"", 518 attr_api_name, "\")\n"); 519 } else if (attr_type == "list(float)") { 520 strings::StrAppend(function_setup, indentation, attr_api_name, 521 " = [_execute.make_float(_f, \"", attr_api_name, 522 "\") for _f in ", attr_api_name, "]\n"); 523 } else if (attr_type == "bool") { 524 strings::StrAppend(function_setup, indentation, attr_api_name, 525 " = _execute.make_bool(", attr_api_name, ", \"", 526 attr_api_name, "\")\n"); 527 } else if (attr_type == "list(bool)") { 528 strings::StrAppend(function_setup, indentation, attr_api_name, 529 " = [_execute.make_bool(_b, \"", attr_api_name, 530 "\") for _b in ", attr_api_name, "]\n"); 531 } else if (attr_type == "type") { 532 strings::StrAppend(function_setup, indentation, attr_api_name, 533 " = _execute.make_type(", attr_api_name, ", \"", 534 attr_api_name, "\")\n"); 535 } else if (attr_type == "list(type)") { 536 strings::StrAppend(function_setup, indentation, attr_api_name, 537 " = [_execute.make_type(_t, \"", attr_api_name, 538 "\") for _t in ", attr_api_name, "]\n"); 539 } else if (attr_type == "shape") { 540 strings::StrAppend(function_setup, indentation, attr_api_name, 541 " = _execute.make_shape(", attr_api_name, ", \"", 542 attr_api_name, "\")\n"); 543 } else if (attr_type == "list(shape)") { 544 strings::StrAppend(function_setup, indentation, attr_api_name, 545 " = [_execute.make_shape(_s, \"", attr_api_name, 546 "\") for _s in ", attr_api_name, "]\n"); 547 } else if (attr_type == "tensor") { 548 strings::StrAppend(function_setup, indentation, attr_api_name, 549 " = _execute.make_tensor(", attr_api_name, ", \"", 550 attr_api_name, "\")\n"); 551 } else if (attr_type == "list(tensor)") { 552 strings::StrAppend(function_setup, indentation, attr_api_name, 553 " = [_execute.make_tensor(_t, \"", attr_api_name, 554 "\") for _t in ", attr_api_name, "]\n"); 555 } else if (attr_type != "func") { 556 *function_setup = 557 strings::StrCat("# No definition for ", function_name_, 558 " since we don't support attrs with type\n" 559 "# '", 560 attr_type, "' right now.\n\n"); 561 return false; 562 } 563 } 564 return true; 565} 566 567// If output i is list output, output_sizes[i] will be set to a 568// string with the python expression that will evaluate to its 569// length. output_sizes[i] is empty for non-list outputs. 570void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr( 571 std::vector<string>* output_sizes, string* num_outputs_expr) { 572 // Expression representing the number of outputs. 573 int num_fixed_outputs = 0; 574 for (int i = 0; i < num_outs_; ++i) { 575 const auto& arg(op_def_.output_arg(i)); 576 if (!arg.number_attr().empty()) { 577 if (!num_outputs_expr->empty()) { 578 strings::StrAppend(num_outputs_expr, " + "); 579 } 580 (*output_sizes)[i] = attr_expressions_[arg.number_attr()]; 581 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]); 582 } else if (!arg.type_list_attr().empty()) { 583 if (!num_outputs_expr->empty()) { 584 strings::StrAppend(num_outputs_expr, " + "); 585 } 586 // Have to be careful to use an expression that works in both 587 // graph and eager paths here. 588 const auto iter = inferred_attrs_.find(arg.type_list_attr()); 589 if (iter == inferred_attrs_.end()) { 590 (*output_sizes)[i] = strings::StrCat( 591 "len(", attr_expressions_[arg.type_list_attr()], ")"); 592 } else { 593 (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")"); 594 } 595 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]); 596 } else { 597 ++num_fixed_outputs; 598 } 599 } 600 if (num_fixed_outputs > 0) { 601 if (!num_outputs_expr->empty()) { 602 strings::StrAppend(num_outputs_expr, " + "); 603 } 604 strings::StrAppend(num_outputs_expr, num_fixed_outputs); 605 } else if (num_outputs_expr->empty()) { 606 *num_outputs_expr = "0"; 607 } 608} 609 610void GenEagerPythonOp::AddEagerFunctionTeardown( 611 const string& indentation, const std::vector<string>& output_sizes, 612 bool execute_record_gradient) { 613 if (num_outs_ > 0) { 614 if (execute_record_gradient) { 615 strings::StrAppend(&result_, indentation, "_execute.record_gradient(\n", 616 " \"", op_def_.name(), 617 "\", _inputs_flat, _attrs, _result, name)\n"); 618 } 619 if (num_outs_ == 1 && !output_sizes[0].empty()) { 620 // Single list result. 621 } else if (num_outs_ == 1) { 622 // Execute returns a single-element list which we need to destructure. 623 strings::StrAppend(&result_, indentation, "_result, = _result\n"); 624 } else { 625 // Have multiple outputs, so we will need to reformat the return 626 // value of execute() to be a list with one entry per op output 627 // (that entry will be a list of tensors if that output is of list 628 // type). 629 // For list outputs, convert the right subrange of _result into a list. 630 Unflatten(indentation, output_sizes, "_result", &result_); 631 // Convert to a named tuple. 632 strings::StrAppend(&result_, indentation, "_result = _", op_def_.name(), 633 "Output._make(_result)\n"); 634 } 635 } else { 636 strings::StrAppend(&result_, indentation, "_result = None\n"); 637 } 638 strings::StrAppend(&result_, indentation, "return _result\n\n"); 639} 640 641bool GenEagerPythonOp::AddEagerFastPathAndGraphCode( 642 const string& parameters, const std::vector<string>& output_sizes, 643 const string& eager_not_allowed_error) { 644 AddDefLine(function_name_, parameters); 645 AddDocStringDescription(); 646 AddDocStringArgs(); 647 AddDocStringInputs(); 648 AddDocStringAttrs(); 649 AddDocStringNameArg(); 650 AddOutputGlobals(); // Added to prelude_ 651 AddDocStringOutputs(); 652 strings::StrAppend(&result_, " \"\"\"\n"); 653 654 // Handle graph-mode case 655 string function_setup; 656 if (!GetEagerFunctionSetup(" ", &function_setup)) { 657 result_ = function_setup; 658 return false; 659 } 660 HandleGraphMode(function_setup); 661 AddEagerFunctionTeardown(" ", output_sizes, 662 true /* execute_record_gradient */); 663 664 // Handle eager-mode case 665 strings::StrAppend(&result_, " else:\n"); 666 667 if (eager_not_allowed_error.empty()) { 668 AddEagerFastPathExecute(); 669 } else { 670 strings::StrAppend(&result_, " ", eager_not_allowed_error); 671 } 672 673 strings::StrAppend(&result_, "\n\n"); 674 return true; 675} 676 677bool GenEagerPythonOp::AddEagerFallbackCode( 678 const string& parameters, const std::vector<string>& output_sizes, 679 const string& num_outputs_expr, const string& eager_not_allowed_error) { 680 if (!eager_not_allowed_error.empty()) { 681 strings::StrAppend(&result_, " ", eager_not_allowed_error); 682 return true; 683 } 684 685 AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix), parameters); 686 strings::StrAppend( 687 &result_, " r\"\"\"This is the slowpath function for Eager mode.\n"); 688 strings::StrAppend(&result_, " This is for function ", function_name_, 689 "\n \"\"\"\n"); 690 691 strings::StrAppend(&result_, " _ctx = _context.context()\n"); 692 693 string function_setup; 694 if (!GetEagerFunctionSetup(" ", &function_setup)) { 695 result_ = function_setup; 696 return false; 697 } 698 strings::StrAppend(&result_, function_setup); 699 700 AddEagerInferredAttrs(" "); 701 AddEagerInputCasts(" "); 702 strings::StrAppend( 703 &result_, " _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n"); 704 AddEagerAttrs(" "); 705 AddEagerExecute(" ", num_outputs_expr); 706 707 AddEagerFunctionTeardown(" ", output_sizes, 708 true /* execute_record_gradient */); 709 710 return true; 711} 712 713void GenEagerPythonOp::AddEagerFastPathExecute() { 714 string fastpath_execute_params = strings::StrCat( 715 "_ctx._handle, _ctx.device_name, \"", op_def_.name(), "\", ", 716 "_execute.record_gradient, name, _ctx._post_execution_callbacks"); 717 string fallback_params; 718 719 for (int i = 0; i < api_def_.in_arg_size(); i++) { 720 const string param_name = param_names_[i].GetRenameTo(); 721 strings::StrAppend(&fastpath_execute_params, ", ", param_name); 722 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); 723 strings::StrAppend(&fallback_params, param_name); 724 } 725 726 for (const auto& attr : api_def_.attr()) { 727 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { 728 strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ", 729 attr.rename_to()); 730 731 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); 732 strings::StrAppend(&fallback_params, attr.rename_to(), "=", 733 attr.rename_to()); 734 } 735 } 736 737 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); 738 strings::StrAppend(&fallback_params, "name=name"); 739 740 strings::StrAppend(&result_, " try:\n"); 741 strings::StrAppend( 742 &result_, " ", 743 "_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n", 744 WordWrap(strings::StrCat(" "), 745 strings::StrCat(fastpath_execute_params, ")"), kRightMargin), 746 "\n"); 747 748 if (op_def_.output_arg_size() > 1) { 749 const string output_tuple_name = 750 strings::StrCat("_", op_def_.name(), "Output"); 751 strings::StrAppend(&result_, " ", "_result = ", output_tuple_name, 752 "._make(_result)\n"); 753 } 754 strings::StrAppend(&result_, " ", "return _result\n"); 755 756 // Handle fallback. 757 strings::StrAppend(&result_, " ", "except _core._FallbackException:\n"); 758 strings::StrAppend( 759 &result_, " ", "return ", function_name_, kEagerFallbackSuffix, 760 "(\n", 761 WordWrap(strings::StrCat(" "), 762 strings::StrCat(fallback_params, ")"), kRightMargin), 763 "\n"); 764 765 // Any errors thrown from execute need to be unwrapped from 766 // _NotOkStatusException. 767 strings::StrAppend(&result_, " ", 768 "except _core._NotOkStatusException as e:\n"); 769 strings::StrAppend(&result_, " ", "if name is not None:\n"); 770 strings::StrAppend(&result_, " ", 771 "message = e.message + \" name: \" + name\n"); 772 strings::StrAppend(&result_, " ", "else:\n"); 773 strings::StrAppend(&result_, " ", "message = e.message\n"); 774 strings::StrAppend( 775 &result_, " ", 776 "_six.raise_from(_core._status_to_exception(e.code, message), None)\n"); 777} 778 779void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) { 780 // Figure out values for inferred attrs, and cast to eager tensors. 781 for (int i = 0; i < op_def_.attr_size(); ++i) { 782 const auto& attr(op_def_.attr(i)); 783 const auto& api_def_attr(api_def_.attr(i)); 784 auto arg_list = attr_to_args_.find(attr.name()); 785 if (arg_list != attr_to_args_.end()) { 786 if (attr.type() == "type") { 787 std::vector<string> output_sizes; 788 const string flattened = 789 FlattenInputs(&arg_list->second, &output_sizes); 790 string conversion = strings::StrCat("_execute.args_to_matching_eager(", 791 flattened, ", _ctx"); 792 if (attr.has_default_value()) { 793 strings::StrAppend( 794 &conversion, ", ", 795 python_op_gen_internal::AttrValueToPython( 796 attr.type(), api_def_attr.default_value(), "_dtypes.")); 797 } 798 strings::StrAppend(&conversion, ")"); 799 const string var_name = AttrVarName(attr.name(), &attr_expressions_); 800 if (output_sizes.size() == 1) { 801 // Avoid creating a temporary variable in the case where 802 // we can easily assign to the right value directly. 803 const string inputs_var = 804 param_names_[arg_list->second.front()].GetRenameTo(); 805 if (output_sizes.front().empty()) { 806 strings::StrAppend(&result_, indentation, var_name, ", (", 807 inputs_var, ",) = ", conversion, "\n"); 808 } else { 809 strings::StrAppend(&result_, indentation, var_name, ", ", 810 inputs_var, " = ", conversion, "\n"); 811 } 812 } else { 813 const string inputs_var = strings::StrCat("_inputs_", attr.name()); 814 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var, 815 " = ", conversion, "\n"); 816 // Convert from a flat list of eager tensors back to the 817 // parameter variables. 818 Unflatten(indentation, output_sizes, inputs_var, &result_); 819 std::vector<string> p; 820 for (int j : arg_list->second) { 821 p.emplace_back(param_names_[j].GetRenameTo()); 822 } 823 strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ", 824 inputs_var, "\n"); 825 } 826 } else if (attr.type() == "list(type)") { 827 // NOTE: We ignore default values for these attrs, since it is 828 // unclear how you would use it, and the one use case is 829 // parse_single_sequence_example which only needs it for 830 // backwards compatibility. 831 const string var_name = AttrVarName(attr.name(), &attr_expressions_); 832 string inputs_var; 833 string conversion; 834 if (arg_list->second.size() > 1) { 835 // If you have more than one list(tensor) argument, their types 836 // have to match. 837 std::vector<string> lists; 838 for (auto iter = arg_list->second.begin(); 839 iter != arg_list->second.end(); ++iter) { 840 lists.push_back(param_names_[*iter].GetRenameTo()); 841 } 842 inputs_var = VectorToTuple(lists); 843 conversion = "_execute.args_to_mixed_eager_tensors"; 844 } else { 845 // For one list(tensor) argument, we just convert every 846 // element of the list to an eager tensor. 847 inputs_var = param_names_[arg_list->second.front()].GetRenameTo(); 848 conversion = "_execute.convert_to_mixed_eager_tensors"; 849 } 850 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var, 851 " = ", conversion, "(", inputs_var, ", _ctx)\n"); 852 } 853 } 854 } 855} 856 857void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) { 858 // Cast remaining args to eager tensors 859 for (int i = 0; i < op_def_.input_arg_size(); ++i) { 860 const auto& arg(op_def_.input_arg(i)); 861 if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue; 862 const string& param = param_names_[i].GetRenameTo(); 863 const string fn = arg.number_attr().empty() ? "" : "n_"; 864 const string dtype = 865 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes."); 866 strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn, 867 "to_tensor(", param, ", ", dtype, ")\n"); 868 } 869} 870 871void GenEagerPythonOp::AddEagerAttrs(const string& indentation) { 872 // Compute eager attrs 873 if (op_def_.attr_size() > 0) { 874 string attr_values; 875 for (int i = 0; i < op_def_.attr_size(); ++i) { 876 if (i > 0) strings::StrAppend(&attr_values, ", "); 877 const auto& attr_name(op_def_.attr(i).name()); 878 strings::StrAppend(&attr_values, "\"", attr_name, "\", ", 879 attr_expressions_[attr_name]); 880 } 881 strings::StrAppend(&attr_values, ")"); 882 strings::StrAppend( 883 &result_, 884 WordWrap(indentation, strings::StrCat("_attrs = (", attr_values), 885 kRightMargin), 886 "\n"); 887 } else { 888 strings::StrAppend(&result_, indentation, "_attrs = None\n"); 889 } 890} 891 892void GenEagerPythonOp::AddEagerExecute(const string& indentation, 893 const string& num_outputs_expr) { 894 const string return_prefix = 895 strings::StrCat(indentation, "_result = _execute.execute("); 896 const string return_args = strings::StrCat( 897 "b\"", op_def_.name(), "\", ", num_outputs_expr, 898 ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)"); 899 strings::StrAppend(&result_, 900 // Wrap the arguments, and indent to the (. 901 WordWrap(return_prefix, return_args, kRightMargin), "\n"); 902} 903 904string GetEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, 905 const std::vector<string>& hidden_ops, 906 bool require_shapes, 907 const string& source_file_name = "") { 908 string result; 909 // Header 910 // TODO(josh11b): Mention the library for which wrappers are being generated. 911 strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops. 912 913This file is MACHINE GENERATED! Do not edit. 914)"); 915 916 // Mention the original source file so someone tracing back through 917 // generated Python code will know where to look next. 918 if (!source_file_name.empty()) { 919 strings::StrAppend(&result, "Original C++ source file: "); 920 strings::StrAppend(&result, source_file_name); 921 strings::StrAppend(&result, "\n"); 922 } 923 924 strings::StrAppend(&result, R"(""" 925 926import collections as _collections 927import six as _six 928 929from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow 930from tensorflow.python.eager import context as _context 931from tensorflow.python.eager import core as _core 932from tensorflow.python.eager import execute as _execute 933from tensorflow.python.framework import dtypes as _dtypes 934from tensorflow.python.framework import errors as _errors 935from tensorflow.python.framework import tensor_shape as _tensor_shape 936 937from tensorflow.core.framework import op_def_pb2 as _op_def_pb2 938# Needed to trigger the call to _set_call_cpp_shape_fn. 939from tensorflow.python.framework import common_shapes as _common_shapes 940from tensorflow.python.framework import op_def_registry as _op_def_registry 941from tensorflow.python.framework import ops as _ops 942from tensorflow.python.framework import op_def_library as _op_def_library 943from tensorflow.python.util.tf_export import tf_export 944 945)"); 946 947 // We'll make a copy of ops that filters out descriptions. 948 OpList cleaned_ops; 949 auto out = cleaned_ops.mutable_op(); 950 out->Reserve(ops.op_size()); 951 for (const auto& op_def : ops.op()) { 952 const auto* api_def = api_defs.GetApiDef(op_def.name()); 953 954 if (api_def->visibility() == ApiDef::SKIP) { 955 continue; 956 } 957 958 // An op is hidden if either its ApiDef visibility is HIDDEN 959 // or it is in the hidden_ops list. 960 bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; 961 if (!is_hidden) { 962 for (const string& hidden : hidden_ops) { 963 if (op_def.name() == hidden) { 964 is_hidden = true; 965 break; 966 } 967 } 968 } 969 970 string function_name; 971 python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(), 972 &function_name); 973 if (is_hidden) function_name = strings::StrCat("_", function_name); 974 975 // When users create custom python wrappers, they may link in the 976 // default op registry by accident, and because they can't 977 // enumerate all 'hidden' symbols, this guard is to prevent 978 // instantiating a python reserved word in their wrapper. 979 if (python_op_gen_internal::IsPythonReserved(function_name)) { 980 continue; 981 } 982 983 strings::StrAppend(&result, 984 GetEagerPythonOp(op_def, *api_def, function_name)); 985 986 if (!require_shapes) { 987 strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), 988 "\")(None)\n\n"); 989 } 990 991 auto added = out->Add(); 992 *added = op_def; 993 RemoveNonDeprecationDescriptionsFromOpDef(added); 994 } 995 996 result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes): 997 op_list = _op_def_pb2.OpList() 998 op_list.ParseFromString(op_list_proto_bytes) 999 _op_def_registry.register_op_list(op_list) 1000 op_def_lib = _op_def_library.OpDefLibrary() 1001 op_def_lib.add_op_list(op_list) 1002 return op_def_lib 1003)"); 1004 1005 result.append("# "); 1006 auto ops_text = ProtoDebugString(cleaned_ops); 1007 str_util::StripTrailingWhitespace(&ops_text); 1008 result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true)); 1009 result.append("\n"); 1010 strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n", 1011 str_util::CEscape(cleaned_ops.SerializeAsString()).c_str()); 1012 return result; 1013} 1014 1015} // namespace 1016 1017void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, 1018 const std::vector<string>& hidden_ops, 1019 bool require_shapes, const string& source_file_name) { 1020 printf("%s", GetEagerPythonOps(ops, api_defs, hidden_ops, require_shapes, 1021 source_file_name) 1022 .c_str()); 1023} 1024 1025string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) { 1026 string op_list_str(op_list_buf, op_list_len); 1027 OpList ops; 1028 ops.ParseFromString(op_list_str); 1029 1030 ApiDefMap api_def_map(ops); 1031 return GetEagerPythonOps(ops, api_def_map, {}, false); 1032} 1033 1034} // namespace tensorflow 1035