python_eager_op_gen.cc revision 40f0cf009641ef0d827729bb01e8ed50d97fd109
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/eager/python_eager_op_gen.h"
16
17#include <stdio.h>
18#include <sstream>
19#include <unordered_map>
20#include "tensorflow/core/framework/api_def.pb.h"
21#include "tensorflow/core/framework/attr_value.pb.h"
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/framework/op_def.pb_text.h"
24#include "tensorflow/core/framework/op_def.pb.h"
25#include "tensorflow/core/framework/op_def_util.h"
26#include "tensorflow/core/framework/op_gen_lib.h"
27#include "tensorflow/core/framework/tensor.pb_text.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/framework/types.pb.h"
30#include "tensorflow/core/lib/gtl/map_util.h"
31#include "tensorflow/core/lib/gtl/stl_util.h"
32#include "tensorflow/core/lib/strings/str_util.h"
33#include "tensorflow/core/lib/strings/strcat.h"
34#include "tensorflow/core/lib/strings/stringprintf.h"
35#include "tensorflow/core/platform/logging.h"
36#include "tensorflow/core/platform/macros.h"
37#include "tensorflow/core/platform/types.h"
38#include "tensorflow/python/framework/python_op_gen_internal.h"
39
40namespace tensorflow {
41namespace {
42
43const int kRightMargin = 78;
44
45constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
46
47string AttrVarName(const string& attr_name,
48                   std::unordered_map<string, string>* attr_expressions) {
49  const string var = strings::StrCat("_attr_", attr_name);
50  if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
51  return var;
52}
53
54void AddInferredAttr(const string& indentation, const string& attr_name,
55                     const string& value_expression, string* result,
56                     std::unordered_map<string, string>* attr_expressions) {
57  strings::StrAppend(result, indentation,
58                     AttrVarName(attr_name, attr_expressions), " = ",
59                     value_expression, "\n");
60}
61
62string VectorToTuple(const std::vector<string>& l) {
63  if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
64  string ret = "(";
65  for (int i = 0; i < l.size(); ++i) {
66    if (i > 0) {
67      strings::StrAppend(&ret, ", ");
68    }
69    strings::StrAppend(&ret, l[i]);
70  }
71  strings::StrAppend(&ret, ")");
72  return ret;
73}
74
75void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
76               const string& var, string* result) {
77  for (int i = 0; i < output_sizes.size(); ++i) {
78    if (!output_sizes[i].empty()) {
79      strings::StrAppend(result, prefix, var, " = ");
80      if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
81      if (i + 1 < output_sizes.size()) {
82        // Special case i == 0 to avoid "0 +" in the generated code.
83        if (i == 0) {
84          strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
85                             var, "[", output_sizes[i], ":]");
86        } else {
87          strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
88                             output_sizes[i], "]] + ", var, "[", i, " + ",
89                             output_sizes[i], ":]");
90        }
91      } else {
92        strings::StrAppend(result, "[", var, "[", i, ":]]");
93      }
94      strings::StrAppend(result, "\n");
95    }
96  }
97}
98
99string TensorPBString(const TensorProto& pb) {
100  // Note: This gets used in the argument list, and so must survive naive
101  // word wrapping.
102  return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
103}
104
105const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
106  for (int i = 0; i < api_def.in_arg_size(); ++i) {
107    if (api_def.in_arg(i).name() == name) {
108      return &api_def.in_arg(i);
109    }
110  }
111  return nullptr;
112}
113
114class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
115 public:
116  GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
117                   const string& function_name)
118      : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) {
119    op_name_ = function_name_;
120    op_name_.Consume("_");
121  }
122  ~GenEagerPythonOp() override {}
123
124  string Code() override;
125
126 protected:
127  void HandleGraphMode(const string& function_setup);
128
129  string GetEagerNotAllowedError();
130  void ExpectListArg(const string& indentation, const string& arg_name,
131                     string* output);
132  bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
133  void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
134                                       string* num_outputs_expr);
135
136  void AddEagerFunctionTeardown(const string& indentation,
137                                const std::vector<string>& output_sizes,
138                                bool execute_record_gradient);
139
140  bool AddEagerFastPathAndGraphCode(const string& parameters,
141                                    const std::vector<string>& output_sizes,
142                                    const string& eager_not_allowed_error);
143  bool AddEagerFallbackCode(const string& parameters,
144                            const std::vector<string>& output_sizes,
145                            const string& num_outputs_expr,
146                            const string& eager_not_allowed_error);
147  void AddEagerFastPathExecute();
148
149  void AddEagerInferredAttrs(const string& indentation);
150  void AddEagerInputCasts(const string& indentation);
151  void AddEagerAttrs(const string& indentation);
152  void AddEagerExecute(const string& indentation,
153                       const string& num_outputs_expr);
154
155  void AddAttrForArg(const string& attr, int arg_index) {
156    gtl::InsertIfNotPresent(&inferred_attrs_, attr,
157                            op_def_.input_arg(arg_index).name());
158    auto iter = attr_to_args_.find(attr);
159    if (iter == attr_to_args_.end()) {
160      attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
161    } else {
162      iter->second.push_back(arg_index);
163    }
164  }
165
166  // Returns a string expression representing a flattened list of all
167  // the inputs given by `*input_indices` (or all inputs if
168  // `input_indices` is nullptr).  `*output_sizes` can be used to unflatten.
169  string FlattenInputs(const std::vector<int>* input_indices,
170                       std::vector<string>* output_sizes) const;
171
172  StringPiece op_name_;
173  typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
174  AttrToArgMap attr_to_args_;
175  std::unordered_map<string, string> attr_expressions_;
176  // This has all the input args followed by those attrs that don't have
177  // defaults.
178  std::vector<python_op_gen_internal::ParamNames> params_no_default_;
179  // The parameters with defaults (these have to be listed after those without).
180  // No input args are included, just attrs.
181  std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
182      params_with_default_;
183};
184
185string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
186                        const string& function_name) {
187  return GenEagerPythonOp(op_def, api_def, function_name).Code();
188}
189
190string GenEagerPythonOp::FlattenInputs(
191    const std::vector<int>* input_indices,
192    std::vector<string>* output_sizes) const {
193  string inputs;
194  enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
195  const int n = input_indices != nullptr ? input_indices->size()
196                                         : op_def_.input_arg_size();
197  for (int j = 0; j < n; ++j) {
198    const int i = input_indices ? (*input_indices)[j] : j;
199    const auto& arg(op_def_.input_arg(i));
200    const bool is_list =
201        !arg.type_list_attr().empty() || !arg.number_attr().empty();
202    if (is_list) {
203      if (inputs_state == WAS_SOLO_INPUT) {
204        strings::StrAppend(&inputs, "] + ");
205      } else if (inputs_state == WAS_LIST_INPUT) {
206        strings::StrAppend(&inputs, " + ");
207      }
208      strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
209      inputs_state = WAS_LIST_INPUT;
210      if (output_sizes != nullptr) {
211        if (!arg.number_attr().empty()) {
212          output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
213        } else {
214          output_sizes->emplace_back(
215              strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
216        }
217      }
218    } else {
219      if (inputs_state == WAS_SOLO_INPUT) {
220        strings::StrAppend(&inputs, ", ");
221      } else if (inputs_state == WAS_LIST_INPUT) {
222        strings::StrAppend(&inputs, " + [");
223      } else {
224        strings::StrAppend(&inputs, "[");
225      }
226      strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
227      inputs_state = WAS_SOLO_INPUT;
228      if (output_sizes != nullptr) output_sizes->emplace_back();
229    }
230  }
231  if (inputs_state == STARTING) return "[]";
232  if (inputs_state == WAS_SOLO_INPUT) {
233    strings::StrAppend(&inputs, "]");
234  }
235  return inputs;
236}
237
238string GenEagerPythonOp::Code() {
239  if (api_def_.visibility() == ApiDef::SKIP) {
240    return "";
241  }
242
243  for (int i = 0; i < api_def_.arg_order_size(); ++i) {
244    const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
245    const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
246    params_no_default_.emplace_back(api_def_arg.name(),
247                                    api_def_arg.rename_to());
248    if (!arg.type_attr().empty()) {
249      AddAttrForArg(arg.type_attr(), i);
250    } else if (!arg.type_list_attr().empty()) {
251      AddAttrForArg(arg.type_list_attr(), i);
252    }
253    if (!arg.number_attr().empty()) {
254      AddAttrForArg(arg.number_attr(), i);
255    }
256  }
257  for (int i = 0; i < op_def_.attr_size(); ++i) {
258    const auto& attr(op_def_.attr(i));
259    const auto& api_def_attr(api_def_.attr(i));
260    // Do not add inferred attrs to the Python function signature.
261    if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
262      if (api_def_attr.has_default_value()) {
263        if (attr.type() == "tensor") {
264          params_with_default_.emplace_back(
265              python_op_gen_internal::ParamNames(api_def_attr.name(),
266                                                 api_def_attr.rename_to()),
267              strings::StrCat(
268                  "_execute.make_tensor(",
269                  TensorPBString(api_def_attr.default_value().tensor()), ", \"",
270                  api_def_attr.rename_to(), "\")"));
271        } else if (attr.type() == "list(tensor)") {
272          std::vector<string> pbtxt;
273          for (const auto& pb : api_def_attr.default_value().list().tensor()) {
274            pbtxt.emplace_back(TensorPBString(pb));
275          }
276          params_with_default_.emplace_back(
277              python_op_gen_internal::ParamNames(api_def_attr.name(),
278                                                 api_def_attr.rename_to()),
279              strings::StrCat("[_execute.make_tensor(_pb, \"",
280                              api_def_attr.rename_to(), "\") for _pb in ",
281                              VectorToTuple(pbtxt), "]"));
282        } else {
283          params_with_default_.emplace_back(
284              python_op_gen_internal::ParamNames(api_def_attr.name(),
285                                                 api_def_attr.rename_to()),
286              python_op_gen_internal::AttrValueToPython(
287                  attr.type(), api_def_attr.default_value(), "_dtypes."));
288        }
289      } else {
290        params_no_default_.emplace_back(api_def_attr.name(),
291                                        api_def_attr.rename_to());
292      }
293    }
294  }
295
296  // Save the list of attr parameters (attrs that won't be inferred),
297  // those with defaults go at the end.
298  // Get the attrs in the order we want by taking the attrs without defaults
299  // from the end of params_no_default_, and adding params_no_default_.
300  attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
301                 params_with_default_.size());
302  for (int i = op_def_.input_arg_size(); i < params_no_default_.size(); ++i) {
303    attrs_.push_back(params_no_default_[i].GetName());
304  }
305  for (const auto& p : params_with_default_) {
306    attrs_.push_back(p.first.GetName());
307  }
308
309  param_names_.reserve(params_no_default_.size() + params_with_default_.size());
310  param_names_.insert(param_names_.begin(), params_no_default_.begin(),
311                      params_no_default_.end());
312  for (const auto& param_and_default : params_with_default_) {
313    param_names_.push_back(param_and_default.first);
314  }
315
316  string parameters;
317  for (const auto& param : params_no_default_) {
318    if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
319    strings::StrAppend(&parameters, param.GetRenameTo());
320  }
321  for (const auto& param_and_default : params_with_default_) {
322    if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
323    strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), "=",
324                       param_and_default.second);
325  }
326  if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
327  strings::StrAppend(&parameters, "name=None");
328
329  // Add attr_expressions_ for attrs that are params.
330  for (int i = 0; i < attrs_.size(); ++i) {
331    const string& attr_name = attrs_[i];
332    const string& attr_api_name =
333        param_names_[i + op_def_.input_arg_size()].GetRenameTo();
334    attr_expressions_[attr_name] = attr_api_name;
335  }
336  // Add attr_expressions_ for attrs that are inferred.
337  for (int i = 0; i < op_def_.attr_size(); ++i) {
338    const auto& attr(op_def_.attr(i));
339    if (attr.type() == "int") {
340      auto arg_list = attr_to_args_.find(attr.name());
341      if (arg_list != attr_to_args_.end()) {
342        AttrVarName(attr.name(), &attr_expressions_);
343      }
344    }
345  }
346
347  string num_outputs_expr;
348  std::vector<string> output_sizes(num_outs_);
349  GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
350
351  string eager_not_allowed_error = GetEagerNotAllowedError();
352
353  if (!AddEagerFastPathAndGraphCode(parameters, output_sizes,
354                                    eager_not_allowed_error)) {
355    return result_;
356  }
357
358  if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
359                            eager_not_allowed_error)) {
360    return result_;
361  }
362
363  return prelude_ + result_;
364}
365
366void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
367  // Handle graph-mode case
368  strings::StrAppend(&result_,
369                     "  _ctx = _context.context()\n"
370                     "  if _ctx.in_graph_mode():\n",
371                     function_setup,
372                     "    _, _, _op = _op_def_lib._apply_op_helper(\n");
373  AddBodyNoReturn("        ");
374  if (num_outs_ > 0) {
375    strings::StrAppend(&result_, "    _result = _op.outputs[:]\n");
376    // Special case handling for stateful op with single list output
377    // that might be empty.
378    if (num_outs_ == 1 && op_def_.is_stateful() &&
379        (!op_def_.output_arg(0).number_attr().empty() ||
380         !op_def_.output_arg(0).type_list_attr().empty())) {
381      // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
382      // a constraint indicating that this can never be empty.
383      strings::StrAppend(&result_,
384                         "    if not _result:\n"
385                         "      return _op\n");
386    }
387    strings::StrAppend(&result_, "    _inputs_flat = _op.inputs\n");
388
389    // Compute graph-mode attrs.
390    if (op_def_.attr_size() > 0) {
391      string attr_values;
392      for (int i = 0; i < op_def_.attr_size(); ++i) {
393        if (i > 0) strings::StrAppend(&attr_values, ", ");
394        const auto& attr_name(op_def_.attr(i).name());
395        strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
396                           attr_name, "\")");
397      }
398      strings::StrAppend(&attr_values, ")");
399      strings::StrAppend(&result_,
400                         WordWrap("    _attrs = (", attr_values, kRightMargin),
401                         "\n");
402    } else {
403      strings::StrAppend(&result_, "    _attrs = None\n");
404    }
405  } else {
406    strings::StrAppend(&result_, "    return _op\n");
407  }
408}
409
410string GenEagerPythonOp::GetEagerNotAllowedError() {
411  bool eager_allowed = true;
412  string ref_arg;
413  for (int i = 0; i < op_def_.input_arg_size(); ++i) {
414    const auto& arg = op_def_.input_arg(i);
415    if (arg.is_ref()) {
416      eager_allowed = false;
417      DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
418      ref_arg = api_def_.in_arg(i).rename_to();
419    }
420  }
421  for (int i = 0; i < op_def_.output_arg_size(); ++i) {
422    const auto& arg = op_def_.output_arg(i);
423    if (arg.is_ref()) {
424      eager_allowed = false;
425      DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
426      ref_arg = api_def_.out_arg(i).rename_to();
427    }
428  }
429
430  if (eager_allowed) return "";
431
432  return strings::StrCat("raise RuntimeError(\"", op_name_,
433                         " op does not support eager execution. ", "Arg '",
434                         ref_arg, "' is a ref.\")\n");
435}
436
437void GenEagerPythonOp::ExpectListArg(const string& indentation,
438                                     const string& arg_name, string* output) {
439  strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
440                     ", (list, tuple)):\n", indentation, "  raise TypeError(\n",
441                     indentation, "      \"Expected list for '", arg_name,
442                     "' argument to \"\n", indentation, "      \"'", op_name_,
443                     "' Op, not %r.\" % ", arg_name, ")\n");
444}
445
446bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
447                                             string* function_setup) {
448  // Validate list inputs, infer length attrs.
449  for (int i = 0; i < op_def_.attr_size(); ++i) {
450    const auto& attr(op_def_.attr(i));
451    if (attr.type() == "int") {
452      auto arg_list = attr_to_args_.find(attr.name());
453      if (arg_list != attr_to_args_.end()) {
454        // Inferred int attrs are the lengths of inputs. Validate those
455        // inputs are lists and have the same length.
456        for (auto iter = arg_list->second.begin();
457             iter != arg_list->second.end(); ++iter) {
458          const string& arg_api_name = param_names_[*iter].GetRenameTo();
459          ExpectListArg(indentation, arg_api_name, function_setup);
460          if (iter == arg_list->second.begin()) {
461            AddInferredAttr(indentation, attr.name(),
462                            strings::StrCat("len(", arg_api_name, ")"),
463                            function_setup, &attr_expressions_);
464          } else {
465            const auto& attr_var = attr_expressions_[attr.name()];
466            strings::StrAppend(
467                function_setup, indentation, "if len(", arg_api_name,
468                ") != ", attr_var, ":\n", indentation, "  raise ValueError(\n",
469                indentation, "      \"List argument '", arg_api_name, "' to '",
470                op_name_, "' Op with length %d \"\n", indentation,
471                "      \"must match length %d of argument '",
472                inferred_attrs_[attr.name()], "'.\" %\n", indentation,
473                "      (len(", arg_api_name, "), ", attr_var, "))\n");
474          }
475        }
476      }
477    }
478  }
479
480  for (int i = 0; i < attrs_.size(); ++i) {
481    const string& attr_name = attrs_[i];
482    const auto& param = param_names_[i + op_def_.input_arg_size()];
483    const auto& attr = *FindAttr(attr_name, op_def_);
484    const string& attr_api_name = param.GetRenameTo();
485    StringPiece attr_type = attr.type();
486    attr_expressions_[attr_name] = attr_api_name;
487    const int default_index = i - (attrs_.size() - params_with_default_.size());
488    if (default_index >= 0) {
489      const string& default_value = params_with_default_[default_index].second;
490      strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
491                         " is None:\n");
492      strings::StrAppend(function_setup, indentation, "  ", attr_api_name,
493                         " = ", default_value, "\n");
494    }
495    if (attr_type.starts_with("list(")) {
496      ExpectListArg(indentation, attr_api_name, function_setup);
497    }
498
499    if (attr_type == "string") {
500      strings::StrAppend(function_setup, indentation, attr_api_name,
501                         " = _execute.make_str(", attr_api_name, ", \"",
502                         attr_api_name, "\")\n");
503    } else if (attr_type == "list(string)") {
504      strings::StrAppend(function_setup, indentation, attr_api_name,
505                         " = [_execute.make_str(_s, \"", attr_api_name,
506                         "\") for _s in ", attr_api_name, "]\n");
507    } else if (attr_type == "int") {
508      strings::StrAppend(function_setup, indentation, attr_api_name,
509                         " = _execute.make_int(", attr_api_name, ", \"",
510                         attr_api_name, "\")\n");
511    } else if (attr_type == "list(int)") {
512      strings::StrAppend(function_setup, indentation, attr_api_name,
513                         " = [_execute.make_int(_i, \"", attr_api_name,
514                         "\") for _i in ", attr_api_name, "]\n");
515    } else if (attr_type == "float") {
516      strings::StrAppend(function_setup, indentation, attr_api_name,
517                         " = _execute.make_float(", attr_api_name, ", \"",
518                         attr_api_name, "\")\n");
519    } else if (attr_type == "list(float)") {
520      strings::StrAppend(function_setup, indentation, attr_api_name,
521                         " = [_execute.make_float(_f, \"", attr_api_name,
522                         "\") for _f in ", attr_api_name, "]\n");
523    } else if (attr_type == "bool") {
524      strings::StrAppend(function_setup, indentation, attr_api_name,
525                         " = _execute.make_bool(", attr_api_name, ", \"",
526                         attr_api_name, "\")\n");
527    } else if (attr_type == "list(bool)") {
528      strings::StrAppend(function_setup, indentation, attr_api_name,
529                         " = [_execute.make_bool(_b, \"", attr_api_name,
530                         "\") for _b in ", attr_api_name, "]\n");
531    } else if (attr_type == "type") {
532      strings::StrAppend(function_setup, indentation, attr_api_name,
533                         " = _execute.make_type(", attr_api_name, ", \"",
534                         attr_api_name, "\")\n");
535    } else if (attr_type == "list(type)") {
536      strings::StrAppend(function_setup, indentation, attr_api_name,
537                         " = [_execute.make_type(_t, \"", attr_api_name,
538                         "\") for _t in ", attr_api_name, "]\n");
539    } else if (attr_type == "shape") {
540      strings::StrAppend(function_setup, indentation, attr_api_name,
541                         " = _execute.make_shape(", attr_api_name, ", \"",
542                         attr_api_name, "\")\n");
543    } else if (attr_type == "list(shape)") {
544      strings::StrAppend(function_setup, indentation, attr_api_name,
545                         " = [_execute.make_shape(_s, \"", attr_api_name,
546                         "\") for _s in ", attr_api_name, "]\n");
547    } else if (attr_type == "tensor") {
548      strings::StrAppend(function_setup, indentation, attr_api_name,
549                         " = _execute.make_tensor(", attr_api_name, ", \"",
550                         attr_api_name, "\")\n");
551    } else if (attr_type == "list(tensor)") {
552      strings::StrAppend(function_setup, indentation, attr_api_name,
553                         " = [_execute.make_tensor(_t, \"", attr_api_name,
554                         "\") for _t in ", attr_api_name, "]\n");
555    } else if (attr_type != "func") {
556      *function_setup =
557          strings::StrCat("# No definition for ", function_name_,
558                          " since we don't support attrs with type\n"
559                          "# '",
560                          attr_type, "' right now.\n\n");
561      return false;
562    }
563  }
564  return true;
565}
566
567// If output i is list output, output_sizes[i] will be set to a
568// string with the python expression that will evaluate to its
569// length. output_sizes[i] is empty for non-list outputs.
570void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
571    std::vector<string>* output_sizes, string* num_outputs_expr) {
572  // Expression representing the number of outputs.
573  int num_fixed_outputs = 0;
574  for (int i = 0; i < num_outs_; ++i) {
575    const auto& arg(op_def_.output_arg(i));
576    if (!arg.number_attr().empty()) {
577      if (!num_outputs_expr->empty()) {
578        strings::StrAppend(num_outputs_expr, " + ");
579      }
580      (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
581      strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
582    } else if (!arg.type_list_attr().empty()) {
583      if (!num_outputs_expr->empty()) {
584        strings::StrAppend(num_outputs_expr, " + ");
585      }
586      // Have to be careful to use an expression that works in both
587      // graph and eager paths here.
588      const auto iter = inferred_attrs_.find(arg.type_list_attr());
589      if (iter == inferred_attrs_.end()) {
590        (*output_sizes)[i] = strings::StrCat(
591            "len(", attr_expressions_[arg.type_list_attr()], ")");
592      } else {
593        (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
594      }
595      strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
596    } else {
597      ++num_fixed_outputs;
598    }
599  }
600  if (num_fixed_outputs > 0) {
601    if (!num_outputs_expr->empty()) {
602      strings::StrAppend(num_outputs_expr, " + ");
603    }
604    strings::StrAppend(num_outputs_expr, num_fixed_outputs);
605  } else if (num_outputs_expr->empty()) {
606    *num_outputs_expr = "0";
607  }
608}
609
610void GenEagerPythonOp::AddEagerFunctionTeardown(
611    const string& indentation, const std::vector<string>& output_sizes,
612    bool execute_record_gradient) {
613  if (num_outs_ > 0) {
614    if (execute_record_gradient) {
615      strings::StrAppend(&result_, indentation, "_execute.record_gradient(\n",
616                         "      \"", op_def_.name(),
617                         "\", _inputs_flat, _attrs, _result, name)\n");
618    }
619    if (num_outs_ == 1 && !output_sizes[0].empty()) {
620      // Single list result.
621    } else if (num_outs_ == 1) {
622      // Execute returns a single-element list which we need to destructure.
623      strings::StrAppend(&result_, indentation, "_result, = _result\n");
624    } else {
625      // Have multiple outputs, so we will need to reformat the return
626      // value of execute() to be a list with one entry per op output
627      // (that entry will be a list of tensors if that output is of list
628      // type).
629      // For list outputs, convert the right subrange of _result into a list.
630      Unflatten(indentation, output_sizes, "_result", &result_);
631      // Convert to a named tuple.
632      strings::StrAppend(&result_, indentation, "_result = _", op_def_.name(),
633                         "Output._make(_result)\n");
634    }
635  } else {
636    strings::StrAppend(&result_, indentation, "_result = None\n");
637  }
638  strings::StrAppend(&result_, indentation, "return _result\n\n");
639}
640
641bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
642    const string& parameters, const std::vector<string>& output_sizes,
643    const string& eager_not_allowed_error) {
644  AddDefLine(function_name_, parameters);
645  AddDocStringDescription();
646  AddDocStringArgs();
647  AddDocStringInputs();
648  AddDocStringAttrs();
649  AddDocStringNameArg();
650  AddOutputGlobals();  // Added to prelude_
651  AddDocStringOutputs();
652  strings::StrAppend(&result_, "  \"\"\"\n");
653
654  // Handle graph-mode case
655  string function_setup;
656  if (!GetEagerFunctionSetup("    ", &function_setup)) {
657    result_ = function_setup;
658    return false;
659  }
660  HandleGraphMode(function_setup);
661  AddEagerFunctionTeardown("    ", output_sizes,
662                           true /* execute_record_gradient */);
663
664  // Handle eager-mode case
665  strings::StrAppend(&result_, "  else:\n");
666
667  if (eager_not_allowed_error.empty()) {
668    AddEagerFastPathExecute();
669  } else {
670    strings::StrAppend(&result_, "    ", eager_not_allowed_error);
671  }
672
673  strings::StrAppend(&result_, "\n\n");
674  return true;
675}
676
677bool GenEagerPythonOp::AddEagerFallbackCode(
678    const string& parameters, const std::vector<string>& output_sizes,
679    const string& num_outputs_expr, const string& eager_not_allowed_error) {
680  if (!eager_not_allowed_error.empty()) {
681    strings::StrAppend(&result_, "  ", eager_not_allowed_error);
682    return true;
683  }
684
685  AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix), parameters);
686  strings::StrAppend(
687      &result_, "  r\"\"\"This is the slowpath function for Eager mode.\n");
688  strings::StrAppend(&result_, "  This is for function ", function_name_,
689                     "\n  \"\"\"\n");
690
691  strings::StrAppend(&result_, "  _ctx = _context.context()\n");
692
693  string function_setup;
694  if (!GetEagerFunctionSetup("  ", &function_setup)) {
695    result_ = function_setup;
696    return false;
697  }
698  strings::StrAppend(&result_, function_setup);
699
700  AddEagerInferredAttrs("  ");
701  AddEagerInputCasts("  ");
702  strings::StrAppend(
703      &result_, "  _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
704  AddEagerAttrs("  ");
705  AddEagerExecute("  ", num_outputs_expr);
706
707  AddEagerFunctionTeardown("  ", output_sizes,
708                           true /* execute_record_gradient */);
709
710  return true;
711}
712
713void GenEagerPythonOp::AddEagerFastPathExecute() {
714  string fastpath_execute_params = strings::StrCat(
715      "_ctx._handle, _ctx.device_name, \"", op_def_.name(), "\", ",
716      "_execute.record_gradient, name, _ctx._post_execution_callbacks");
717  string fallback_params;
718
719  for (int i = 0; i < api_def_.in_arg_size(); i++) {
720    const string param_name = param_names_[i].GetRenameTo();
721    strings::StrAppend(&fastpath_execute_params, ", ", param_name);
722    if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
723    strings::StrAppend(&fallback_params, param_name);
724  }
725
726  for (const auto& attr : api_def_.attr()) {
727    if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
728      strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
729                         attr.rename_to());
730
731      if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
732      strings::StrAppend(&fallback_params, attr.rename_to(), "=",
733                         attr.rename_to());
734    }
735  }
736
737  if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
738  strings::StrAppend(&fallback_params, "name=name");
739
740  strings::StrAppend(&result_, "    try:\n");
741  strings::StrAppend(
742      &result_, "      ",
743      "_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n",
744      WordWrap(strings::StrCat("        "),
745               strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
746      "\n");
747
748  if (op_def_.output_arg_size() > 1) {
749    const string output_tuple_name =
750        strings::StrCat("_", op_def_.name(), "Output");
751    strings::StrAppend(&result_, "      ", "_result = ", output_tuple_name,
752                       "._make(_result)\n");
753  }
754  strings::StrAppend(&result_, "      ", "return _result\n");
755
756  // Handle fallback.
757  strings::StrAppend(&result_, "    ", "except _core._FallbackException:\n");
758  strings::StrAppend(
759      &result_, "      ", "return ", function_name_, kEagerFallbackSuffix,
760      "(\n",
761      WordWrap(strings::StrCat("          "),
762               strings::StrCat(fallback_params, ")"), kRightMargin),
763      "\n");
764
765  // Any errors thrown from execute need to be unwrapped from
766  // _NotOkStatusException.
767  strings::StrAppend(&result_, "    ",
768                     "except _core._NotOkStatusException as e:\n");
769  strings::StrAppend(&result_, "      ", "if name is not None:\n");
770  strings::StrAppend(&result_, "        ",
771                     "message = e.message + \" name: \" + name\n");
772  strings::StrAppend(&result_, "      ", "else:\n");
773  strings::StrAppend(&result_, "        ", "message = e.message\n");
774  strings::StrAppend(
775      &result_, "      ",
776      "_six.raise_from(_core._status_to_exception(e.code, message), None)\n");
777}
778
779void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
780  // Figure out values for inferred attrs, and cast to eager tensors.
781  for (int i = 0; i < op_def_.attr_size(); ++i) {
782    const auto& attr(op_def_.attr(i));
783    const auto& api_def_attr(api_def_.attr(i));
784    auto arg_list = attr_to_args_.find(attr.name());
785    if (arg_list != attr_to_args_.end()) {
786      if (attr.type() == "type") {
787        std::vector<string> output_sizes;
788        const string flattened =
789            FlattenInputs(&arg_list->second, &output_sizes);
790        string conversion = strings::StrCat("_execute.args_to_matching_eager(",
791                                            flattened, ", _ctx");
792        if (attr.has_default_value()) {
793          strings::StrAppend(
794              &conversion, ", ",
795              python_op_gen_internal::AttrValueToPython(
796                  attr.type(), api_def_attr.default_value(), "_dtypes."));
797        }
798        strings::StrAppend(&conversion, ")");
799        const string var_name = AttrVarName(attr.name(), &attr_expressions_);
800        if (output_sizes.size() == 1) {
801          // Avoid creating a temporary variable in the case where
802          // we can easily assign to the right value directly.
803          const string inputs_var =
804              param_names_[arg_list->second.front()].GetRenameTo();
805          if (output_sizes.front().empty()) {
806            strings::StrAppend(&result_, indentation, var_name, ", (",
807                               inputs_var, ",) = ", conversion, "\n");
808          } else {
809            strings::StrAppend(&result_, indentation, var_name, ", ",
810                               inputs_var, " = ", conversion, "\n");
811          }
812        } else {
813          const string inputs_var = strings::StrCat("_inputs_", attr.name());
814          strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
815                             " = ", conversion, "\n");
816          // Convert from a flat list of eager tensors back to the
817          // parameter variables.
818          Unflatten(indentation, output_sizes, inputs_var, &result_);
819          std::vector<string> p;
820          for (int j : arg_list->second) {
821            p.emplace_back(param_names_[j].GetRenameTo());
822          }
823          strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
824                             inputs_var, "\n");
825        }
826      } else if (attr.type() == "list(type)") {
827        // NOTE: We ignore default values for these attrs, since it is
828        // unclear how you would use it, and the one use case is
829        // parse_single_sequence_example which only needs it for
830        // backwards compatibility.
831        const string var_name = AttrVarName(attr.name(), &attr_expressions_);
832        string inputs_var;
833        string conversion;
834        if (arg_list->second.size() > 1) {
835          // If you have more than one list(tensor) argument, their types
836          // have to match.
837          std::vector<string> lists;
838          for (auto iter = arg_list->second.begin();
839               iter != arg_list->second.end(); ++iter) {
840            lists.push_back(param_names_[*iter].GetRenameTo());
841          }
842          inputs_var = VectorToTuple(lists);
843          conversion = "_execute.args_to_mixed_eager_tensors";
844        } else {
845          // For one list(tensor) argument, we just convert every
846          // element of the list to an eager tensor.
847          inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
848          conversion = "_execute.convert_to_mixed_eager_tensors";
849        }
850        strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
851                           " = ", conversion, "(", inputs_var, ", _ctx)\n");
852      }
853    }
854  }
855}
856
857void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
858  // Cast remaining args to eager tensors
859  for (int i = 0; i < op_def_.input_arg_size(); ++i) {
860    const auto& arg(op_def_.input_arg(i));
861    if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
862    const string& param = param_names_[i].GetRenameTo();
863    const string fn = arg.number_attr().empty() ? "" : "n_";
864    const string dtype =
865        python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
866    strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
867                       "to_tensor(", param, ", ", dtype, ")\n");
868  }
869}
870
871void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
872  // Compute eager attrs
873  if (op_def_.attr_size() > 0) {
874    string attr_values;
875    for (int i = 0; i < op_def_.attr_size(); ++i) {
876      if (i > 0) strings::StrAppend(&attr_values, ", ");
877      const auto& attr_name(op_def_.attr(i).name());
878      strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
879                         attr_expressions_[attr_name]);
880    }
881    strings::StrAppend(&attr_values, ")");
882    strings::StrAppend(
883        &result_,
884        WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
885                 kRightMargin),
886        "\n");
887  } else {
888    strings::StrAppend(&result_, indentation, "_attrs = None\n");
889  }
890}
891
892void GenEagerPythonOp::AddEagerExecute(const string& indentation,
893                                       const string& num_outputs_expr) {
894  const string return_prefix =
895      strings::StrCat(indentation, "_result = _execute.execute(");
896  const string return_args = strings::StrCat(
897      "b\"", op_def_.name(), "\", ", num_outputs_expr,
898      ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)");
899  strings::StrAppend(&result_,
900                     // Wrap the arguments, and indent to the (.
901                     WordWrap(return_prefix, return_args, kRightMargin), "\n");
902}
903
904string GetEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs,
905                         const std::vector<string>& hidden_ops,
906                         bool require_shapes,
907                         const string& source_file_name = "") {
908  string result;
909  // Header
910  // TODO(josh11b): Mention the library for which wrappers are being generated.
911  strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
912
913This file is MACHINE GENERATED! Do not edit.
914)");
915
916  // Mention the original source file so someone tracing back through
917  // generated Python code will know where to look next.
918  if (!source_file_name.empty()) {
919    strings::StrAppend(&result, "Original C++ source file: ");
920    strings::StrAppend(&result, source_file_name);
921    strings::StrAppend(&result, "\n");
922  }
923
924  strings::StrAppend(&result, R"("""
925
926import collections as _collections
927import six as _six
928
929from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
930from tensorflow.python.eager import context as _context
931from tensorflow.python.eager import core as _core
932from tensorflow.python.eager import execute as _execute
933from tensorflow.python.framework import dtypes as _dtypes
934from tensorflow.python.framework import errors as _errors
935from tensorflow.python.framework import tensor_shape as _tensor_shape
936
937from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
938# Needed to trigger the call to _set_call_cpp_shape_fn.
939from tensorflow.python.framework import common_shapes as _common_shapes
940from tensorflow.python.framework import op_def_registry as _op_def_registry
941from tensorflow.python.framework import ops as _ops
942from tensorflow.python.framework import op_def_library as _op_def_library
943from tensorflow.python.util.tf_export import tf_export
944
945)");
946
947  // We'll make a copy of ops that filters out descriptions.
948  OpList cleaned_ops;
949  auto out = cleaned_ops.mutable_op();
950  out->Reserve(ops.op_size());
951  for (const auto& op_def : ops.op()) {
952    const auto* api_def = api_defs.GetApiDef(op_def.name());
953
954    if (api_def->visibility() == ApiDef::SKIP) {
955      continue;
956    }
957
958    // An op is hidden if either its ApiDef visibility is HIDDEN
959    // or it is in the hidden_ops list.
960    bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
961    if (!is_hidden) {
962      for (const string& hidden : hidden_ops) {
963        if (op_def.name() == hidden) {
964          is_hidden = true;
965          break;
966        }
967      }
968    }
969
970    string function_name;
971    python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
972                                                    &function_name);
973    if (is_hidden) function_name = strings::StrCat("_", function_name);
974
975    // When users create custom python wrappers, they may link in the
976    // default op registry by accident, and because they can't
977    // enumerate all 'hidden' symbols, this guard is to prevent
978    // instantiating a python reserved word in their wrapper.
979    if (python_op_gen_internal::IsPythonReserved(function_name)) {
980      continue;
981    }
982
983    strings::StrAppend(&result,
984                       GetEagerPythonOp(op_def, *api_def, function_name));
985
986    if (!require_shapes) {
987      strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
988                         "\")(None)\n\n");
989    }
990
991    auto added = out->Add();
992    *added = op_def;
993    RemoveNonDeprecationDescriptionsFromOpDef(added);
994  }
995
996  result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
997  op_list = _op_def_pb2.OpList()
998  op_list.ParseFromString(op_list_proto_bytes)
999  _op_def_registry.register_op_list(op_list)
1000  op_def_lib = _op_def_library.OpDefLibrary()
1001  op_def_lib.add_op_list(op_list)
1002  return op_def_lib
1003)");
1004
1005  result.append("# ");
1006  auto ops_text = ProtoDebugString(cleaned_ops);
1007  str_util::StripTrailingWhitespace(&ops_text);
1008  result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
1009  result.append("\n");
1010  strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
1011                   str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
1012  return result;
1013}
1014
1015}  // namespace
1016
1017void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1018                         const std::vector<string>& hidden_ops,
1019                         bool require_shapes, const string& source_file_name) {
1020  printf("%s", GetEagerPythonOps(ops, api_defs, hidden_ops, require_shapes,
1021                                 source_file_name)
1022                   .c_str());
1023}
1024
1025string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
1026  string op_list_str(op_list_buf, op_list_len);
1027  OpList ops;
1028  ops.ParseFromString(op_list_str);
1029
1030  ApiDefMap api_def_map(ops);
1031  return GetEagerPythonOps(ops, api_def_map, {}, false);
1032}
1033
1034}  // namespace tensorflow
1035