python_eager_op_gen.cc revision fd3882dd5cb75773fb12b3a84962411c2df2a300
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/eager/python_eager_op_gen.h"
16
17#include <stdio.h>
18#include <sstream>
19#include <unordered_map>
20#include "tensorflow/core/framework/attr_value.pb.h"
21#include "tensorflow/core/framework/op.h"
22#include "tensorflow/core/framework/op_def.pb_text.h"
23#include "tensorflow/core/framework/op_def.pb.h"
24#include "tensorflow/core/framework/op_def_util.h"
25#include "tensorflow/core/framework/op_gen_lib.h"
26#include "tensorflow/core/framework/tensor.pb_text.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/framework/types.pb.h"
29#include "tensorflow/core/lib/gtl/map_util.h"
30#include "tensorflow/core/lib/gtl/stl_util.h"
31#include "tensorflow/core/lib/strings/str_util.h"
32#include "tensorflow/core/lib/strings/strcat.h"
33#include "tensorflow/core/lib/strings/stringprintf.h"
34#include "tensorflow/core/platform/logging.h"
35#include "tensorflow/core/platform/macros.h"
36#include "tensorflow/core/platform/types.h"
37#include "tensorflow/python/framework/python_op_gen_internal.h"
38
39namespace tensorflow {
40namespace {
41
42const int kRightMargin = 78;
43
44string AttrVarName(const string& attr_name,
45                   std::unordered_map<string, string>* attr_expressions) {
46  const string var = strings::StrCat("_attr_", attr_name);
47  if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
48  return var;
49}
50
51void AddInferredAttr(const string& attr_name, const string& value_expression,
52                     string* result,
53                     std::unordered_map<string, string>* attr_expressions) {
54  strings::StrAppend(result, "  ", AttrVarName(attr_name, attr_expressions),
55                     " = ", value_expression, "\n");
56}
57
58string VectorToTuple(const std::vector<string>& l) {
59  if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
60  string ret = "(";
61  for (int i = 0; i < l.size(); ++i) {
62    if (i > 0) {
63      strings::StrAppend(&ret, ", ");
64    }
65    strings::StrAppend(&ret, l[i]);
66  }
67  strings::StrAppend(&ret, ")");
68  return ret;
69}
70
71void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
72               const string& var, string* result) {
73  for (int i = 0; i < output_sizes.size(); ++i) {
74    if (!output_sizes[i].empty()) {
75      strings::StrAppend(result, prefix, var, " = ");
76      if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
77      if (i + 1 < output_sizes.size()) {
78        // Special case i == 0 to avoid "0 +" in the generated code.
79        if (i == 0) {
80          strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
81                             var, "[", output_sizes[i], ":]");
82        } else {
83          strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
84                             output_sizes[i], "]] + ", var, "[", i, " + ",
85                             output_sizes[i], ":]");
86        }
87      } else {
88        strings::StrAppend(result, "[", var, "[", i, ":]]");
89      }
90      strings::StrAppend(result, "\n");
91    }
92  }
93}
94
95string TensorPBString(const TensorProto& pb) {
96  // Note: This gets used in the argument list, and so must survive naive
97  // word wrapping.
98  return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
99}
100
101class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
102 public:
103  GenEagerPythonOp(const OpDef& op_def, const string& function_name)
104      : python_op_gen_internal::GenPythonOp(op_def, function_name) {
105    op_name_ = function_name_;
106    op_name_.Consume("_");
107  }
108  ~GenEagerPythonOp() override {}
109
110  string Code() override;
111
112 protected:
113  void ExpectListArg(const string& arg_name);
114  void AddEagerInferredAttrs();
115  void AddEagerInputCasts();
116  void AddEagerAttrs();
117  void AddEagerExecute(const string& num_outputs_expr);
118
119  void AddAttrForArg(const string& attr, int arg_index) {
120    gtl::InsertIfNotPresent(&inferred_attrs_, attr,
121                            op_def_.input_arg(arg_index).name());
122    auto iter = attr_to_args_.find(attr);
123    if (iter == attr_to_args_.end()) {
124      attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
125    } else {
126      iter->second.push_back(arg_index);
127    }
128  }
129
130  // Returns a string expression representing a flattened list of all
131  // the inputs given by `*input_indices` (or all inputs if
132  // `input_indices` is nullptr).  `*output_sizes` can be used to unflatten.
133  string FlattenInputs(const std::vector<int>* input_indices,
134                       std::vector<string>* output_sizes) const;
135
136  StringPiece op_name_;
137  typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
138  AttrToArgMap attr_to_args_;
139  std::unordered_map<string, string> attr_expressions_;
140};
141
142string GetEagerPythonOp(const OpDef& op_def, const string& function_name) {
143  return GenEagerPythonOp(op_def, function_name).Code();
144}
145
146string GenEagerPythonOp::FlattenInputs(
147    const std::vector<int>* input_indices,
148    std::vector<string>* output_sizes) const {
149  string inputs;
150  enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
151  const int n = input_indices != nullptr ? input_indices->size()
152                                         : op_def_.input_arg_size();
153  for (int j = 0; j < n; ++j) {
154    const int i = input_indices ? (*input_indices)[j] : j;
155    const auto& arg(op_def_.input_arg(i));
156    const bool is_list =
157        !arg.type_list_attr().empty() || !arg.number_attr().empty();
158    if (is_list) {
159      if (inputs_state == WAS_SOLO_INPUT) {
160        strings::StrAppend(&inputs, "] + ");
161      } else if (inputs_state == WAS_LIST_INPUT) {
162        strings::StrAppend(&inputs, " + ");
163      }
164      strings::StrAppend(&inputs, "list(", param_names_[i], ")");
165      inputs_state = WAS_LIST_INPUT;
166      if (output_sizes != nullptr) {
167        if (!arg.number_attr().empty()) {
168          output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
169        } else {
170          output_sizes->emplace_back(
171              strings::StrCat("len(", param_names_[i], ")"));
172        }
173      }
174    } else {
175      if (inputs_state == WAS_SOLO_INPUT) {
176        strings::StrAppend(&inputs, ", ");
177      } else if (inputs_state == WAS_LIST_INPUT) {
178        strings::StrAppend(&inputs, " + [");
179      } else {
180        strings::StrAppend(&inputs, "[");
181      }
182      strings::StrAppend(&inputs, param_names_[i]);
183      inputs_state = WAS_SOLO_INPUT;
184      if (output_sizes != nullptr) output_sizes->emplace_back();
185    }
186  }
187  if (inputs_state == STARTING) return "[]";
188  if (inputs_state == WAS_SOLO_INPUT) {
189    strings::StrAppend(&inputs, "]");
190  }
191  return inputs;
192}
193
194string GenEagerPythonOp::Code() {
195  // This has all the input args followed by those attrs that don't have
196  // defaults.
197  std::vector<string> args_no_default;
198  // The parameters with defaults (these have to be listed after those without).
199  // No input args are included, just attrs.
200  std::vector<std::pair<string, string>> args_with_defaults;
201  for (int i = 0; i < op_def_.input_arg_size(); ++i) {
202    const auto& arg(op_def_.input_arg(i));
203    args_no_default.push_back(arg.name());
204    if (!arg.type_attr().empty()) {
205      AddAttrForArg(arg.type_attr(), i);
206    } else if (!arg.type_list_attr().empty()) {
207      AddAttrForArg(arg.type_list_attr(), i);
208    }
209    if (!arg.number_attr().empty()) {
210      AddAttrForArg(arg.number_attr(), i);
211    }
212  }
213  for (int i = 0; i < op_def_.attr_size(); ++i) {
214    const auto& attr(op_def_.attr(i));
215    // Do not add inferred attrs to the Python function signature.
216    if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
217      if (attr.has_default_value()) {
218        if (attr.type() == "tensor") {
219          args_with_defaults.emplace_back(
220              attr.name(),
221              strings::StrCat("_execute.make_tensor(",
222                              TensorPBString(attr.default_value().tensor()),
223                              ", \"", attr.name(), "\")"));
224        } else if (attr.type() == "list(tensor)") {
225          std::vector<string> pbtxt;
226          for (const auto& pb : attr.default_value().list().tensor()) {
227            pbtxt.emplace_back(TensorPBString(pb));
228          }
229          args_with_defaults.emplace_back(
230              attr.name(),
231              strings::StrCat("[_execute.make_tensor(_pb, \"", attr.name(),
232                              "\") for _pb in ", VectorToTuple(pbtxt), "]"));
233        } else {
234          args_with_defaults.emplace_back(
235              attr.name(), python_op_gen_internal::AttrValueToPython(
236                               attr.type(), attr.default_value(), "_dtypes."));
237        }
238      } else {
239        args_no_default.push_back(attr.name());
240      }
241    }
242  }
243
244  // Save the list of attr parameters (attrs that won't be inferred),
245  // those with defaults go at the end.
246  // Get the attrs in the order we want by taking the attrs without defaults
247  // from the end of args_no_default, and adding args_no_default.
248  attrs_.reserve(args_no_default.size() - op_def_.input_arg_size() +
249                 args_with_defaults.size());
250  attrs_.insert(attrs_.end(),
251                args_no_default.begin() + op_def_.input_arg_size(),
252                args_no_default.end());
253  for (const auto& a : args_with_defaults) {
254    attrs_.push_back(a.first);
255  }
256
257  param_names_.reserve(args_no_default.size() + args_with_defaults.size());
258  string parameters;
259  for (const string& name : args_no_default) {
260    if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
261    const string param = python_op_gen_internal::AvoidPythonReserved(name);
262    strings::StrAppend(&parameters, param);
263    param_names_.push_back(param);
264  }
265  for (const auto& name_default : args_with_defaults) {
266    if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
267    const string param =
268        python_op_gen_internal::AvoidPythonReserved(name_default.first);
269    strings::StrAppend(&parameters, param, "=", name_default.second);
270    param_names_.push_back(param);
271  }
272  if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
273  strings::StrAppend(&parameters, "name=None");
274
275  AddDefLine(parameters);
276  AddDocStringDescription();
277  AddDocStringArgs();
278  AddDocStringInputs();
279  AddDocStringAttrs();
280  AddDocStringNameArg();
281  AddOutputGlobals();
282  AddDocStringOutputs();
283  strings::StrAppend(&result_, "  \"\"\"\n");
284
285  // Function body.
286
287  // Validate list inputs, infer length attrs.
288  for (int i = 0; i < op_def_.attr_size(); ++i) {
289    const auto& attr(op_def_.attr(i));
290    if (attr.type() == "int") {
291      auto arg_list = attr_to_args_.find(attr.name());
292      if (arg_list != attr_to_args_.end()) {
293        // Inferred int attrs are the lengths of inputs. Validate those
294        // inputs are lists and have the same length.
295        for (auto iter = arg_list->second.begin();
296             iter != arg_list->second.end(); ++iter) {
297          const string& arg_name = param_names_[*iter];
298          ExpectListArg(arg_name);
299          if (iter == arg_list->second.begin()) {
300            AddInferredAttr(attr.name(), strings::StrCat("len(", arg_name, ")"),
301                            &result_, &attr_expressions_);
302          } else {
303            const auto& attr_var = attr_expressions_[attr.name()];
304            strings::StrAppend(&result_, "  if len(", arg_name,
305                               ") != ", attr_var,
306                               ":\n"
307                               "    raise ValueError(\n"
308                               "        \"List argument '",
309                               arg_name, "' to '", op_name_,
310                               "' Op with length %d \"\n"
311                               "        \"must match length %d of argument '",
312                               inferred_attrs_[attr.name()],
313                               "'.\" %\n"
314                               "        (len(",
315                               arg_name, "), ", attr_var, "))\n");
316          }
317        }
318      }
319    }
320  }
321
322  // Values for non-inferred attrs.
323  for (int i = 0; i < attrs_.size(); ++i) {
324    const string& attr_name = attrs_[i];
325    const string& param = param_names_[i + op_def_.input_arg_size()];
326    const auto& attr = *FindAttr(attr_name, op_def_);
327    StringPiece attr_type = attr.type();
328    attr_expressions_[attr_name] = param;
329    const int default_index = i - (attrs_.size() - args_with_defaults.size());
330    if (default_index >= 0) {
331      const string& default_value = args_with_defaults[default_index].second;
332      strings::StrAppend(&result_, "  if ", param, " is None:\n");
333      strings::StrAppend(&result_, "    ", param, " = ", default_value, "\n");
334    }
335    if (attr_type.starts_with("list(")) {
336      ExpectListArg(param);
337    }
338
339    if (attr_type == "string") {
340      strings::StrAppend(&result_, "  ", param, " = _execute.make_str(", param,
341                         ", \"", param, "\")\n");
342    } else if (attr_type == "list(string)") {
343      strings::StrAppend(&result_, "  ", param, " = [_execute.make_str(_s, \"",
344                         param, "\") for _s in ", param, "]\n");
345    } else if (attr_type == "int") {
346      strings::StrAppend(&result_, "  ", param, " = _execute.make_int(", param,
347                         ", \"", param, "\")\n");
348    } else if (attr_type == "list(int)") {
349      strings::StrAppend(&result_, "  ", param, " = [_execute.make_int(_i, \"",
350                         param, "\") for _i in ", param, "]\n");
351    } else if (attr_type == "float") {
352      strings::StrAppend(&result_, "  ", param, " = _execute.make_float(",
353                         param, ", \"", param, "\")\n");
354    } else if (attr_type == "list(float)") {
355      strings::StrAppend(&result_, "  ", param,
356                         " = [_execute.make_float(_f, \"", param,
357                         "\") for _f in ", param, "]\n");
358    } else if (attr_type == "bool") {
359      strings::StrAppend(&result_, "  ", param, " = _execute.make_bool(", param,
360                         ", \"", param, "\")\n");
361    } else if (attr_type == "list(bool)") {
362      strings::StrAppend(&result_, "  ", param, " = [_execute.make_bool(_b, \"",
363                         param, "\") for _b in ", param, "]\n");
364    } else if (attr_type == "type") {
365      strings::StrAppend(&result_, "  ", param, " = _execute.make_type(", param,
366                         ", \"", param, "\")\n");
367    } else if (attr_type == "list(type)") {
368      strings::StrAppend(&result_, "  ", param, " = [_execute.make_type(_t, \"",
369                         param, "\") for _t in ", param, "]\n");
370    } else if (attr_type == "shape") {
371      strings::StrAppend(&result_, "  ", param, " = _execute.make_shape(",
372                         param, ", \"", param, "\")\n");
373    } else if (attr_type == "list(shape)") {
374      strings::StrAppend(&result_, "  ", param,
375                         " = [_execute.make_shape(_s, \"", param,
376                         "\") for _s in ", param, "]\n");
377    } else if (attr_type == "tensor") {
378      strings::StrAppend(&result_, "  ", param, " = _execute.make_tensor(",
379                         param, ", \"", param, "\")\n");
380    } else if (attr_type == "list(tensor)") {
381      strings::StrAppend(&result_, "  ", param,
382                         " = [_execute.make_tensor(_t, \"", param,
383                         "\") for _t in ", param, "]\n");
384    } else if (attr_type != "func") {
385      return strings::StrCat("# No definition for ", function_name_,
386                             " since we don't support attrs with type\n"
387                             "# '",
388                             attr_type, "' right now.\n\n");
389    }
390  }
391
392  // Figure out the list of inputs.
393  const string inputs = FlattenInputs(nullptr, nullptr);
394
395  // Handle graph-mode case
396  strings::StrAppend(&result_,
397                     "  _ctx = _context.context()\n"
398
399                     "  if _ctx.in_graph_mode():\n"
400                     "    _, _, _op = _op_def_lib._apply_op_helper(\n");
401  AddBodyNoReturn("        ");
402  if (num_outs_ > 0) {
403    strings::StrAppend(&result_, "    _result = _op.outputs[:]\n");
404    // Special case handling for stateful op with single list output
405    // that might be empty.
406    if (num_outs_ == 1 && op_def_.is_stateful() &&
407        (!op_def_.output_arg(0).number_attr().empty() ||
408         !op_def_.output_arg(0).type_list_attr().empty())) {
409      // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
410      // a constraint indicating that this can never be empty.
411      strings::StrAppend(&result_,
412                         "    if not _result:\n"
413                         "      return _op\n");
414    }
415    strings::StrAppend(&result_, "    _inputs_flat = ", inputs, "\n");
416
417    // Compute graph-mode attrs.
418    if (op_def_.attr_size() > 0) {
419      string attr_values;
420      for (int i = 0; i < op_def_.attr_size(); ++i) {
421        if (i > 0) strings::StrAppend(&attr_values, ", ");
422        const auto& attr_name(op_def_.attr(i).name());
423        strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
424                           attr_name, "\")");
425      }
426      strings::StrAppend(&attr_values, ")");
427      strings::StrAppend(&result_,
428                         WordWrap("    _attrs = (", attr_values, kRightMargin),
429                         "\n");
430    } else {
431      strings::StrAppend(&result_, "    _attrs = None\n");
432    }
433  } else {
434    strings::StrAppend(&result_, "    return _op\n");
435  }
436
437  // Handle eager-mode case
438  strings::StrAppend(&result_, "  else:\n");
439
440  // Expression representing the number of outputs.
441  int num_fixed_outputs = 0;
442  string num_outputs_expr;
443  // If output i is list output, output_sizes[i] will be set to a
444  // string with the python expression that will evaluate to its
445  // length. output_sizes[i] is empty for non-list outputs.
446  std::vector<string> output_sizes(num_outs_);
447  for (int i = 0; i < num_outs_; ++i) {
448    const auto& arg(op_def_.output_arg(i));
449    if (!arg.number_attr().empty()) {
450      if (!num_outputs_expr.empty()) {
451        strings::StrAppend(&num_outputs_expr, " + ");
452      }
453      output_sizes[i] = attr_expressions_[arg.number_attr()];
454      strings::StrAppend(&num_outputs_expr, output_sizes[i]);
455    } else if (!arg.type_list_attr().empty()) {
456      if (!num_outputs_expr.empty()) {
457        strings::StrAppend(&num_outputs_expr, " + ");
458      }
459      // Have to be careful to use an expression that works in both
460      // graph and eager paths here.
461      const auto iter = inferred_attrs_.find(arg.type_list_attr());
462      if (iter == inferred_attrs_.end()) {
463        output_sizes[i] = strings::StrCat(
464            "len(", attr_expressions_[arg.type_list_attr()], ")");
465      } else {
466        output_sizes[i] = strings::StrCat("len(", iter->second, ")");
467      }
468      strings::StrAppend(&num_outputs_expr, output_sizes[i]);
469    } else {
470      ++num_fixed_outputs;
471    }
472  }
473  if (num_fixed_outputs > 0) {
474    if (!num_outputs_expr.empty()) {
475      strings::StrAppend(&num_outputs_expr, " + ");
476    }
477    strings::StrAppend(&num_outputs_expr, num_fixed_outputs);
478  } else if (num_outputs_expr.empty()) {
479    num_outputs_expr = "0";
480  }
481
482  bool eager_allowed = true;
483  string ref_arg;
484  for (const auto& arg : op_def_.input_arg()) {
485    if (arg.is_ref()) {
486      eager_allowed = false;
487      ref_arg = arg.name();
488    }
489  }
490  for (const auto& arg : op_def_.output_arg()) {
491    if (arg.is_ref()) {
492      eager_allowed = false;
493      ref_arg = arg.name();
494    }
495  }
496
497  if (eager_allowed) {
498    AddEagerInferredAttrs();
499    AddEagerInputCasts();
500    strings::StrAppend(&result_, "    _inputs_flat = ", inputs, "\n");
501    AddEagerAttrs();
502    AddEagerExecute(num_outputs_expr);
503  } else {
504    strings::StrAppend(&result_,
505                       "    raise RuntimeError(\n"
506                       "        \"",
507                       op_name_, " op does not support eager execution. ",
508                       "Arg '", ref_arg, "'' is a ref.\")\n");
509  }
510
511  if (num_outs_ > 0) {
512    strings::StrAppend(&result_, "  _execute.record_gradient(\n", "      \"",
513                       op_def_.name(),
514                       "\", _inputs_flat, _attrs, _result, _ctx, name)\n");
515    if (num_outs_ == 1 && !output_sizes[0].empty()) {
516      // Single list result.
517    } else if (num_outs_ == 1) {
518      // Execute returns a single-element list which we need to destructure.
519      strings::StrAppend(&result_, "  _result, = _result\n");
520    } else {
521      // Have multiple outputs, so we will need to reformat the return
522      // value of execute() to be a list with one entry per op output
523      // (that entry will be a list of tensors if that output is of list
524      // type).
525      // For list outputs, convert the right subrange of _result into a list.
526      Unflatten("  ", output_sizes, "_result", &result_);
527      // Convert to a named tuple.
528      strings::StrAppend(&result_, "  _result = _", op_def_.name(),
529                         "Output._make(_result)\n");
530    }
531  }
532  strings::StrAppend(&result_, "  return _result\n\n");
533  return prelude_ + result_;
534}
535
536void GenEagerPythonOp::ExpectListArg(const string& arg_name) {
537  strings::StrAppend(&result_, "  if not isinstance(", arg_name,
538                     ", (list, tuple)):\n"
539                     "    raise TypeError(\n"
540                     "        \"Expected list for '",
541                     arg_name,
542                     "' argument to \"\n"
543                     "        \"'",
544                     op_name_, "' Op, not %r.\" % ", arg_name, ")\n");
545}
546
547void GenEagerPythonOp::AddEagerInferredAttrs() {
548  // Figure out values for inferred attrs, and cast to eager tensors.
549  for (int i = 0; i < op_def_.attr_size(); ++i) {
550    const auto& attr(op_def_.attr(i));
551    auto arg_list = attr_to_args_.find(attr.name());
552    if (arg_list != attr_to_args_.end()) {
553      if (attr.type() == "type") {
554        std::vector<string> output_sizes;
555        const string flattened =
556            FlattenInputs(&arg_list->second, &output_sizes);
557        string conversion = strings::StrCat("_execute.args_to_matching_eager(",
558                                            flattened, ", _ctx");
559        if (attr.has_default_value()) {
560          strings::StrAppend(
561              &conversion, ", ",
562              python_op_gen_internal::AttrValueToPython(
563                  attr.type(), attr.default_value(), "_dtypes."));
564        }
565        strings::StrAppend(&conversion, ")");
566        const string var_name = AttrVarName(attr.name(), &attr_expressions_);
567        if (output_sizes.size() == 1) {
568          // Avoid creating a temporary variable in the case where
569          // we can easily assign to the right value directly.
570          const string inputs_var = param_names_[arg_list->second.front()];
571          if (output_sizes.front().empty()) {
572            strings::StrAppend(&result_, "    ", var_name, ", (", inputs_var,
573                               ",) = ", conversion, "\n");
574          } else {
575            strings::StrAppend(&result_, "    ", var_name, ", ", inputs_var,
576                               " = ", conversion, "\n");
577          }
578        } else {
579          const string inputs_var = strings::StrCat("_inputs_", attr.name());
580          strings::StrAppend(&result_, "    ", var_name, ", ", inputs_var,
581                             " = ", conversion, "\n");
582          // Convert from a flat list of eager tensors back to the
583          // parameter variables.
584          Unflatten("    ", output_sizes, inputs_var, &result_);
585          std::vector<string> p;
586          for (int j : arg_list->second) {
587            p.emplace_back(param_names_[j]);
588          }
589          strings::StrAppend(&result_, "    ", VectorToTuple(p), " = ",
590                             inputs_var, "\n");
591        }
592        strings::StrAppend(&result_, "    ", var_name, " = ", var_name,
593                           ".as_datatype_enum\n");
594      } else if (attr.type() == "list(type)") {
595        // NOTE: We ignore default values for these attrs, since it is
596        // unclear how you would use it, and the one use case is
597        // parse_single_sequence_example which only needs it for
598        // backwards compatibility.
599        const string var_name = AttrVarName(attr.name(), &attr_expressions_);
600        string inputs_var;
601        string conversion;
602        if (arg_list->second.size() > 1) {
603          // If you have more than one list(tensor) argument, their types
604          // have to match.
605          std::vector<string> lists;
606          for (auto iter = arg_list->second.begin();
607               iter != arg_list->second.end(); ++iter) {
608            lists.push_back(param_names_[*iter]);
609          }
610          inputs_var = VectorToTuple(lists);
611          conversion = "_execute.args_to_mixed_eager_tensors";
612        } else {
613          // For one list(tensor) argument, we just convert every
614          // element of the list to an eager tensor.
615          inputs_var = param_names_[arg_list->second.front()];
616          conversion = "_execute.convert_to_mixed_eager_tensors";
617        }
618        strings::StrAppend(&result_, "    ", var_name, ", ", inputs_var, " = ",
619                           conversion, "(", inputs_var, ", _ctx)\n");
620        strings::StrAppend(&result_, "    ", var_name,
621                           " = [_t.as_datatype_enum for _t in ", var_name,
622                           "]\n");
623      }
624    }
625  }
626}
627
628void GenEagerPythonOp::AddEagerInputCasts() {
629  // Cast remaining args to eager tensors
630  for (int i = 0; i < op_def_.input_arg_size(); ++i) {
631    const auto& arg(op_def_.input_arg(i));
632    if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
633    const string& param = param_names_[i];
634    const string fn = arg.number_attr().empty() ? "" : "n_";
635    const string dtype =
636        python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
637    strings::StrAppend(&result_, "    ", param, " = _ops.convert_", fn,
638                       "to_tensor(", param, ", ", dtype, ")\n");
639  }
640}
641
642void GenEagerPythonOp::AddEagerAttrs() {
643  // Compute eager attrs
644  if (op_def_.attr_size() > 0) {
645    string attr_values;
646    for (int i = 0; i < op_def_.attr_size(); ++i) {
647      if (i > 0) strings::StrAppend(&attr_values, ", ");
648      const auto& attr_name(op_def_.attr(i).name());
649      strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
650                         attr_expressions_[attr_name]);
651    }
652    strings::StrAppend(&attr_values, ")");
653    strings::StrAppend(
654        &result_, WordWrap("    _attrs = (", attr_values, kRightMargin), "\n");
655  } else {
656    strings::StrAppend(&result_, "    _attrs = None\n");
657  }
658}
659
660void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) {
661  const string return_prefix = "    _result = _execute.execute(";
662  const string return_args = strings::StrCat(
663      "b\"", op_def_.name(), "\", ", num_outputs_expr,
664      ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)");
665  strings::StrAppend(&result_,
666                     // Wrap the arguments, and indent to the (.
667                     WordWrap(return_prefix, return_args, kRightMargin), "\n");
668}
669
670string GetEagerPythonOps(const OpList& ops,
671                         const std::vector<string>& hidden_ops,
672                         bool require_shapes,
673                         const string& source_file_name = "") {
674  string result;
675  // Header
676  // TODO(josh11b): Mention the library for which wrappers are being generated.
677  strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
678
679This file is MACHINE GENERATED! Do not edit.
680)");
681
682  // Mention the original source file so someone tracing back through generated
683  // Python code will know where to look next.
684  if (!source_file_name.empty()) {
685    strings::StrAppend(&result, "Original C++ source file: ");
686    strings::StrAppend(&result, source_file_name);
687    strings::StrAppend(&result, "\n");
688  }
689
690  strings::StrAppend(&result, R"("""
691
692import collections as _collections
693
694from tensorflow.python.eager import execute as _execute
695from tensorflow.python.eager import context as _context
696from tensorflow.python.eager import core as _core
697from tensorflow.python.framework import dtypes as _dtypes
698from tensorflow.python.framework import tensor_shape as _tensor_shape
699
700from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
701# Needed to trigger the call to _set_call_cpp_shape_fn.
702from tensorflow.python.framework import common_shapes as _common_shapes
703from tensorflow.python.framework import op_def_registry as _op_def_registry
704from tensorflow.python.framework import ops as _ops
705from tensorflow.python.framework import op_def_library as _op_def_library
706
707)");
708
709  // We'll make a copy of ops that filters out descriptions.
710  OpList cleaned_ops;
711  auto out = cleaned_ops.mutable_op();
712  out->Reserve(ops.op_size());
713  for (const auto& op_def : ops.op()) {
714    bool is_hidden = false;
715    for (const string& hidden : hidden_ops) {
716      if (op_def.name() == hidden) {
717        is_hidden = true;
718        break;
719      }
720    }
721
722    string function_name;
723    python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
724                                                    &function_name);
725    if (is_hidden) function_name = strings::StrCat("_", function_name);
726
727    // When users create custom python wrappers, they may link in the
728    // default op registry by accident, and because they can't
729    // enumerate all 'hidden' symbols, this guard is to prevent
730    // instantiating a python reserved word in their wrapper.
731    if (python_op_gen_internal::IsPythonReserved(function_name)) {
732      continue;
733    }
734
735    strings::StrAppend(&result, GetEagerPythonOp(op_def, function_name));
736
737    if (!require_shapes) {
738      strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
739                         "\")(None)\n\n");
740    }
741
742    auto added = out->Add();
743    *added = op_def;
744    RemoveNonDeprecationDescriptionsFromOpDef(added);
745  }
746
747  result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
748  op_list = _op_def_pb2.OpList()
749  op_list.ParseFromString(op_list_proto_bytes)
750  _op_def_registry.register_op_list(op_list)
751  op_def_lib = _op_def_library.OpDefLibrary()
752  op_def_lib.add_op_list(op_list)
753  return op_def_lib
754)");
755
756  result.append("# ");
757  auto ops_text = ProtoDebugString(cleaned_ops);
758  str_util::StripTrailingWhitespace(&ops_text);
759  result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
760  result.append("\n");
761  strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
762                   str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
763  return result;
764}
765
766}  // namespace
767
768void PrintEagerPythonOps(const OpList& ops,
769                         const std::vector<string>& hidden_ops,
770                         bool require_shapes, const string& source_file_name) {
771  printf("%s",
772         GetEagerPythonOps(ops, hidden_ops, require_shapes, source_file_name)
773             .c_str());
774}
775
776string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
777  string op_list_str(op_list_buf, op_list_len);
778  OpList ops;
779  ops.ParseFromString(op_list_str);
780  return GetEagerPythonOps(ops, {}, false);
781}
782
783}  // namespace tensorflow
784