python_eager_op_gen.cc revision fd3882dd5cb75773fb12b3a84962411c2df2a300
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15#include "tensorflow/python/eager/python_eager_op_gen.h" 16 17#include <stdio.h> 18#include <sstream> 19#include <unordered_map> 20#include "tensorflow/core/framework/attr_value.pb.h" 21#include "tensorflow/core/framework/op.h" 22#include "tensorflow/core/framework/op_def.pb_text.h" 23#include "tensorflow/core/framework/op_def.pb.h" 24#include "tensorflow/core/framework/op_def_util.h" 25#include "tensorflow/core/framework/op_gen_lib.h" 26#include "tensorflow/core/framework/tensor.pb_text.h" 27#include "tensorflow/core/framework/types.h" 28#include "tensorflow/core/framework/types.pb.h" 29#include "tensorflow/core/lib/gtl/map_util.h" 30#include "tensorflow/core/lib/gtl/stl_util.h" 31#include "tensorflow/core/lib/strings/str_util.h" 32#include "tensorflow/core/lib/strings/strcat.h" 33#include "tensorflow/core/lib/strings/stringprintf.h" 34#include "tensorflow/core/platform/logging.h" 35#include "tensorflow/core/platform/macros.h" 36#include "tensorflow/core/platform/types.h" 37#include "tensorflow/python/framework/python_op_gen_internal.h" 38 39namespace tensorflow { 40namespace { 41 42const int kRightMargin = 78; 43 44string AttrVarName(const string& attr_name, 45 std::unordered_map<string, string>* attr_expressions) { 46 const string var = strings::StrCat("_attr_", attr_name); 47 if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var; 48 return var; 49} 50 51void AddInferredAttr(const string& attr_name, const string& value_expression, 52 string* result, 53 std::unordered_map<string, string>* attr_expressions) { 54 strings::StrAppend(result, " ", AttrVarName(attr_name, attr_expressions), 55 " = ", value_expression, "\n"); 56} 57 58string VectorToTuple(const std::vector<string>& l) { 59 if (l.size() == 1) return strings::StrCat("(", l.front(), ",)"); 60 string ret = "("; 61 for (int i = 0; i < l.size(); ++i) { 62 if (i > 0) { 63 strings::StrAppend(&ret, ", "); 64 } 65 strings::StrAppend(&ret, l[i]); 66 } 67 strings::StrAppend(&ret, ")"); 68 return ret; 69} 70 71void Unflatten(const string& prefix, const std::vector<string>& output_sizes, 72 const string& var, string* result) { 73 for (int i = 0; i < output_sizes.size(); ++i) { 74 if (!output_sizes[i].empty()) { 75 strings::StrAppend(result, prefix, var, " = "); 76 if (i > 0) strings::StrAppend(result, var, "[:", i, "] + "); 77 if (i + 1 < output_sizes.size()) { 78 // Special case i == 0 to avoid "0 +" in the generated code. 79 if (i == 0) { 80 strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ", 81 var, "[", output_sizes[i], ":]"); 82 } else { 83 strings::StrAppend(result, "[", var, "[", i, ":", i, " + ", 84 output_sizes[i], "]] + ", var, "[", i, " + ", 85 output_sizes[i], ":]"); 86 } 87 } else { 88 strings::StrAppend(result, "[", var, "[", i, ":]]"); 89 } 90 strings::StrAppend(result, "\n"); 91 } 92 } 93} 94 95string TensorPBString(const TensorProto& pb) { 96 // Note: This gets used in the argument list, and so must survive naive 97 // word wrapping. 98 return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\""); 99} 100 101class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { 102 public: 103 GenEagerPythonOp(const OpDef& op_def, const string& function_name) 104 : python_op_gen_internal::GenPythonOp(op_def, function_name) { 105 op_name_ = function_name_; 106 op_name_.Consume("_"); 107 } 108 ~GenEagerPythonOp() override {} 109 110 string Code() override; 111 112 protected: 113 void ExpectListArg(const string& arg_name); 114 void AddEagerInferredAttrs(); 115 void AddEagerInputCasts(); 116 void AddEagerAttrs(); 117 void AddEagerExecute(const string& num_outputs_expr); 118 119 void AddAttrForArg(const string& attr, int arg_index) { 120 gtl::InsertIfNotPresent(&inferred_attrs_, attr, 121 op_def_.input_arg(arg_index).name()); 122 auto iter = attr_to_args_.find(attr); 123 if (iter == attr_to_args_.end()) { 124 attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index})); 125 } else { 126 iter->second.push_back(arg_index); 127 } 128 } 129 130 // Returns a string expression representing a flattened list of all 131 // the inputs given by `*input_indices` (or all inputs if 132 // `input_indices` is nullptr). `*output_sizes` can be used to unflatten. 133 string FlattenInputs(const std::vector<int>* input_indices, 134 std::vector<string>* output_sizes) const; 135 136 StringPiece op_name_; 137 typedef std::unordered_map<string, std::vector<int>> AttrToArgMap; 138 AttrToArgMap attr_to_args_; 139 std::unordered_map<string, string> attr_expressions_; 140}; 141 142string GetEagerPythonOp(const OpDef& op_def, const string& function_name) { 143 return GenEagerPythonOp(op_def, function_name).Code(); 144} 145 146string GenEagerPythonOp::FlattenInputs( 147 const std::vector<int>* input_indices, 148 std::vector<string>* output_sizes) const { 149 string inputs; 150 enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING; 151 const int n = input_indices != nullptr ? input_indices->size() 152 : op_def_.input_arg_size(); 153 for (int j = 0; j < n; ++j) { 154 const int i = input_indices ? (*input_indices)[j] : j; 155 const auto& arg(op_def_.input_arg(i)); 156 const bool is_list = 157 !arg.type_list_attr().empty() || !arg.number_attr().empty(); 158 if (is_list) { 159 if (inputs_state == WAS_SOLO_INPUT) { 160 strings::StrAppend(&inputs, "] + "); 161 } else if (inputs_state == WAS_LIST_INPUT) { 162 strings::StrAppend(&inputs, " + "); 163 } 164 strings::StrAppend(&inputs, "list(", param_names_[i], ")"); 165 inputs_state = WAS_LIST_INPUT; 166 if (output_sizes != nullptr) { 167 if (!arg.number_attr().empty()) { 168 output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr)); 169 } else { 170 output_sizes->emplace_back( 171 strings::StrCat("len(", param_names_[i], ")")); 172 } 173 } 174 } else { 175 if (inputs_state == WAS_SOLO_INPUT) { 176 strings::StrAppend(&inputs, ", "); 177 } else if (inputs_state == WAS_LIST_INPUT) { 178 strings::StrAppend(&inputs, " + ["); 179 } else { 180 strings::StrAppend(&inputs, "["); 181 } 182 strings::StrAppend(&inputs, param_names_[i]); 183 inputs_state = WAS_SOLO_INPUT; 184 if (output_sizes != nullptr) output_sizes->emplace_back(); 185 } 186 } 187 if (inputs_state == STARTING) return "[]"; 188 if (inputs_state == WAS_SOLO_INPUT) { 189 strings::StrAppend(&inputs, "]"); 190 } 191 return inputs; 192} 193 194string GenEagerPythonOp::Code() { 195 // This has all the input args followed by those attrs that don't have 196 // defaults. 197 std::vector<string> args_no_default; 198 // The parameters with defaults (these have to be listed after those without). 199 // No input args are included, just attrs. 200 std::vector<std::pair<string, string>> args_with_defaults; 201 for (int i = 0; i < op_def_.input_arg_size(); ++i) { 202 const auto& arg(op_def_.input_arg(i)); 203 args_no_default.push_back(arg.name()); 204 if (!arg.type_attr().empty()) { 205 AddAttrForArg(arg.type_attr(), i); 206 } else if (!arg.type_list_attr().empty()) { 207 AddAttrForArg(arg.type_list_attr(), i); 208 } 209 if (!arg.number_attr().empty()) { 210 AddAttrForArg(arg.number_attr(), i); 211 } 212 } 213 for (int i = 0; i < op_def_.attr_size(); ++i) { 214 const auto& attr(op_def_.attr(i)); 215 // Do not add inferred attrs to the Python function signature. 216 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { 217 if (attr.has_default_value()) { 218 if (attr.type() == "tensor") { 219 args_with_defaults.emplace_back( 220 attr.name(), 221 strings::StrCat("_execute.make_tensor(", 222 TensorPBString(attr.default_value().tensor()), 223 ", \"", attr.name(), "\")")); 224 } else if (attr.type() == "list(tensor)") { 225 std::vector<string> pbtxt; 226 for (const auto& pb : attr.default_value().list().tensor()) { 227 pbtxt.emplace_back(TensorPBString(pb)); 228 } 229 args_with_defaults.emplace_back( 230 attr.name(), 231 strings::StrCat("[_execute.make_tensor(_pb, \"", attr.name(), 232 "\") for _pb in ", VectorToTuple(pbtxt), "]")); 233 } else { 234 args_with_defaults.emplace_back( 235 attr.name(), python_op_gen_internal::AttrValueToPython( 236 attr.type(), attr.default_value(), "_dtypes.")); 237 } 238 } else { 239 args_no_default.push_back(attr.name()); 240 } 241 } 242 } 243 244 // Save the list of attr parameters (attrs that won't be inferred), 245 // those with defaults go at the end. 246 // Get the attrs in the order we want by taking the attrs without defaults 247 // from the end of args_no_default, and adding args_no_default. 248 attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() + 249 args_with_defaults.size()); 250 attrs_.insert(attrs_.end(), 251 args_no_default.begin() + op_def_.input_arg_size(), 252 args_no_default.end()); 253 for (const auto& a : args_with_defaults) { 254 attrs_.push_back(a.first); 255 } 256 257 param_names_.reserve(args_no_default.size() + args_with_defaults.size()); 258 string parameters; 259 for (const string& name : args_no_default) { 260 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 261 const string param = python_op_gen_internal::AvoidPythonReserved(name); 262 strings::StrAppend(¶meters, param); 263 param_names_.push_back(param); 264 } 265 for (const auto& name_default : args_with_defaults) { 266 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 267 const string param = 268 python_op_gen_internal::AvoidPythonReserved(name_default.first); 269 strings::StrAppend(¶meters, param, "=", name_default.second); 270 param_names_.push_back(param); 271 } 272 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 273 strings::StrAppend(¶meters, "name=None"); 274 275 AddDefLine(parameters); 276 AddDocStringDescription(); 277 AddDocStringArgs(); 278 AddDocStringInputs(); 279 AddDocStringAttrs(); 280 AddDocStringNameArg(); 281 AddOutputGlobals(); 282 AddDocStringOutputs(); 283 strings::StrAppend(&result_, " \"\"\"\n"); 284 285 // Function body. 286 287 // Validate list inputs, infer length attrs. 288 for (int i = 0; i < op_def_.attr_size(); ++i) { 289 const auto& attr(op_def_.attr(i)); 290 if (attr.type() == "int") { 291 auto arg_list = attr_to_args_.find(attr.name()); 292 if (arg_list != attr_to_args_.end()) { 293 // Inferred int attrs are the lengths of inputs. Validate those 294 // inputs are lists and have the same length. 295 for (auto iter = arg_list->second.begin(); 296 iter != arg_list->second.end(); ++iter) { 297 const string& arg_name = param_names_[*iter]; 298 ExpectListArg(arg_name); 299 if (iter == arg_list->second.begin()) { 300 AddInferredAttr(attr.name(), strings::StrCat("len(", arg_name, ")"), 301 &result_, &attr_expressions_); 302 } else { 303 const auto& attr_var = attr_expressions_[attr.name()]; 304 strings::StrAppend(&result_, " if len(", arg_name, 305 ") != ", attr_var, 306 ":\n" 307 " raise ValueError(\n" 308 " \"List argument '", 309 arg_name, "' to '", op_name_, 310 "' Op with length %d \"\n" 311 " \"must match length %d of argument '", 312 inferred_attrs_[attr.name()], 313 "'.\" %\n" 314 " (len(", 315 arg_name, "), ", attr_var, "))\n"); 316 } 317 } 318 } 319 } 320 } 321 322 // Values for non-inferred attrs. 323 for (int i = 0; i < attrs_.size(); ++i) { 324 const string& attr_name = attrs_[i]; 325 const string& param = param_names_[i + op_def_.input_arg_size()]; 326 const auto& attr = *FindAttr(attr_name, op_def_); 327 StringPiece attr_type = attr.type(); 328 attr_expressions_[attr_name] = param; 329 const int default_index = i - (attrs_.size() - args_with_defaults.size()); 330 if (default_index >= 0) { 331 const string& default_value = args_with_defaults[default_index].second; 332 strings::StrAppend(&result_, " if ", param, " is None:\n"); 333 strings::StrAppend(&result_, " ", param, " = ", default_value, "\n"); 334 } 335 if (attr_type.starts_with("list(")) { 336 ExpectListArg(param); 337 } 338 339 if (attr_type == "string") { 340 strings::StrAppend(&result_, " ", param, " = _execute.make_str(", param, 341 ", \"", param, "\")\n"); 342 } else if (attr_type == "list(string)") { 343 strings::StrAppend(&result_, " ", param, " = [_execute.make_str(_s, \"", 344 param, "\") for _s in ", param, "]\n"); 345 } else if (attr_type == "int") { 346 strings::StrAppend(&result_, " ", param, " = _execute.make_int(", param, 347 ", \"", param, "\")\n"); 348 } else if (attr_type == "list(int)") { 349 strings::StrAppend(&result_, " ", param, " = [_execute.make_int(_i, \"", 350 param, "\") for _i in ", param, "]\n"); 351 } else if (attr_type == "float") { 352 strings::StrAppend(&result_, " ", param, " = _execute.make_float(", 353 param, ", \"", param, "\")\n"); 354 } else if (attr_type == "list(float)") { 355 strings::StrAppend(&result_, " ", param, 356 " = [_execute.make_float(_f, \"", param, 357 "\") for _f in ", param, "]\n"); 358 } else if (attr_type == "bool") { 359 strings::StrAppend(&result_, " ", param, " = _execute.make_bool(", param, 360 ", \"", param, "\")\n"); 361 } else if (attr_type == "list(bool)") { 362 strings::StrAppend(&result_, " ", param, " = [_execute.make_bool(_b, \"", 363 param, "\") for _b in ", param, "]\n"); 364 } else if (attr_type == "type") { 365 strings::StrAppend(&result_, " ", param, " = _execute.make_type(", param, 366 ", \"", param, "\")\n"); 367 } else if (attr_type == "list(type)") { 368 strings::StrAppend(&result_, " ", param, " = [_execute.make_type(_t, \"", 369 param, "\") for _t in ", param, "]\n"); 370 } else if (attr_type == "shape") { 371 strings::StrAppend(&result_, " ", param, " = _execute.make_shape(", 372 param, ", \"", param, "\")\n"); 373 } else if (attr_type == "list(shape)") { 374 strings::StrAppend(&result_, " ", param, 375 " = [_execute.make_shape(_s, \"", param, 376 "\") for _s in ", param, "]\n"); 377 } else if (attr_type == "tensor") { 378 strings::StrAppend(&result_, " ", param, " = _execute.make_tensor(", 379 param, ", \"", param, "\")\n"); 380 } else if (attr_type == "list(tensor)") { 381 strings::StrAppend(&result_, " ", param, 382 " = [_execute.make_tensor(_t, \"", param, 383 "\") for _t in ", param, "]\n"); 384 } else if (attr_type != "func") { 385 return strings::StrCat("# No definition for ", function_name_, 386 " since we don't support attrs with type\n" 387 "# '", 388 attr_type, "' right now.\n\n"); 389 } 390 } 391 392 // Figure out the list of inputs. 393 const string inputs = FlattenInputs(nullptr, nullptr); 394 395 // Handle graph-mode case 396 strings::StrAppend(&result_, 397 " _ctx = _context.context()\n" 398 399 " if _ctx.in_graph_mode():\n" 400 " _, _, _op = _op_def_lib._apply_op_helper(\n"); 401 AddBodyNoReturn(" "); 402 if (num_outs_ > 0) { 403 strings::StrAppend(&result_, " _result = _op.outputs[:]\n"); 404 // Special case handling for stateful op with single list output 405 // that might be empty. 406 if (num_outs_ == 1 && op_def_.is_stateful() && 407 (!op_def_.output_arg(0).number_attr().empty() || 408 !op_def_.output_arg(0).type_list_attr().empty())) { 409 // TODO(josh11b): Can skip this if the number_attr/type_list_attr has 410 // a constraint indicating that this can never be empty. 411 strings::StrAppend(&result_, 412 " if not _result:\n" 413 " return _op\n"); 414 } 415 strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n"); 416 417 // Compute graph-mode attrs. 418 if (op_def_.attr_size() > 0) { 419 string attr_values; 420 for (int i = 0; i < op_def_.attr_size(); ++i) { 421 if (i > 0) strings::StrAppend(&attr_values, ", "); 422 const auto& attr_name(op_def_.attr(i).name()); 423 strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"", 424 attr_name, "\")"); 425 } 426 strings::StrAppend(&attr_values, ")"); 427 strings::StrAppend(&result_, 428 WordWrap(" _attrs = (", attr_values, kRightMargin), 429 "\n"); 430 } else { 431 strings::StrAppend(&result_, " _attrs = None\n"); 432 } 433 } else { 434 strings::StrAppend(&result_, " return _op\n"); 435 } 436 437 // Handle eager-mode case 438 strings::StrAppend(&result_, " else:\n"); 439 440 // Expression representing the number of outputs. 441 int num_fixed_outputs = 0; 442 string num_outputs_expr; 443 // If output i is list output, output_sizes[i] will be set to a 444 // string with the python expression that will evaluate to its 445 // length. output_sizes[i] is empty for non-list outputs. 446 std::vector<string> output_sizes(num_outs_); 447 for (int i = 0; i < num_outs_; ++i) { 448 const auto& arg(op_def_.output_arg(i)); 449 if (!arg.number_attr().empty()) { 450 if (!num_outputs_expr.empty()) { 451 strings::StrAppend(&num_outputs_expr, " + "); 452 } 453 output_sizes[i] = attr_expressions_[arg.number_attr()]; 454 strings::StrAppend(&num_outputs_expr, output_sizes[i]); 455 } else if (!arg.type_list_attr().empty()) { 456 if (!num_outputs_expr.empty()) { 457 strings::StrAppend(&num_outputs_expr, " + "); 458 } 459 // Have to be careful to use an expression that works in both 460 // graph and eager paths here. 461 const auto iter = inferred_attrs_.find(arg.type_list_attr()); 462 if (iter == inferred_attrs_.end()) { 463 output_sizes[i] = strings::StrCat( 464 "len(", attr_expressions_[arg.type_list_attr()], ")"); 465 } else { 466 output_sizes[i] = strings::StrCat("len(", iter->second, ")"); 467 } 468 strings::StrAppend(&num_outputs_expr, output_sizes[i]); 469 } else { 470 ++num_fixed_outputs; 471 } 472 } 473 if (num_fixed_outputs > 0) { 474 if (!num_outputs_expr.empty()) { 475 strings::StrAppend(&num_outputs_expr, " + "); 476 } 477 strings::StrAppend(&num_outputs_expr, num_fixed_outputs); 478 } else if (num_outputs_expr.empty()) { 479 num_outputs_expr = "0"; 480 } 481 482 bool eager_allowed = true; 483 string ref_arg; 484 for (const auto& arg : op_def_.input_arg()) { 485 if (arg.is_ref()) { 486 eager_allowed = false; 487 ref_arg = arg.name(); 488 } 489 } 490 for (const auto& arg : op_def_.output_arg()) { 491 if (arg.is_ref()) { 492 eager_allowed = false; 493 ref_arg = arg.name(); 494 } 495 } 496 497 if (eager_allowed) { 498 AddEagerInferredAttrs(); 499 AddEagerInputCasts(); 500 strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n"); 501 AddEagerAttrs(); 502 AddEagerExecute(num_outputs_expr); 503 } else { 504 strings::StrAppend(&result_, 505 " raise RuntimeError(\n" 506 " \"", 507 op_name_, " op does not support eager execution. ", 508 "Arg '", ref_arg, "'' is a ref.\")\n"); 509 } 510 511 if (num_outs_ > 0) { 512 strings::StrAppend(&result_, " _execute.record_gradient(\n", " \"", 513 op_def_.name(), 514 "\", _inputs_flat, _attrs, _result, _ctx, name)\n"); 515 if (num_outs_ == 1 && !output_sizes[0].empty()) { 516 // Single list result. 517 } else if (num_outs_ == 1) { 518 // Execute returns a single-element list which we need to destructure. 519 strings::StrAppend(&result_, " _result, = _result\n"); 520 } else { 521 // Have multiple outputs, so we will need to reformat the return 522 // value of execute() to be a list with one entry per op output 523 // (that entry will be a list of tensors if that output is of list 524 // type). 525 // For list outputs, convert the right subrange of _result into a list. 526 Unflatten(" ", output_sizes, "_result", &result_); 527 // Convert to a named tuple. 528 strings::StrAppend(&result_, " _result = _", op_def_.name(), 529 "Output._make(_result)\n"); 530 } 531 } 532 strings::StrAppend(&result_, " return _result\n\n"); 533 return prelude_ + result_; 534} 535 536void GenEagerPythonOp::ExpectListArg(const string& arg_name) { 537 strings::StrAppend(&result_, " if not isinstance(", arg_name, 538 ", (list, tuple)):\n" 539 " raise TypeError(\n" 540 " \"Expected list for '", 541 arg_name, 542 "' argument to \"\n" 543 " \"'", 544 op_name_, "' Op, not %r.\" % ", arg_name, ")\n"); 545} 546 547void GenEagerPythonOp::AddEagerInferredAttrs() { 548 // Figure out values for inferred attrs, and cast to eager tensors. 549 for (int i = 0; i < op_def_.attr_size(); ++i) { 550 const auto& attr(op_def_.attr(i)); 551 auto arg_list = attr_to_args_.find(attr.name()); 552 if (arg_list != attr_to_args_.end()) { 553 if (attr.type() == "type") { 554 std::vector<string> output_sizes; 555 const string flattened = 556 FlattenInputs(&arg_list->second, &output_sizes); 557 string conversion = strings::StrCat("_execute.args_to_matching_eager(", 558 flattened, ", _ctx"); 559 if (attr.has_default_value()) { 560 strings::StrAppend( 561 &conversion, ", ", 562 python_op_gen_internal::AttrValueToPython( 563 attr.type(), attr.default_value(), "_dtypes.")); 564 } 565 strings::StrAppend(&conversion, ")"); 566 const string var_name = AttrVarName(attr.name(), &attr_expressions_); 567 if (output_sizes.size() == 1) { 568 // Avoid creating a temporary variable in the case where 569 // we can easily assign to the right value directly. 570 const string inputs_var = param_names_[arg_list->second.front()]; 571 if (output_sizes.front().empty()) { 572 strings::StrAppend(&result_, " ", var_name, ", (", inputs_var, 573 ",) = ", conversion, "\n"); 574 } else { 575 strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, 576 " = ", conversion, "\n"); 577 } 578 } else { 579 const string inputs_var = strings::StrCat("_inputs_", attr.name()); 580 strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, 581 " = ", conversion, "\n"); 582 // Convert from a flat list of eager tensors back to the 583 // parameter variables. 584 Unflatten(" ", output_sizes, inputs_var, &result_); 585 std::vector<string> p; 586 for (int j : arg_list->second) { 587 p.emplace_back(param_names_[j]); 588 } 589 strings::StrAppend(&result_, " ", VectorToTuple(p), " = ", 590 inputs_var, "\n"); 591 } 592 strings::StrAppend(&result_, " ", var_name, " = ", var_name, 593 ".as_datatype_enum\n"); 594 } else if (attr.type() == "list(type)") { 595 // NOTE: We ignore default values for these attrs, since it is 596 // unclear how you would use it, and the one use case is 597 // parse_single_sequence_example which only needs it for 598 // backwards compatibility. 599 const string var_name = AttrVarName(attr.name(), &attr_expressions_); 600 string inputs_var; 601 string conversion; 602 if (arg_list->second.size() > 1) { 603 // If you have more than one list(tensor) argument, their types 604 // have to match. 605 std::vector<string> lists; 606 for (auto iter = arg_list->second.begin(); 607 iter != arg_list->second.end(); ++iter) { 608 lists.push_back(param_names_[*iter]); 609 } 610 inputs_var = VectorToTuple(lists); 611 conversion = "_execute.args_to_mixed_eager_tensors"; 612 } else { 613 // For one list(tensor) argument, we just convert every 614 // element of the list to an eager tensor. 615 inputs_var = param_names_[arg_list->second.front()]; 616 conversion = "_execute.convert_to_mixed_eager_tensors"; 617 } 618 strings::StrAppend(&result_, " ", var_name, ", ", inputs_var, " = ", 619 conversion, "(", inputs_var, ", _ctx)\n"); 620 strings::StrAppend(&result_, " ", var_name, 621 " = [_t.as_datatype_enum for _t in ", var_name, 622 "]\n"); 623 } 624 } 625 } 626} 627 628void GenEagerPythonOp::AddEagerInputCasts() { 629 // Cast remaining args to eager tensors 630 for (int i = 0; i < op_def_.input_arg_size(); ++i) { 631 const auto& arg(op_def_.input_arg(i)); 632 if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue; 633 const string& param = param_names_[i]; 634 const string fn = arg.number_attr().empty() ? "" : "n_"; 635 const string dtype = 636 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes."); 637 strings::StrAppend(&result_, " ", param, " = _ops.convert_", fn, 638 "to_tensor(", param, ", ", dtype, ")\n"); 639 } 640} 641 642void GenEagerPythonOp::AddEagerAttrs() { 643 // Compute eager attrs 644 if (op_def_.attr_size() > 0) { 645 string attr_values; 646 for (int i = 0; i < op_def_.attr_size(); ++i) { 647 if (i > 0) strings::StrAppend(&attr_values, ", "); 648 const auto& attr_name(op_def_.attr(i).name()); 649 strings::StrAppend(&attr_values, "\"", attr_name, "\", ", 650 attr_expressions_[attr_name]); 651 } 652 strings::StrAppend(&attr_values, ")"); 653 strings::StrAppend( 654 &result_, WordWrap(" _attrs = (", attr_values, kRightMargin), "\n"); 655 } else { 656 strings::StrAppend(&result_, " _attrs = None\n"); 657 } 658} 659 660void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) { 661 const string return_prefix = " _result = _execute.execute("; 662 const string return_args = strings::StrCat( 663 "b\"", op_def_.name(), "\", ", num_outputs_expr, 664 ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)"); 665 strings::StrAppend(&result_, 666 // Wrap the arguments, and indent to the (. 667 WordWrap(return_prefix, return_args, kRightMargin), "\n"); 668} 669 670string GetEagerPythonOps(const OpList& ops, 671 const std::vector<string>& hidden_ops, 672 bool require_shapes, 673 const string& source_file_name = "") { 674 string result; 675 // Header 676 // TODO(josh11b): Mention the library for which wrappers are being generated. 677 strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops. 678 679This file is MACHINE GENERATED! Do not edit. 680)"); 681 682 // Mention the original source file so someone tracing back through generated 683 // Python code will know where to look next. 684 if (!source_file_name.empty()) { 685 strings::StrAppend(&result, "Original C++ source file: "); 686 strings::StrAppend(&result, source_file_name); 687 strings::StrAppend(&result, "\n"); 688 } 689 690 strings::StrAppend(&result, R"(""" 691 692import collections as _collections 693 694from tensorflow.python.eager import execute as _execute 695from tensorflow.python.eager import context as _context 696from tensorflow.python.eager import core as _core 697from tensorflow.python.framework import dtypes as _dtypes 698from tensorflow.python.framework import tensor_shape as _tensor_shape 699 700from tensorflow.core.framework import op_def_pb2 as _op_def_pb2 701# Needed to trigger the call to _set_call_cpp_shape_fn. 702from tensorflow.python.framework import common_shapes as _common_shapes 703from tensorflow.python.framework import op_def_registry as _op_def_registry 704from tensorflow.python.framework import ops as _ops 705from tensorflow.python.framework import op_def_library as _op_def_library 706 707)"); 708 709 // We'll make a copy of ops that filters out descriptions. 710 OpList cleaned_ops; 711 auto out = cleaned_ops.mutable_op(); 712 out->Reserve(ops.op_size()); 713 for (const auto& op_def : ops.op()) { 714 bool is_hidden = false; 715 for (const string& hidden : hidden_ops) { 716 if (op_def.name() == hidden) { 717 is_hidden = true; 718 break; 719 } 720 } 721 722 string function_name; 723 python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(), 724 &function_name); 725 if (is_hidden) function_name = strings::StrCat("_", function_name); 726 727 // When users create custom python wrappers, they may link in the 728 // default op registry by accident, and because they can't 729 // enumerate all 'hidden' symbols, this guard is to prevent 730 // instantiating a python reserved word in their wrapper. 731 if (python_op_gen_internal::IsPythonReserved(function_name)) { 732 continue; 733 } 734 735 strings::StrAppend(&result, GetEagerPythonOp(op_def, function_name)); 736 737 if (!require_shapes) { 738 strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), 739 "\")(None)\n\n"); 740 } 741 742 auto added = out->Add(); 743 *added = op_def; 744 RemoveNonDeprecationDescriptionsFromOpDef(added); 745 } 746 747 result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes): 748 op_list = _op_def_pb2.OpList() 749 op_list.ParseFromString(op_list_proto_bytes) 750 _op_def_registry.register_op_list(op_list) 751 op_def_lib = _op_def_library.OpDefLibrary() 752 op_def_lib.add_op_list(op_list) 753 return op_def_lib 754)"); 755 756 result.append("# "); 757 auto ops_text = ProtoDebugString(cleaned_ops); 758 str_util::StripTrailingWhitespace(&ops_text); 759 result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true)); 760 result.append("\n"); 761 strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n", 762 str_util::CEscape(cleaned_ops.SerializeAsString()).c_str()); 763 return result; 764} 765 766} // namespace 767 768void PrintEagerPythonOps(const OpList& ops, 769 const std::vector<string>& hidden_ops, 770 bool require_shapes, const string& source_file_name) { 771 printf("%s", 772 GetEagerPythonOps(ops, hidden_ops, require_shapes, source_file_name) 773 .c_str()); 774} 775 776string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) { 777 string op_list_str(op_list_buf, op_list_len); 778 OpList ops; 779 ops.ParseFromString(op_list_str); 780 return GetEagerPythonOps(ops, {}, false); 781} 782 783} // namespace tensorflow 784