python_eager_op_gen.cc revision 8f1e63d5629bda4f6c91fdec7a3b8418ed96786e
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/eager/python_eager_op_gen.h"
16
17#include <stdio.h>
18#include <sstream>
19#include <unordered_map>
20#include "tensorflow/core/framework/api_def.pb.h"
21#include "tensorflow/core/framework/attr_value.pb.h"
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/framework/op_def.pb_text.h"
24#include "tensorflow/core/framework/op_def.pb.h"
25#include "tensorflow/core/framework/op_def_util.h"
26#include "tensorflow/core/framework/op_gen_lib.h"
27#include "tensorflow/core/framework/tensor.pb_text.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/framework/types.pb.h"
30#include "tensorflow/core/lib/gtl/map_util.h"
31#include "tensorflow/core/lib/gtl/stl_util.h"
32#include "tensorflow/core/lib/strings/str_util.h"
33#include "tensorflow/core/lib/strings/strcat.h"
34#include "tensorflow/core/lib/strings/stringprintf.h"
35#include "tensorflow/core/platform/logging.h"
36#include "tensorflow/core/platform/macros.h"
37#include "tensorflow/core/platform/types.h"
38#include "tensorflow/python/framework/python_op_gen_internal.h"
39
40namespace tensorflow {
41namespace {
42
43const int kRightMargin = 78;
44
45string AttrVarName(const string& attr_name,
46                   std::unordered_map<string, string>* attr_expressions) {
47  const string var = strings::StrCat("_attr_", attr_name);
48  if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
49  return var;
50}
51
52void AddInferredAttr(const string& attr_name, const string& value_expression,
53                     string* result,
54                     std::unordered_map<string, string>* attr_expressions) {
55  strings::StrAppend(result, "  ", AttrVarName(attr_name, attr_expressions),
56                     " = ", value_expression, "\n");
57}
58
59string VectorToTuple(const std::vector<string>& l) {
60  if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
61  string ret = "(";
62  for (int i = 0; i < l.size(); ++i) {
63    if (i > 0) {
64      strings::StrAppend(&ret, ", ");
65    }
66    strings::StrAppend(&ret, l[i]);
67  }
68  strings::StrAppend(&ret, ")");
69  return ret;
70}
71
72void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
73               const string& var, string* result) {
74  for (int i = 0; i < output_sizes.size(); ++i) {
75    if (!output_sizes[i].empty()) {
76      strings::StrAppend(result, prefix, var, " = ");
77      if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
78      if (i + 1 < output_sizes.size()) {
79        // Special case i == 0 to avoid "0 +" in the generated code.
80        if (i == 0) {
81          strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
82                             var, "[", output_sizes[i], ":]");
83        } else {
84          strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
85                             output_sizes[i], "]] + ", var, "[", i, " + ",
86                             output_sizes[i], ":]");
87        }
88      } else {
89        strings::StrAppend(result, "[", var, "[", i, ":]]");
90      }
91      strings::StrAppend(result, "\n");
92    }
93  }
94}
95
96string TensorPBString(const TensorProto& pb) {
97  // Note: This gets used in the argument list, and so must survive naive
98  // word wrapping.
99  return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
100}
101
102const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
103  for (int i = 0; i < api_def.in_arg_size(); ++i) {
104    if (api_def.in_arg(i).name() == name) {
105      return &api_def.in_arg(i);
106    }
107  }
108  return nullptr;
109}
110
111class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
112 public:
113  GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
114                   const string& function_name)
115      : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) {
116    op_name_ = function_name_;
117    op_name_.Consume("_");
118  }
119  ~GenEagerPythonOp() override {}
120
121  string Code() override;
122
123 protected:
124  void ExpectListArg(const string& arg_name);
125  void AddEagerInferredAttrs();
126  void AddEagerInputCasts();
127  void AddEagerAttrs();
128  void AddEagerExecute(const string& num_outputs_expr);
129
130  void AddAttrForArg(const string& attr, int arg_index) {
131    gtl::InsertIfNotPresent(&inferred_attrs_, attr,
132                            op_def_.input_arg(arg_index).name());
133    auto iter = attr_to_args_.find(attr);
134    if (iter == attr_to_args_.end()) {
135      attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
136    } else {
137      iter->second.push_back(arg_index);
138    }
139  }
140
141  // Returns a string expression representing a flattened list of all
142  // the inputs given by `*input_indices` (or all inputs if
143  // `input_indices` is nullptr).  `*output_sizes` can be used to unflatten.
144  string FlattenInputs(const std::vector<int>* input_indices,
145                       std::vector<string>* output_sizes) const;
146
147  StringPiece op_name_;
148  typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
149  AttrToArgMap attr_to_args_;
150  std::unordered_map<string, string> attr_expressions_;
151};
152
153string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
154                        const string& function_name) {
155  return GenEagerPythonOp(op_def, api_def, function_name).Code();
156}
157
158string GenEagerPythonOp::FlattenInputs(
159    const std::vector<int>* input_indices,
160    std::vector<string>* output_sizes) const {
161  string inputs;
162  enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
163  const int n = input_indices != nullptr ? input_indices->size()
164                                         : op_def_.input_arg_size();
165  for (int j = 0; j < n; ++j) {
166    const int i = input_indices ? (*input_indices)[j] : j;
167    const auto& arg(op_def_.input_arg(i));
168    const bool is_list =
169        !arg.type_list_attr().empty() || !arg.number_attr().empty();
170    if (is_list) {
171      if (inputs_state == WAS_SOLO_INPUT) {
172        strings::StrAppend(&inputs, "] + ");
173      } else if (inputs_state == WAS_LIST_INPUT) {
174        strings::StrAppend(&inputs, " + ");
175      }
176      strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
177      inputs_state = WAS_LIST_INPUT;
178      if (output_sizes != nullptr) {
179        if (!arg.number_attr().empty()) {
180          output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
181        } else {
182          output_sizes->emplace_back(
183              strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
184        }
185      }
186    } else {
187      if (inputs_state == WAS_SOLO_INPUT) {
188        strings::StrAppend(&inputs, ", ");
189      } else if (inputs_state == WAS_LIST_INPUT) {
190        strings::StrAppend(&inputs, " + [");
191      } else {
192        strings::StrAppend(&inputs, "[");
193      }
194      strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
195      inputs_state = WAS_SOLO_INPUT;
196      if (output_sizes != nullptr) output_sizes->emplace_back();
197    }
198  }
199  if (inputs_state == STARTING) return "[]";
200  if (inputs_state == WAS_SOLO_INPUT) {
201    strings::StrAppend(&inputs, "]");
202  }
203  return inputs;
204}
205
206string GenEagerPythonOp::Code() {
207  if (api_def_.visibility() == ApiDef::SKIP) {
208    return "";
209  }
210  // This has all the input args followed by those attrs that don't have
211  // defaults.
212  std::vector<python_op_gen_internal::ParamNames> params_no_default;
213  // The parameters with defaults (these have to be listed after those without).
214  // No input args are included, just attrs.
215  std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
216      params_with_default;
217
218  for (int i = 0; i < api_def_.arg_order_size(); ++i) {
219    const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
220    const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
221    params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
222    if (!arg.type_attr().empty()) {
223      AddAttrForArg(arg.type_attr(), i);
224    } else if (!arg.type_list_attr().empty()) {
225      AddAttrForArg(arg.type_list_attr(), i);
226    }
227    if (!arg.number_attr().empty()) {
228      AddAttrForArg(arg.number_attr(), i);
229    }
230  }
231  for (int i = 0; i < op_def_.attr_size(); ++i) {
232    const auto& attr(op_def_.attr(i));
233    const auto& api_def_attr(api_def_.attr(i));
234    // Do not add inferred attrs to the Python function signature.
235    if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
236      if (api_def_attr.has_default_value()) {
237        if (attr.type() == "tensor") {
238          params_with_default.emplace_back(
239              python_op_gen_internal::ParamNames(api_def_attr.name(),
240                                                 api_def_attr.rename_to()),
241              strings::StrCat(
242                  "_execute.make_tensor(",
243                  TensorPBString(api_def_attr.default_value().tensor()), ", \"",
244                  api_def_attr.rename_to(), "\")"));
245        } else if (attr.type() == "list(tensor)") {
246          std::vector<string> pbtxt;
247          for (const auto& pb : api_def_attr.default_value().list().tensor()) {
248            pbtxt.emplace_back(TensorPBString(pb));
249          }
250          params_with_default.emplace_back(
251              python_op_gen_internal::ParamNames(api_def_attr.name(),
252                                                 api_def_attr.rename_to()),
253              strings::StrCat("[_execute.make_tensor(_pb, \"",
254                              api_def_attr.rename_to(), "\") for _pb in ",
255                              VectorToTuple(pbtxt), "]"));
256        } else {
257          params_with_default.emplace_back(
258              python_op_gen_internal::ParamNames(api_def_attr.name(),
259                                                 api_def_attr.rename_to()),
260              python_op_gen_internal::AttrValueToPython(
261                  attr.type(), api_def_attr.default_value(), "_dtypes."));
262        }
263      } else {
264        params_no_default.emplace_back(api_def_attr.name(),
265                                       api_def_attr.rename_to());
266      }
267    }
268  }
269
270  // Save the list of attr parameters (attrs that won't be inferred),
271  // those with defaults go at the end.
272  // Get the attrs in the order we want by taking the attrs without defaults
273  // from the end of params_no_default, and adding params_no_default.
274  attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
275                 params_with_default.size());
276  for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) {
277    attrs_.push_back(params_no_default[i].GetName());
278  }
279  for (const auto& p : params_with_default) {
280    attrs_.push_back(p.first.GetName());
281  }
282
283  param_names_.reserve(params_no_default.size() + params_with_default.size());
284  param_names_.insert(param_names_.begin(), params_no_default.begin(),
285                      params_no_default.end());
286  for (const auto& param_and_default : params_with_default) {
287    param_names_.push_back(param_and_default.first);
288  }
289
290  string parameters;
291  for (const auto& param : params_no_default) {
292    if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
293    strings::StrAppend(&parameters, param.GetRenameTo());
294  }
295  for (const auto& param_and_default : params_with_default) {
296    if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
297    strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), "=",
298                       param_and_default.second);
299  }
300  if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
301  strings::StrAppend(&parameters, "name=None");
302
303  AddExport();
304  AddDefLine(parameters);
305  AddDocStringDescription();
306  AddDocStringArgs();
307  AddDocStringInputs();
308  AddDocStringAttrs();
309  AddDocStringNameArg();
310  AddOutputGlobals();
311  AddDocStringOutputs();
312  strings::StrAppend(&result_, "  \"\"\"\n");
313
314  // Function body.
315
316  // Validate list inputs, infer length attrs.
317  for (int i = 0; i < op_def_.attr_size(); ++i) {
318    const auto& attr(op_def_.attr(i));
319    if (attr.type() == "int") {
320      auto arg_list = attr_to_args_.find(attr.name());
321      if (arg_list != attr_to_args_.end()) {
322        // Inferred int attrs are the lengths of inputs. Validate those
323        // inputs are lists and have the same length.
324        for (auto iter = arg_list->second.begin();
325             iter != arg_list->second.end(); ++iter) {
326          const string& arg_api_name = param_names_[*iter].GetRenameTo();
327          ExpectListArg(arg_api_name);
328          if (iter == arg_list->second.begin()) {
329            AddInferredAttr(attr.name(),
330                            strings::StrCat("len(", arg_api_name, ")"),
331                            &result_, &attr_expressions_);
332          } else {
333            const auto& attr_var = attr_expressions_[attr.name()];
334            strings::StrAppend(&result_, "  if len(", arg_api_name,
335                               ") != ", attr_var,
336                               ":\n"
337                               "    raise ValueError(\n"
338                               "        \"List argument '",
339                               arg_api_name, "' to '", op_name_,
340                               "' Op with length %d \"\n"
341                               "        \"must match length %d of argument '",
342                               inferred_attrs_[attr.name()],
343                               "'.\" %\n"
344                               "        (len(",
345                               arg_api_name, "), ", attr_var, "))\n");
346          }
347        }
348      }
349    }
350  }
351
352  // Values for non-inferred attrs.
353  for (int i = 0; i < attrs_.size(); ++i) {
354    const string& attr_name = attrs_[i];
355    const auto& param = param_names_[i + op_def_.input_arg_size()];
356    const auto& attr = *FindAttr(attr_name, op_def_);
357    const string& attr_api_name = param.GetRenameTo();
358    StringPiece attr_type = attr.type();
359    attr_expressions_[attr_name] = attr_api_name;
360    const int default_index = i - (attrs_.size() - params_with_default.size());
361    if (default_index >= 0) {
362      const string& default_value = params_with_default[default_index].second;
363      strings::StrAppend(&result_, "  if ", attr_api_name, " is None:\n");
364      strings::StrAppend(&result_, "    ", attr_api_name, " = ", default_value,
365                         "\n");
366    }
367    if (attr_type.starts_with("list(")) {
368      ExpectListArg(attr_api_name);
369    }
370
371    if (attr_type == "string") {
372      strings::StrAppend(&result_, "  ", attr_api_name, " = _execute.make_str(",
373                         attr_api_name, ", \"", attr_api_name, "\")\n");
374    } else if (attr_type == "list(string)") {
375      strings::StrAppend(&result_, "  ", attr_api_name,
376                         " = [_execute.make_str(_s, \"", attr_api_name,
377                         "\") for _s in ", attr_api_name, "]\n");
378    } else if (attr_type == "int") {
379      strings::StrAppend(&result_, "  ", attr_api_name, " = _execute.make_int(",
380                         attr_api_name, ", \"", attr_api_name, "\")\n");
381    } else if (attr_type == "list(int)") {
382      strings::StrAppend(&result_, "  ", attr_api_name,
383                         " = [_execute.make_int(_i, \"", attr_api_name,
384                         "\") for _i in ", attr_api_name, "]\n");
385    } else if (attr_type == "float") {
386      strings::StrAppend(&result_, "  ", attr_api_name,
387                         " = _execute.make_float(", attr_api_name, ", \"",
388                         attr_api_name, "\")\n");
389    } else if (attr_type == "list(float)") {
390      strings::StrAppend(&result_, "  ", attr_api_name,
391                         " = [_execute.make_float(_f, \"", attr_api_name,
392                         "\") for _f in ", attr_api_name, "]\n");
393    } else if (attr_type == "bool") {
394      strings::StrAppend(&result_, "  ", attr_api_name,
395                         " = _execute.make_bool(", attr_api_name, ", \"",
396                         attr_api_name, "\")\n");
397    } else if (attr_type == "list(bool)") {
398      strings::StrAppend(&result_, "  ", attr_api_name,
399                         " = [_execute.make_bool(_b, \"", attr_api_name,
400                         "\") for _b in ", attr_api_name, "]\n");
401    } else if (attr_type == "type") {
402      strings::StrAppend(&result_, "  ", attr_api_name,
403                         " = _execute.make_type(", attr_api_name, ", \"",
404                         attr_api_name, "\")\n");
405    } else if (attr_type == "list(type)") {
406      strings::StrAppend(&result_, "  ", attr_api_name,
407                         " = [_execute.make_type(_t, \"", attr_api_name,
408                         "\") for _t in ", attr_api_name, "]\n");
409    } else if (attr_type == "shape") {
410      strings::StrAppend(&result_, "  ", attr_api_name,
411                         " = _execute.make_shape(", attr_api_name, ", \"",
412                         attr_api_name, "\")\n");
413    } else if (attr_type == "list(shape)") {
414      strings::StrAppend(&result_, "  ", attr_api_name,
415                         " = [_execute.make_shape(_s, \"", attr_api_name,
416                         "\") for _s in ", attr_api_name, "]\n");
417    } else if (attr_type == "tensor") {
418      strings::StrAppend(&result_, "  ", attr_api_name,
419                         " = _execute.make_tensor(", attr_api_name, ", \"",
420                         attr_api_name, "\")\n");
421    } else if (attr_type == "list(tensor)") {
422      strings::StrAppend(&result_, "  ", attr_api_name,
423                         " = [_execute.make_tensor(_t, \"", attr_api_name,
424                         "\") for _t in ", attr_api_name, "]\n");
425    } else if (attr_type != "func") {
426      return strings::StrCat("# No definition for ", function_name_,
427                             " since we don't support attrs with type\n"
428                             "# '",
429                             attr_type, "' right now.\n\n");
430    }
431  }
432
433  // Figure out the list of inputs.
434  const string inputs = FlattenInputs(nullptr, nullptr);
435
436  // Handle graph-mode case
437  strings::StrAppend(&result_,
438                     "  _ctx = _context.context()\n"
439
440                     "  if _ctx.in_graph_mode():\n"
441                     "    _, _, _op = _op_def_lib._apply_op_helper(\n");
442  AddBodyNoReturn("        ");
443  if (num_outs_ > 0) {
444    strings::StrAppend(&result_, "    _result = _op.outputs[:]\n");
445    // Special case handling for stateful op with single list output
446    // that might be empty.
447    if (num_outs_ == 1 && op_def_.is_stateful() &&
448        (!op_def_.output_arg(0).number_attr().empty() ||
449         !op_def_.output_arg(0).type_list_attr().empty())) {
450      // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
451      // a constraint indicating that this can never be empty.
452      strings::StrAppend(&result_,
453                         "    if not _result:\n"
454                         "      return _op\n");
455    }
456    strings::StrAppend(&result_, "    _inputs_flat = _op.inputs\n");
457
458    // Compute graph-mode attrs.
459    if (op_def_.attr_size() > 0) {
460      string attr_values;
461      for (int i = 0; i < op_def_.attr_size(); ++i) {
462        if (i > 0) strings::StrAppend(&attr_values, ", ");
463        const auto& attr_name(op_def_.attr(i).name());
464        strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
465                           attr_name, "\")");
466      }
467      strings::StrAppend(&attr_values, ")");
468      strings::StrAppend(&result_,
469                         WordWrap("    _attrs = (", attr_values, kRightMargin),
470                         "\n");
471    } else {
472      strings::StrAppend(&result_, "    _attrs = None\n");
473    }
474  } else {
475    strings::StrAppend(&result_, "    return _op\n");
476  }
477
478  // Handle eager-mode case
479  strings::StrAppend(&result_, "  else:\n");
480
481  // Expression representing the number of outputs.
482  int num_fixed_outputs = 0;
483  string num_outputs_expr;
484  // If output i is list output, output_sizes[i] will be set to a
485  // string with the python expression that will evaluate to its
486  // length. output_sizes[i] is empty for non-list outputs.
487  std::vector<string> output_sizes(num_outs_);
488  for (int i = 0; i < num_outs_; ++i) {
489    const auto& arg(op_def_.output_arg(i));
490    if (!arg.number_attr().empty()) {
491      if (!num_outputs_expr.empty()) {
492        strings::StrAppend(&num_outputs_expr, " + ");
493      }
494      output_sizes[i] = attr_expressions_[arg.number_attr()];
495      strings::StrAppend(&num_outputs_expr, output_sizes[i]);
496    } else if (!arg.type_list_attr().empty()) {
497      if (!num_outputs_expr.empty()) {
498        strings::StrAppend(&num_outputs_expr, " + ");
499      }
500      // Have to be careful to use an expression that works in both
501      // graph and eager paths here.
502      const auto iter = inferred_attrs_.find(arg.type_list_attr());
503      if (iter == inferred_attrs_.end()) {
504        output_sizes[i] = strings::StrCat(
505            "len(", attr_expressions_[arg.type_list_attr()], ")");
506      } else {
507        output_sizes[i] = strings::StrCat("len(", iter->second, ")");
508      }
509      strings::StrAppend(&num_outputs_expr, output_sizes[i]);
510    } else {
511      ++num_fixed_outputs;
512    }
513  }
514  if (num_fixed_outputs > 0) {
515    if (!num_outputs_expr.empty()) {
516      strings::StrAppend(&num_outputs_expr, " + ");
517    }
518    strings::StrAppend(&num_outputs_expr, num_fixed_outputs);
519  } else if (num_outputs_expr.empty()) {
520    num_outputs_expr = "0";
521  }
522
523  bool eager_allowed = true;
524  string ref_arg;
525  for (int i = 0; i < op_def_.input_arg_size(); ++i) {
526    const auto& arg = op_def_.input_arg(i);
527    if (arg.is_ref()) {
528      eager_allowed = false;
529      DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
530      ref_arg = api_def_.in_arg(i).rename_to();
531    }
532  }
533  for (int i = 0; i < op_def_.output_arg_size(); ++i) {
534    const auto& arg = op_def_.output_arg(i);
535    if (arg.is_ref()) {
536      eager_allowed = false;
537      DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
538      ref_arg = api_def_.out_arg(i).rename_to();
539    }
540  }
541
542  if (eager_allowed) {
543    AddEagerInferredAttrs();
544    AddEagerInputCasts();
545    strings::StrAppend(&result_, "    _inputs_flat = ", inputs, "\n");
546    AddEagerAttrs();
547    AddEagerExecute(num_outputs_expr);
548  } else {
549    strings::StrAppend(&result_,
550                       "    raise RuntimeError(\n"
551                       "        \"",
552                       op_name_, " op does not support eager execution. ",
553                       "Arg '", ref_arg, "'' is a ref.\")\n");
554  }
555
556  if (num_outs_ > 0) {
557    strings::StrAppend(&result_, "  _execute.record_gradient(\n", "      \"",
558                       op_def_.name(),
559                       "\", _inputs_flat, _attrs, _result, name)\n");
560    if (num_outs_ == 1 && !output_sizes[0].empty()) {
561      // Single list result.
562    } else if (num_outs_ == 1) {
563      // Execute returns a single-element list which we need to destructure.
564      strings::StrAppend(&result_, "  _result, = _result\n");
565    } else {
566      // Have multiple outputs, so we will need to reformat the return
567      // value of execute() to be a list with one entry per op output
568      // (that entry will be a list of tensors if that output is of list
569      // type).
570      // For list outputs, convert the right subrange of _result into a list.
571      Unflatten("  ", output_sizes, "_result", &result_);
572      // Convert to a named tuple.
573      strings::StrAppend(&result_, "  _result = _", op_def_.name(),
574                         "Output._make(_result)\n");
575    }
576  } else {
577    strings::StrAppend(&result_, "    _result = None\n");
578  }
579  strings::StrAppend(&result_, "  return _result\n\n");
580  return prelude_ + result_;
581}
582
583void GenEagerPythonOp::ExpectListArg(const string& arg_name) {
584  strings::StrAppend(&result_, "  if not isinstance(", arg_name,
585                     ", (list, tuple)):\n"
586                     "    raise TypeError(\n"
587                     "        \"Expected list for '",
588                     arg_name,
589                     "' argument to \"\n"
590                     "        \"'",
591                     op_name_, "' Op, not %r.\" % ", arg_name, ")\n");
592}
593
594void GenEagerPythonOp::AddEagerInferredAttrs() {
595  // Figure out values for inferred attrs, and cast to eager tensors.
596  for (int i = 0; i < op_def_.attr_size(); ++i) {
597    const auto& attr(op_def_.attr(i));
598    const auto& api_def_attr(api_def_.attr(i));
599    auto arg_list = attr_to_args_.find(attr.name());
600    if (arg_list != attr_to_args_.end()) {
601      if (attr.type() == "type") {
602        std::vector<string> output_sizes;
603        const string flattened =
604            FlattenInputs(&arg_list->second, &output_sizes);
605        string conversion = strings::StrCat("_execute.args_to_matching_eager(",
606                                            flattened, ", _ctx");
607        if (attr.has_default_value()) {
608          strings::StrAppend(
609              &conversion, ", ",
610              python_op_gen_internal::AttrValueToPython(
611                  attr.type(), api_def_attr.default_value(), "_dtypes."));
612        }
613        strings::StrAppend(&conversion, ")");
614        const string var_name = AttrVarName(attr.name(), &attr_expressions_);
615        if (output_sizes.size() == 1) {
616          // Avoid creating a temporary variable in the case where
617          // we can easily assign to the right value directly.
618          const string inputs_var =
619              param_names_[arg_list->second.front()].GetRenameTo();
620          if (output_sizes.front().empty()) {
621            strings::StrAppend(&result_, "    ", var_name, ", (", inputs_var,
622                               ",) = ", conversion, "\n");
623          } else {
624            strings::StrAppend(&result_, "    ", var_name, ", ", inputs_var,
625                               " = ", conversion, "\n");
626          }
627        } else {
628          const string inputs_var = strings::StrCat("_inputs_", attr.name());
629          strings::StrAppend(&result_, "    ", var_name, ", ", inputs_var,
630                             " = ", conversion, "\n");
631          // Convert from a flat list of eager tensors back to the
632          // parameter variables.
633          Unflatten("    ", output_sizes, inputs_var, &result_);
634          std::vector<string> p;
635          for (int j : arg_list->second) {
636            p.emplace_back(param_names_[j].GetRenameTo());
637          }
638          strings::StrAppend(&result_, "    ", VectorToTuple(p), " = ",
639                             inputs_var, "\n");
640        }
641      } else if (attr.type() == "list(type)") {
642        // NOTE: We ignore default values for these attrs, since it is
643        // unclear how you would use it, and the one use case is
644        // parse_single_sequence_example which only needs it for
645        // backwards compatibility.
646        const string var_name = AttrVarName(attr.name(), &attr_expressions_);
647        string inputs_var;
648        string conversion;
649        if (arg_list->second.size() > 1) {
650          // If you have more than one list(tensor) argument, their types
651          // have to match.
652          std::vector<string> lists;
653          for (auto iter = arg_list->second.begin();
654               iter != arg_list->second.end(); ++iter) {
655            lists.push_back(param_names_[*iter].GetRenameTo());
656          }
657          inputs_var = VectorToTuple(lists);
658          conversion = "_execute.args_to_mixed_eager_tensors";
659        } else {
660          // For one list(tensor) argument, we just convert every
661          // element of the list to an eager tensor.
662          inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
663          conversion = "_execute.convert_to_mixed_eager_tensors";
664        }
665        strings::StrAppend(&result_, "    ", var_name, ", ", inputs_var, " = ",
666                           conversion, "(", inputs_var, ", _ctx)\n");
667      }
668    }
669  }
670}
671
672void GenEagerPythonOp::AddEagerInputCasts() {
673  // Cast remaining args to eager tensors
674  for (int i = 0; i < op_def_.input_arg_size(); ++i) {
675    const auto& arg(op_def_.input_arg(i));
676    if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
677    const string& param = param_names_[i].GetRenameTo();
678    const string fn = arg.number_attr().empty() ? "" : "n_";
679    const string dtype =
680        python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
681    strings::StrAppend(&result_, "    ", param, " = _ops.convert_", fn,
682                       "to_tensor(", param, ", ", dtype, ")\n");
683  }
684}
685
686void GenEagerPythonOp::AddEagerAttrs() {
687  // Compute eager attrs
688  if (op_def_.attr_size() > 0) {
689    string attr_values;
690    for (int i = 0; i < op_def_.attr_size(); ++i) {
691      if (i > 0) strings::StrAppend(&attr_values, ", ");
692      const auto& attr_name(op_def_.attr(i).name());
693      strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
694                         attr_expressions_[attr_name]);
695    }
696    strings::StrAppend(&attr_values, ")");
697    strings::StrAppend(
698        &result_, WordWrap("    _attrs = (", attr_values, kRightMargin), "\n");
699  } else {
700    strings::StrAppend(&result_, "    _attrs = None\n");
701  }
702}
703
704void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) {
705  const string return_prefix = "    _result = _execute.execute(";
706  const string return_args = strings::StrCat(
707      "b\"", op_def_.name(), "\", ", num_outputs_expr,
708      ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)");
709  strings::StrAppend(&result_,
710                     // Wrap the arguments, and indent to the (.
711                     WordWrap(return_prefix, return_args, kRightMargin), "\n");
712}
713
714string GetEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs,
715                         const std::vector<string>& hidden_ops,
716                         bool require_shapes,
717                         const string& source_file_name = "") {
718  string result;
719  // Header
720  // TODO(josh11b): Mention the library for which wrappers are being generated.
721  strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
722
723This file is MACHINE GENERATED! Do not edit.
724)");
725
726  // Mention the original source file so someone tracing back through generated
727  // Python code will know where to look next.
728  if (!source_file_name.empty()) {
729    strings::StrAppend(&result, "Original C++ source file: ");
730    strings::StrAppend(&result, source_file_name);
731    strings::StrAppend(&result, "\n");
732  }
733
734  strings::StrAppend(&result, R"("""
735
736import collections as _collections
737
738from tensorflow.python.eager import execute as _execute
739from tensorflow.python.eager import context as _context
740from tensorflow.python.eager import core as _core
741from tensorflow.python.framework import dtypes as _dtypes
742from tensorflow.python.framework import tensor_shape as _tensor_shape
743
744from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
745# Needed to trigger the call to _set_call_cpp_shape_fn.
746from tensorflow.python.framework import common_shapes as _common_shapes
747from tensorflow.python.framework import op_def_registry as _op_def_registry
748from tensorflow.python.framework import ops as _ops
749from tensorflow.python.framework import op_def_library as _op_def_library
750from tensorflow.python.util.tf_export import tf_export
751
752)");
753
754  // We'll make a copy of ops that filters out descriptions.
755  OpList cleaned_ops;
756  auto out = cleaned_ops.mutable_op();
757  out->Reserve(ops.op_size());
758  for (const auto& op_def : ops.op()) {
759    bool is_hidden = false;
760    for (const string& hidden : hidden_ops) {
761      if (op_def.name() == hidden) {
762        is_hidden = true;
763        break;
764      }
765    }
766
767    string function_name;
768    python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
769                                                    &function_name);
770    if (is_hidden) function_name = strings::StrCat("_", function_name);
771
772    // When users create custom python wrappers, they may link in the
773    // default op registry by accident, and because they can't
774    // enumerate all 'hidden' symbols, this guard is to prevent
775    // instantiating a python reserved word in their wrapper.
776    if (python_op_gen_internal::IsPythonReserved(function_name)) {
777      continue;
778    }
779
780    const auto* api_def = api_defs.GetApiDef(op_def.name());
781    strings::StrAppend(&result,
782                       GetEagerPythonOp(op_def, *api_def, function_name));
783
784    if (!require_shapes) {
785      strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
786                         "\")(None)\n\n");
787    }
788
789    auto added = out->Add();
790    *added = op_def;
791    RemoveNonDeprecationDescriptionsFromOpDef(added);
792  }
793
794  result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
795  op_list = _op_def_pb2.OpList()
796  op_list.ParseFromString(op_list_proto_bytes)
797  _op_def_registry.register_op_list(op_list)
798  op_def_lib = _op_def_library.OpDefLibrary()
799  op_def_lib.add_op_list(op_list)
800  return op_def_lib
801)");
802
803  result.append("# ");
804  auto ops_text = ProtoDebugString(cleaned_ops);
805  str_util::StripTrailingWhitespace(&ops_text);
806  result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
807  result.append("\n");
808  strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
809                   str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
810  return result;
811}
812
813}  // namespace
814
815void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs,
816                         const std::vector<string>& hidden_ops,
817                         bool require_shapes, const string& source_file_name) {
818  printf("%s", GetEagerPythonOps(ops, api_defs, hidden_ops, require_shapes,
819                                 source_file_name)
820                   .c_str());
821}
822
823string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
824  string op_list_str(op_list_buf, op_list_len);
825  OpList ops;
826  ops.ParseFromString(op_list_str);
827
828  ApiDefMap api_def_map(ops);
829  return GetEagerPythonOps(ops, api_def_map, {}, false);
830}
831
832}  // namespace tensorflow
833