python_eager_op_gen.cc revision 8f1e63d5629bda4f6c91fdec7a3b8418ed96786e
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/python/eager/python_eager_op_gen.h" 16 17#include <stdio.h> 18#include <sstream> 19#include <unordered_map> 20#include "tensorflow/core/framework/api_def.pb.h" 21#include "tensorflow/core/framework/attr_value.pb.h" 22#include "tensorflow/core/framework/op.h" 23#include "tensorflow/core/framework/op_def.pb_text.h" 24#include "tensorflow/core/framework/op_def.pb.h" 25#include "tensorflow/core/framework/op_def_util.h" 26#include "tensorflow/core/framework/op_gen_lib.h" 27#include "tensorflow/core/framework/tensor.pb_text.h" 28#include "tensorflow/core/framework/types.h" 29#include "tensorflow/core/framework/types.pb.h" 30#include "tensorflow/core/lib/gtl/map_util.h" 31#include "tensorflow/core/lib/gtl/stl_util.h" 32#include "tensorflow/core/lib/strings/str_util.h" 33#include "tensorflow/core/lib/strings/strcat.h" 34#include "tensorflow/core/lib/strings/stringprintf.h" 35#include "tensorflow/core/platform/logging.h" 36#include "tensorflow/core/platform/macros.h" 37#include "tensorflow/core/platform/types.h" 38#include "tensorflow/python/framework/python_op_gen_internal.h" 39 40namespace tensorflow { 41namespace { 42 43const int kRightMargin = 78; 44 45string AttrVarName(const string& attr_name, 46 std::unordered_map<string, string>* attr_expressions) { 47 const string var = strings::StrCat("_attr_", attr_name); 48 if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var; 49 return var; 50} 51 52void AddInferredAttr(const string& attr_name, const string& value_expression, 53 string* result, 54 std::unordered_map<string, string>* attr_expressions) { 55 strings::StrAppend(result, " ", AttrVarName(attr_name, attr_expressions), 56 " = ", value_expression, "\n"); 57} 58 59string VectorToTuple(const std::vector<string>& l) { 60 if (l.size() == 1) return strings::StrCat("(", l.front(), ",)"); 61 string ret = "("; 62 for (int i = 0; i < l.size(); ++i) { 63 if (i > 0) { 64 strings::StrAppend(&ret, ", "); 65 } 66 strings::StrAppend(&ret, l[i]); 67 } 68 strings::StrAppend(&ret, ")"); 69 return ret; 70} 71 72void Unflatten(const string& prefix, const std::vector<string>& output_sizes, 73 const string& var, string* result) { 74 for (int i = 0; i < output_sizes.size(); ++i) { 75 if (!output_sizes[i].empty()) { 76 strings::StrAppend(result, prefix, var, " = "); 77 if (i > 0) strings::StrAppend(result, var, "[:", i, "] + "); 78 if (i + 1 < output_sizes.size()) { 79 // Special case i == 0 to avoid "0 +" in the generated code. 80 if (i == 0) { 81 strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ", 82 var, "[", output_sizes[i], ":]"); 83 } else { 84 strings::StrAppend(result, "[", var, "[", i, ":", i, " + ", 85 output_sizes[i], "]] + ", var, "[", i, " + ", 86 output_sizes[i], ":]"); 87 } 88 } else { 89 strings::StrAppend(result, "[", var, "[", i, ":]]"); 90 } 91 strings::StrAppend(result, "\n"); 92 } 93 } 94} 95 96string TensorPBString(const TensorProto& pb) { 97 // Note: This gets used in the argument list, and so must survive naive 98 // word wrapping. 99 return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\""); 100} 101 102const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { 103 for (int i = 0; i < api_def.in_arg_size(); ++i) { 104 if (api_def.in_arg(i).name() == name) { 105 return &api_def.in_arg(i); 106 } 107 } 108 return nullptr; 109} 110 111class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { 112 public: 113 GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, 114 const string& function_name) 115 : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) { 116 op_name_ = function_name_; 117 op_name_.Consume("_"); 118 } 119 ~GenEagerPythonOp() override {} 120 121 string Code() override; 122 123 protected: 124 void ExpectListArg(const string& arg_name); 125 void AddEagerInferredAttrs(); 126 void AddEagerInputCasts(); 127 void AddEagerAttrs(); 128 void AddEagerExecute(const string& num_outputs_expr); 129 130 void AddAttrForArg(const string& attr, int arg_index) { 131 gtl::InsertIfNotPresent(&inferred_attrs_, attr, 132 op_def_.input_arg(arg_index).name()); 133 auto iter = attr_to_args_.find(attr); 134 if (iter == attr_to_args_.end()) { 135 attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index})); 136 } else { 137 iter->second.push_back(arg_index); 138 } 139 } 140 141 // Returns a string expression representing a flattened list of all 142 // the inputs given by `*input_indices` (or all inputs if 143 // `input_indices` is nullptr). `*output_sizes` can be used to unflatten. 144 string FlattenInputs(const std::vector<int>* input_indices, 145 std::vector<string>* output_sizes) const; 146 147 StringPiece op_name_; 148 typedef std::unordered_map<string, std::vector<int>> AttrToArgMap; 149 AttrToArgMap attr_to_args_; 150 std::unordered_map<string, string> attr_expressions_; 151}; 152 153string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, 154 const string& function_name) { 155 return GenEagerPythonOp(op_def, api_def, function_name).Code(); 156} 157 158string GenEagerPythonOp::FlattenInputs( 159 const std::vector<int>* input_indices, 160 std::vector<string>* output_sizes) const { 161 string inputs; 162 enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING; 163 const int n = input_indices != nullptr ? input_indices->size() 164 : op_def_.input_arg_size(); 165 for (int j = 0; j < n; ++j) { 166 const int i = input_indices ? (*input_indices)[j] : j; 167 const auto& arg(op_def_.input_arg(i)); 168 const bool is_list = 169 !arg.type_list_attr().empty() || !arg.number_attr().empty(); 170 if (is_list) { 171 if (inputs_state == WAS_SOLO_INPUT) { 172 strings::StrAppend(&inputs, "] + "); 173 } else if (inputs_state == WAS_LIST_INPUT) { 174 strings::StrAppend(&inputs, " + "); 175 } 176 strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")"); 177 inputs_state = WAS_LIST_INPUT; 178 if (output_sizes != nullptr) { 179 if (!arg.number_attr().empty()) { 180 output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr)); 181 } else { 182 output_sizes->emplace_back( 183 strings::StrCat("len(", param_names_[i].GetRenameTo(), ")")); 184 } 185 } 186 } else { 187 if (inputs_state == WAS_SOLO_INPUT) { 188 strings::StrAppend(&inputs, ", "); 189 } else if (inputs_state == WAS_LIST_INPUT) { 190 strings::StrAppend(&inputs, " + ["); 191 } else { 192 strings::StrAppend(&inputs, "["); 193 } 194 strings::StrAppend(&inputs, param_names_[i].GetRenameTo()); 195 inputs_state = WAS_SOLO_INPUT; 196 if (output_sizes != nullptr) output_sizes->emplace_back(); 197 } 198 } 199 if (inputs_state == STARTING) return "[]"; 200 if (inputs_state == WAS_SOLO_INPUT) { 201 strings::StrAppend(&inputs, "]"); 202 } 203 return inputs; 204} 205 206string GenEagerPythonOp::Code() { 207 if (api_def_.visibility() == ApiDef::SKIP) { 208 return ""; 209 } 210 // This has all the input args followed by those attrs that don't have 211 // defaults. 212 std::vector<python_op_gen_internal::ParamNames> params_no_default; 213 // The parameters with defaults (these have to be listed after those without). 214 // No input args are included, just attrs. 215 std::vector<std::pair<python_op_gen_internal::ParamNames, string>> 216 params_with_default; 217 218 for (int i = 0; i < api_def_.arg_order_size(); ++i) { 219 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); 220 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); 221 params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to()); 222 if (!arg.type_attr().empty()) { 223 AddAttrForArg(arg.type_attr(), i); 224 } else if (!arg.type_list_attr().empty()) { 225 AddAttrForArg(arg.type_list_attr(), i); 226 } 227 if (!arg.number_attr().empty()) { 228 AddAttrForArg(arg.number_attr(), i); 229 } 230 } 231 for (int i = 0; i < op_def_.attr_size(); ++i) { 232 const auto& attr(op_def_.attr(i)); 233 const auto& api_def_attr(api_def_.attr(i)); 234 // Do not add inferred attrs to the Python function signature. 235 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { 236 if (api_def_attr.has_default_value()) { 237 if (attr.type() == "tensor") { 238 params_with_default.emplace_back( 239 python_op_gen_internal::ParamNames(api_def_attr.name(), 240 api_def_attr.rename_to()), 241 strings::StrCat( 242 "_execute.make_tensor(", 243 TensorPBString(api_def_attr.default_value().tensor()), ", \"", 244 api_def_attr.rename_to(), "\")")); 245 } else if (attr.type() == "list(tensor)") { 246 std::vector<string> pbtxt; 247 for (const auto& pb : api_def_attr.default_value().list().tensor()) { 248 pbtxt.emplace_back(TensorPBString(pb)); 249 } 250 params_with_default.emplace_back( 251 python_op_gen_internal::ParamNames(api_def_attr.name(), 252 api_def_attr.rename_to()), 253 strings::StrCat("[_execute.make_tensor(_pb, \"", 254 api_def_attr.rename_to(), "\") for _pb in ", 255 VectorToTuple(pbtxt), "]")); 256 } else { 257 params_with_default.emplace_back( 258 python_op_gen_internal::ParamNames(api_def_attr.name(), 259 api_def_attr.rename_to()), 260 python_op_gen_internal::AttrValueToPython( 261 attr.type(), api_def_attr.default_value(), "_dtypes.")); 262 } 263 } else { 264 params_no_default.emplace_back(api_def_attr.name(), 265 api_def_attr.rename_to()); 266 } 267 } 268 } 269 270 // Save the list of attr parameters (attrs that won't be inferred), 271 // those with defaults go at the end. 272 // Get the attrs in the order we want by taking the attrs without defaults 273 // from the end of params_no_default, and adding params_no_default. 274 attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() + 275 params_with_default.size()); 276 for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) { 277 attrs_.push_back(params_no_default[i].GetName()); 278 } 279 for (const auto& p : params_with_default) { 280 attrs_.push_back(p.first.GetName()); 281 } 282 283 param_names_.reserve(params_no_default.size() + params_with_default.size()); 284 param_names_.insert(param_names_.begin(), params_no_default.begin(), 285 params_no_default.end()); 286 for (const auto& param_and_default : params_with_default) { 287 param_names_.push_back(param_and_default.first); 288 } 289 290 string parameters; 291 for (const auto& param : params_no_default) { 292 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 293 strings::StrAppend(¶meters, param.GetRenameTo()); 294 } 295 for (const auto& param_and_default : params_with_default) { 296 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 297 strings::StrAppend(¶meters, param_and_default.first.GetRenameTo(), "=", 298 param_and_default.second); 299 } 300 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 301 strings::StrAppend(¶meters, "name=None"); 302 303 AddExport(); 304 AddDefLine(parameters); 305 AddDocStringDescription(); 306 AddDocStringArgs(); 307 AddDocStringInputs(); 308 AddDocStringAttrs(); 309 AddDocStringNameArg(); 310 AddOutputGlobals(); 311 AddDocStringOutputs(); 312 strings::StrAppend(&result_, " \"\"\"\n"); 313 314 // Function body. 315 316 // Validate list inputs, infer length attrs. 317 for (int i = 0; i < op_def_.attr_size(); ++i) { 318 const auto& attr(op_def_.attr(i)); 319 if (attr.type() == "int") { 320 auto arg_list = attr_to_args_.find(attr.name()); 321 if (arg_list != attr_to_args_.end()) { 322 // Inferred int attrs are the lengths of inputs. Validate those 323 // inputs are lists and have the same length. 324 for (auto iter = arg_list->second.begin(); 325 iter != arg_list->second.end(); ++iter) { 326 const string& arg_api_name = param_names_[*iter].GetRenameTo(); 327 ExpectListArg(arg_api_name); 328 if (iter == arg_list->second.begin()) { 329 AddInferredAttr(attr.name(), 330 strings::StrCat("len(", arg_api_name, ")"), 331 &result_, &attr_expressions_); 332 } else { 333 const auto& attr_var = attr_expressions_[attr.name()]; 334 strings::StrAppend(&result_, " if len(", arg_api_name, 335 ") != ", attr_var, 336 ":\n" 337 " raise ValueError(\n" 338 " \"List argument '", 339 arg_api_name, "' to '", op_name_, 340 "' Op with length %d \"\n" 341 " \"must match length %d of argument '", 342 inferred_attrs_[attr.name()], 343 "'.\" %\n" 344 " (len(", 345 arg_api_name, "), ", attr_var, "))\n"); 346 } 347 } 348 } 349 } 350 } 351 352 // Values for non-inferred attrs. 353 for (int i = 0; i < attrs_.size(); ++i) { 354 const string& attr_name = attrs_[i]; 355 const auto& param = param_names_[i + op_def_.input_arg_size()]; 356 const auto& attr = *FindAttr(attr_name, op_def_); 357 const string& attr_api_name = param.GetRenameTo(); 358 StringPiece attr_type = attr.type(); 359 attr_expressions_[attr_name] = attr_api_name; 360 const int default_index = i - (attrs_.size() - params_with_default.size()); 361 if (default_index >= 0) { 362 const string& default_value = params_with_default[default_index].second; 363 strings::StrAppend(&result_, " if ", attr_api_name, " is None:\n"); 364 strings::StrAppend(&result_, " ", attr_api_name, " = ", default_value, 365 "\n"); 366 } 367 if (attr_type.starts_with("list(")) { 368 ExpectListArg(attr_api_name); 369 } 370 371 if (attr_type == "string") { 372 strings::StrAppend(&result_, " ", attr_api_name, " = _execute.make_str(", 373 attr_api_name, ", \"", attr_api_name, "\")\n"); 374 } else if (attr_type == "list(string)") { 375 strings::StrAppend(&result_, " ", attr_api_name, 376 " = [_execute.make_str(_s, \"", attr_api_name, 377 "\") for _s in ", attr_api_name, "]\n"); 378 } else if (attr_type == "int") { 379 strings::StrAppend(&result_, " ", attr_api_name, " = _execute.make_int(", 380 attr_api_name, ", \"", attr_api_name, "\")\n"); 381 } else if (attr_type == "list(int)") { 382 strings::StrAppend(&result_, " ", attr_api_name, 383 " = [_execute.make_int(_i, \"", attr_api_name, 384 "\") for _i in ", attr_api_name, "]\n"); 385 } else if (attr_type == "float") { 386 strings::StrAppend(&result_, " ", attr_api_name, 387 " = _execute.make_float(", attr_api_name, ", \"", 388 attr_api_name, "\")\n"); 389 } else if (attr_type == "list(float)") { 390 strings::StrAppend(&result_, " ", attr_api_name, 391 " = [_execute.make_float(_f, \"", attr_api_name, 392 "\") for _f in ", attr_api_name, "]\n"); 393 } else if (attr_type == "bool") { 394 strings::StrAppend(&result_, " ", attr_api_name, 395 " = _execute.make_bool(", attr_api_name, ", \"", 396 attr_api_name, "\")\n"); 397 } else if (attr_type == "list(bool)") { 398 strings::StrAppend(&result_, " ", attr_api_name, 399 " = [_execute.make_bool(_b, \"", attr_api_name, 400 "\") for _b in ", attr_api_name, "]\n"); 401 } else if (attr_type == "type") { 402 strings::StrAppend(&result_, " ", attr_api_name, 403 " = _execute.make_type(", attr_api_name, ", \"", 404 attr_api_name, "\")\n"); 405 } else if (attr_type == "list(type)") { 406 strings::StrAppend(&result_, " ", attr_api_name, 407 " = [_execute.make_type(_t, \"", attr_api_name, 408 "\") for _t in ", attr_api_name, "]\n"); 409 } else if (attr_type == "shape") { 410 strings::StrAppend(&result_, " ", attr_api_name, 411 " = _execute.make_shape(", attr_api_name, ", \"", 412 attr_api_name, "\")\n"); 413 } else if (attr_type == "list(shape)") { 414 strings::StrAppend(&result_, " ", attr_api_name, 415 " = [_execute.make_shape(_s, \"", attr_api_name, 416 "\") for _s in ", attr_api_name, "]\n"); 417 } else if (attr_type == "tensor") { 418 strings::StrAppend(&result_, " ", attr_api_name, 419 " = _execute.make_tensor(", attr_api_name, ", \"", 420 attr_api_name, "\")\n"); 421 } else if (attr_type == "list(tensor)") { 422 strings::StrAppend(&result_, " ", attr_api_name, 423 " = [_execute.make_tensor(_t, \"", attr_api_name, 424 "\") for _t in ", attr_api_name, "]\n"); 425 } else if (attr_type != "func") { 426 return strings::StrCat("# No definition for ", function_name_, 427 " since we don't support attrs with type\n" 428 "# '", 429 attr_type, "' right now.\n\n"); 430 } 431 } 432 433 // Figure out the list of inputs. 434 const string inputs = FlattenInputs(nullptr, nullptr); 435 436 // Handle graph-mode case 437 strings::StrAppend(&result_, 438 " _ctx = _context.context()\n" 439 440 " if _ctx.in_graph_mode():\n" 441 " _, _, _op = _op_def_lib._apply_op_helper(\n"); 442 AddBodyNoReturn(" "); 443 if (num_outs_ > 0) { 444 strings::StrAppend(&result_, " _result = _op.outputs[:]\n"); 445 // Special case handling for stateful op with single list output 446 // that might be empty. 447 if (num_outs_ == 1 && op_def_.is_stateful() && 448 (!op_def_.output_arg(0).number_attr().empty() || 449 !op_def_.output_arg(0).type_list_attr().empty())) { 450 // TODO(josh11b): Can skip this if the number_attr/type_list_attr has 451 // a constraint indicating that this can never be empty. 452 strings::StrAppend(&result_, 453 " if not _result:\n" 454 " return _op\n"); 455 } 456 strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n"); 457 458 // Compute graph-mode attrs. 459 if (op_def_.attr_size() > 0) { 460 string attr_values; 461 for (int i = 0; i < op_def_.attr_size(); ++i) { 462 if (i > 0) strings::StrAppend(&attr_values, ", "); 463 const auto& attr_name(op_def_.attr(i).name()); 464 strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"", 465 attr_name, "\")"); 466 } 467 strings::StrAppend(&attr_values, ")"); 468 strings::StrAppend(&result_, 469 WordWrap(" _attrs = (", attr_values, kRightMargin), 470 "\n"); 471 } else { 472 strings::StrAppend(&result_, " _attrs = None\n"); 473 } 474 } else { 475 strings::StrAppend(&result_, " return _op\n"); 476 } 477 478 // Handle eager-mode case 479 strings::StrAppend(&result_, " else:\n"); 480 481 // Expression representing the number of outputs. 482 int num_fixed_outputs = 0; 483 string num_outputs_expr; 484 // If output i is list output, output_sizes[i] will be set to a 485 // string with the python expression that will evaluate to its 486 // length. output_sizes[i] is empty for non-list outputs. 487 std::vector<string> output_sizes(num_outs_); 488 for (int i = 0; i < num_outs_; ++i) { 489 const auto& arg(op_def_.output_arg(i)); 490 if (!arg.number_attr().empty()) { 491 if (!num_outputs_expr.empty()) { 492 strings::StrAppend(&num_outputs_expr, " + "); 493 } 494 output_sizes[i] = attr_expressions_[arg.number_attr()]; 495 strings::StrAppend(&num_outputs_expr, output_sizes[i]); 496 } else if (!arg.type_list_attr().empty()) { 497 if (!num_outputs_expr.empty()) { 498 strings::StrAppend(&num_outputs_expr, " + "); 499 } 500 // Have to be careful to use an expression that works in both 501 // graph and eager paths here. 502 const auto iter = inferred_attrs_.find(arg.type_list_attr()); 503 if (iter == inferred_attrs_.end()) { 504 output_sizes[i] = strings::StrCat( 505 "len(", attr_expressions_[arg.type_list_attr()], ")"); 506 } else { 507 output_sizes[i] = strings::StrCat("len(", iter->second, ")"); 508 } 509 strings::StrAppend(&num_outputs_expr, output_sizes[i]); 510 } else { 511 ++num_fixed_outputs; 512 } 513 } 514 if (num_fixed_outputs > 0) { 515 if (!num_outputs_expr.empty()) { 516 strings::StrAppend(&num_outputs_expr, " + "); 517 } 518 strings::StrAppend(&num_outputs_expr, num_fixed_outputs); 519 } else if (num_outputs_expr.empty()) { 520 num_outputs_expr = "0"; 521 } 522 523 bool eager_allowed = true; 524 string ref_arg; 525 for (int i = 0; i < op_def_.input_arg_size(); ++i) { 526 const auto& arg = op_def_.input_arg(i); 527 if (arg.is_ref()) { 528 eager_allowed = false; 529 DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name()); 530 ref_arg = api_def_.in_arg(i).rename_to(); 531 } 532 } 533 for (int i = 0; i < op_def_.output_arg_size(); ++i) { 534 const auto& arg = op_def_.output_arg(i); 535 if (arg.is_ref()) { 536 eager_allowed = false; 537 DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name()); 538 ref_arg = api_def_.out_arg(i).rename_to(); 539 } 540 } 541 542 if (eager_allowed) { 543 AddEagerInferredAttrs(); 544 AddEagerInputCasts(); 545 strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n"); 546 AddEagerAttrs(); 547 AddEagerExecute(num_outputs_expr); 548 } else { 549 strings::StrAppend(&result_, 550 " raise RuntimeError(\n" 551 " \"", 552 op_name_, " op does not support eager execution. ", 553 "Arg '", ref_arg, "'' is a ref.\")\n"); 554 } 555 556 if (num_outs_ > 0) { 557 strings::StrAppend(&result_, " _execute.record_gradient(\n", " \"", 558 op_def_.name(), 559 "\", _inputs_flat, _attrs, _result, name)\n"); 560 if (num_outs_ == 1 && !output_sizes[0].empty()) { 561 // Single list result. 562 } else if (num_outs_ == 1) { 563 // Execute returns a single-element list which we need to destructure. 564 strings::StrAppend(&result_, " _result, = _result\n"); 565 } else { 566 // Have multiple outputs, so we will need to reformat the return 567 // value of execute() to be a list with one entry per op output 568 // (that entry will be a list of tensors if that output is of list 569 // type). 570 // For list outputs, convert the right subrange of _result into a list. 571 Unflatten(" ", output_sizes, "_result", &result_); 572 // Convert to a named tuple. 573 strings::StrAppend(&result_, " _result = _", op_def_.name(), 574 "Output._make(_result)\n"); 575 } 576 } else { 577 strings::StrAppend(&result_, " _result = None\n"); 578 } 579 strings::StrAppend(&result_, " return _result\n\n"); 580 return prelude_ + result_; 581} 582 583void GenEagerPythonOp::ExpectListArg(const string& arg_name) { 584 strings::StrAppend(&result_, " if not isinstance(", arg_name, 585 ", (list, tuple)):\n" 586 " raise TypeError(\n" 587 " \"Expected list for '", 588 arg_name, 589 "' argument to \"\n" 590 " \"'", 591 op_name_, "' Op, not %r.\" % ", arg_name, ")\n"); 592} 593 594void GenEagerPythonOp::AddEagerInferredAttrs() { 595 // Figure out values for inferred attrs, and cast to eager tensors. 596 for (int i = 0; i < op_def_.attr_size(); ++i) { 597 const auto& attr(op_def_.attr(i)); 598 const auto& api_def_attr(api_def_.attr(i)); 599 auto arg_list = attr_to_args_.find(attr.name()); 600 if (arg_list != attr_to_args_.end()) { 601 if (attr.type() == "type") { 602 std::vector<string> output_sizes; 603 const string flattened = 604 FlattenInputs(&arg_list->second, &output_sizes); 605 string conversion = strings::StrCat("_execute.args_to_matching_eager(", 606 flattened, ", _ctx"); 607 if (attr.has_default_value()) { 608 strings::StrAppend( 609 &conversion, ", ", 610 python_op_gen_internal::AttrValueToPython( 611 attr.type(), api_def_attr.default_value(), "_dtypes.")); 612 } 613 strings::StrAppend(&conversion, ")"); 614 const string var_name = AttrVarName(attr.name(), &attr_expressions_); 615 if (output_sizes.size() == 1) { 616 // Avoid creating a temporary variable in the case where 617 // we can easily assign to the right value directly. 618 const string inputs_var = 619 param_names_[arg_list->second.front()].GetRenameTo(); 620 if (output_sizes.front().empty()) { 621 strings::StrAppend(&result_, " ", var_name, ", (", inputs_var, 622 ",) = ", conversion, "\n"); 623 } else { 624 strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, 625 " = ", conversion, "\n"); 626 } 627 } else { 628 const string inputs_var = strings::StrCat("_inputs_", attr.name()); 629 strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, 630 " = ", conversion, "\n"); 631 // Convert from a flat list of eager tensors back to the 632 // parameter variables. 633 Unflatten(" ", output_sizes, inputs_var, &result_); 634 std::vector<string> p; 635 for (int j : arg_list->second) { 636 p.emplace_back(param_names_[j].GetRenameTo()); 637 } 638 strings::StrAppend(&result_, " ", VectorToTuple(p), " = ", 639 inputs_var, "\n"); 640 } 641 } else if (attr.type() == "list(type)") { 642 // NOTE: We ignore default values for these attrs, since it is 643 // unclear how you would use it, and the one use case is 644 // parse_single_sequence_example which only needs it for 645 // backwards compatibility. 646 const string var_name = AttrVarName(attr.name(), &attr_expressions_); 647 string inputs_var; 648 string conversion; 649 if (arg_list->second.size() > 1) { 650 // If you have more than one list(tensor) argument, their types 651 // have to match. 652 std::vector<string> lists; 653 for (auto iter = arg_list->second.begin(); 654 iter != arg_list->second.end(); ++iter) { 655 lists.push_back(param_names_[*iter].GetRenameTo()); 656 } 657 inputs_var = VectorToTuple(lists); 658 conversion = "_execute.args_to_mixed_eager_tensors"; 659 } else { 660 // For one list(tensor) argument, we just convert every 661 // element of the list to an eager tensor. 662 inputs_var = param_names_[arg_list->second.front()].GetRenameTo(); 663 conversion = "_execute.convert_to_mixed_eager_tensors"; 664 } 665 strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, " = ", 666 conversion, "(", inputs_var, ", _ctx)\n"); 667 } 668 } 669 } 670} 671 672void GenEagerPythonOp::AddEagerInputCasts() { 673 // Cast remaining args to eager tensors 674 for (int i = 0; i < op_def_.input_arg_size(); ++i) { 675 const auto& arg(op_def_.input_arg(i)); 676 if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue; 677 const string& param = param_names_[i].GetRenameTo(); 678 const string fn = arg.number_attr().empty() ? "" : "n_"; 679 const string dtype = 680 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes."); 681 strings::StrAppend(&result_, " ", param, " = _ops.convert_", fn, 682 "to_tensor(", param, ", ", dtype, ")\n"); 683 } 684} 685 686void GenEagerPythonOp::AddEagerAttrs() { 687 // Compute eager attrs 688 if (op_def_.attr_size() > 0) { 689 string attr_values; 690 for (int i = 0; i < op_def_.attr_size(); ++i) { 691 if (i > 0) strings::StrAppend(&attr_values, ", "); 692 const auto& attr_name(op_def_.attr(i).name()); 693 strings::StrAppend(&attr_values, "\"", attr_name, "\", ", 694 attr_expressions_[attr_name]); 695 } 696 strings::StrAppend(&attr_values, ")"); 697 strings::StrAppend( 698 &result_, WordWrap(" _attrs = (", attr_values, kRightMargin), "\n"); 699 } else { 700 strings::StrAppend(&result_, " _attrs = None\n"); 701 } 702} 703 704void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) { 705 const string return_prefix = " _result = _execute.execute("; 706 const string return_args = strings::StrCat( 707 "b\"", op_def_.name(), "\", ", num_outputs_expr, 708 ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)"); 709 strings::StrAppend(&result_, 710 // Wrap the arguments, and indent to the (. 711 WordWrap(return_prefix, return_args, kRightMargin), "\n"); 712} 713 714string GetEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, 715 const std::vector<string>& hidden_ops, 716 bool require_shapes, 717 const string& source_file_name = "") { 718 string result; 719 // Header 720 // TODO(josh11b): Mention the library for which wrappers are being generated. 721 strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops. 722 723This file is MACHINE GENERATED! Do not edit. 724)"); 725 726 // Mention the original source file so someone tracing back through generated 727 // Python code will know where to look next. 728 if (!source_file_name.empty()) { 729 strings::StrAppend(&result, "Original C++ source file: "); 730 strings::StrAppend(&result, source_file_name); 731 strings::StrAppend(&result, "\n"); 732 } 733 734 strings::StrAppend(&result, R"(""" 735 736import collections as _collections 737 738from tensorflow.python.eager import execute as _execute 739from tensorflow.python.eager import context as _context 740from tensorflow.python.eager import core as _core 741from tensorflow.python.framework import dtypes as _dtypes 742from tensorflow.python.framework import tensor_shape as _tensor_shape 743 744from tensorflow.core.framework import op_def_pb2 as _op_def_pb2 745# Needed to trigger the call to _set_call_cpp_shape_fn. 746from tensorflow.python.framework import common_shapes as _common_shapes 747from tensorflow.python.framework import op_def_registry as _op_def_registry 748from tensorflow.python.framework import ops as _ops 749from tensorflow.python.framework import op_def_library as _op_def_library 750from tensorflow.python.util.tf_export import tf_export 751 752)"); 753 754 // We'll make a copy of ops that filters out descriptions. 755 OpList cleaned_ops; 756 auto out = cleaned_ops.mutable_op(); 757 out->Reserve(ops.op_size()); 758 for (const auto& op_def : ops.op()) { 759 bool is_hidden = false; 760 for (const string& hidden : hidden_ops) { 761 if (op_def.name() == hidden) { 762 is_hidden = true; 763 break; 764 } 765 } 766 767 string function_name; 768 python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(), 769 &function_name); 770 if (is_hidden) function_name = strings::StrCat("_", function_name); 771 772 // When users create custom python wrappers, they may link in the 773 // default op registry by accident, and because they can't 774 // enumerate all 'hidden' symbols, this guard is to prevent 775 // instantiating a python reserved word in their wrapper. 776 if (python_op_gen_internal::IsPythonReserved(function_name)) { 777 continue; 778 } 779 780 const auto* api_def = api_defs.GetApiDef(op_def.name()); 781 strings::StrAppend(&result, 782 GetEagerPythonOp(op_def, *api_def, function_name)); 783 784 if (!require_shapes) { 785 strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), 786 "\")(None)\n\n"); 787 } 788 789 auto added = out->Add(); 790 *added = op_def; 791 RemoveNonDeprecationDescriptionsFromOpDef(added); 792 } 793 794 result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes): 795 op_list = _op_def_pb2.OpList() 796 op_list.ParseFromString(op_list_proto_bytes) 797 _op_def_registry.register_op_list(op_list) 798 op_def_lib = _op_def_library.OpDefLibrary() 799 op_def_lib.add_op_list(op_list) 800 return op_def_lib 801)"); 802 803 result.append("# "); 804 auto ops_text = ProtoDebugString(cleaned_ops); 805 str_util::StripTrailingWhitespace(&ops_text); 806 result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true)); 807 result.append("\n"); 808 strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n", 809 str_util::CEscape(cleaned_ops.SerializeAsString()).c_str()); 810 return result; 811} 812 813} // namespace 814 815void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, 816 const std::vector<string>& hidden_ops, 817 bool require_shapes, const string& source_file_name) { 818 printf("%s", GetEagerPythonOps(ops, api_defs, hidden_ops, require_shapes, 819 source_file_name) 820 .c_str()); 821} 822 823string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) { 824 string op_list_str(op_list_buf, op_list_len); 825 OpList ops; 826 ops.ParseFromString(op_list_str); 827 828 ApiDefMap api_def_map(ops); 829 return GetEagerPythonOps(ops, api_def_map, {}, false); 830} 831 832} // namespace tensorflow 833