1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/function.h"
17
18#include <deque>
19#include <vector>
20
21#include "tensorflow/core/common_runtime/device.h"
22#include "tensorflow/core/common_runtime/executor.h"
23#include "tensorflow/core/common_runtime/graph_optimizer.h"
24#include "tensorflow/core/common_runtime/memory_types.h"
25#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
26#include "tensorflow/core/framework/function.h"
27#include "tensorflow/core/framework/node_def.pb.h"
28#include "tensorflow/core/framework/node_def_util.h"
29#include "tensorflow/core/framework/op.h"
30#include "tensorflow/core/framework/op_kernel.h"
31#include "tensorflow/core/framework/versions.pb.h"
32#include "tensorflow/core/graph/algorithm.h"
33#include "tensorflow/core/graph/control_flow.h"
34#include "tensorflow/core/graph/gradients.h"
35#include "tensorflow/core/graph/graph_constructor.h"
36#include "tensorflow/core/graph/optimizer_cse.h"
37#include "tensorflow/core/lib/gtl/map_util.h"
38#include "tensorflow/core/platform/macros.h"
39
40// See core/kernels/function_ops.cc for related kernels.
41
42namespace tensorflow {
43
44// A few string constant used throughout this module.
45//
46// TODO(zhifengc): Dedup some of these constants into
47// framework/function.h
48static constexpr const char* const kArgOp = "_Arg";
49static constexpr const char* const kRetOp = "_Retval";
50static constexpr const char* const kGradientOp =
51    FunctionLibraryDefinition::kGradientOp;
52static constexpr const char* const kNodeLabel = "Func";
53static constexpr const char* const kFuncAttr =
54    FunctionLibraryDefinition::kFuncAttr;
55
56// Represents the index-th output of a node.
57struct Endpoint {
58  Node* node;
59  int index;
60
61  // Returns the string name represents this endpoint.
62  string name() const {
63    if (index == 0) {
64      return node->name();
65    } else {
66      return strings::StrCat(node->name(), ":", index);
67    }
68  }
69
70  DataType dtype() const { return node->output_type(index); }
71};
72
73struct EndpointHash {
74  uint64 operator()(const Endpoint& x) const {
75    return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
76                  x.index);
77  }
78};
79
80struct EndpointEq {
81  bool operator()(const Endpoint& x, const Endpoint& y) const {
82    return (x.node == y.node) && (x.index == y.index);
83  }
84};
85
86// The following Add* routines are used to add a few graph nodes while
87// functions are transformed.
88static Node* AddNoOp(Graph* g) {
89  NodeDef ndef;
90  ndef.set_name(g->NewName(kNodeLabel));
91  ndef.set_op("NoOp");
92  Status s;
93  Node* ret = g->AddNode(ndef, &s);
94  TF_CHECK_OK(s);
95  return ret;
96}
97
98static Node* AddIdentity(Graph* g, Endpoint input) {
99  DCHECK_LT(0, input.dtype());
100  NodeDef ndef;
101  ndef.set_name(g->NewName(kNodeLabel));
102  ndef.set_op("Identity");
103  ndef.add_input(input.name());
104  AddNodeAttr("T", BaseType(input.dtype()), &ndef);
105  Status s;
106  Node* ret = g->AddNode(ndef, &s);
107  TF_CHECK_OK(s);
108  g->AddEdge(input.node, input.index, ret, 0);
109  return ret;
110}
111
112static Node* AddArg(Graph* g, DataType dtype, int index) {
113  DCHECK_LT(0, dtype);
114  DCHECK_LT(dtype, DT_FLOAT_REF);
115  NodeDef ndef;
116  ndef.set_name(g->NewName(kNodeLabel));
117  ndef.set_op(kArgOp);
118  AddNodeAttr("T", dtype, &ndef);
119  AddNodeAttr("index", index, &ndef);
120  Status s;
121  Node* ret = g->AddNode(ndef, &s);
122  TF_CHECK_OK(s);
123  return ret;
124}
125
126static Node* AddRet(Graph* g, Endpoint input, int index) {
127  DCHECK_LT(0, input.dtype());
128  DCHECK_LT(input.dtype(), DT_FLOAT_REF);
129  NodeDef ndef;
130  ndef.set_name(g->NewName(kNodeLabel));
131  ndef.set_op(kRetOp);
132  ndef.add_input(input.name());
133  AddNodeAttr("T", input.dtype(), &ndef);
134  AddNodeAttr("index", index, &ndef);
135  Status s;
136  Node* ret = g->AddNode(ndef, &s);
137  TF_CHECK_OK(s);
138  g->AddEdge(input.node, input.index, ret, 0);
139  return ret;
140}
141
142class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
143 public:
144  FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
145                             int graph_def_version,
146                             const FunctionLibraryDefinition* lib_def,
147                             const OptimizerOptions& optimizer_options,
148                             CustomKernelCreator custom_kernel_creator,
149                             ProcessFunctionLibraryRuntime* parent);
150
151  ~FunctionLibraryRuntimeImpl() override;
152
153  Status Instantiate(const string& function_name, AttrSlice attrs,
154                     const InstantiateOptions& options,
155                     Handle* handle) override;
156
157  Status ReleaseHandle(Handle handle) override;
158
159  const FunctionBody* GetFunctionBody(Handle handle) override;
160
161  Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
162
163  void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
164           std::vector<Tensor>* rets, DoneCallback done) override;
165  // NOTE(mrry): This overload is currently only implemented for local function
166  // execution.
167  // TODO(b/70346412): Implement support for remote function execution when
168  // passing a call frame.
169  void Run(const Options& opts, Handle handle, CallFrameInterface* frame,
170           DoneCallback done) override;
171
172  bool IsStateful(const string& function) override;
173
174  const FunctionLibraryDefinition* GetFunctionLibraryDefinition()
175      const override {
176    return base_lib_def_;
177  }
178
179  Device* device() override { return device_; }
180  Env* env() override { return env_; }
181  int graph_def_version() override { return graph_def_version_; }
182
183  string DebugString(Handle h) override;
184
185  Status Clone(std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
186               std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
187               FunctionLibraryRuntime** out_flr) override;
188
189 private:
190  typedef FunctionLibraryRuntimeImpl ME;
191
192  const DeviceMgr* const device_mgr_;
193  Device* const device_;
194  Env* const env_;
195  const int graph_def_version_;
196  const FunctionLibraryDefinition* const base_lib_def_;
197  GraphOptimizer optimizer_;
198  const CustomKernelCreator custom_kernel_creator_;
199  const string device_name_;
200
201  std::function<Status(const string&, const OpDef**)> get_func_sig_;
202  std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
203
204  mutable mutex mu_;
205
206  int next_handle_ GUARDED_BY(mu_);
207
208  // The instantiated and transformed function is encoded as a Graph
209  // object, and an executor is created for the graph.
210  struct Item : public core::RefCounted {
211    const Graph* graph = nullptr;                            // Owned by exec.
212    const FunctionLibraryDefinition* overlay_lib = nullptr;  // Not owned.
213    FunctionBody* func_graph = nullptr;
214    Executor* exec = nullptr;
215
216    ~Item() override {
217      delete this->func_graph;
218      delete this->exec;
219    }
220  };
221  std::unordered_map<Handle, Item*> items_ GUARDED_BY(mu_);
222
223  ProcessFunctionLibraryRuntime* parent_ = nullptr;  // not owned.
224
225  Status CreateKernel(const NodeDef& ndef,
226                      const FunctionLibraryDefinition* lib_def,
227                      OpKernel** kernel);
228  Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs,
229                           const FunctionLibraryDefinition* lib_def,
230                           FunctionBody** fbody);
231  Status CreateItem(Handle handle, Item** item);
232  Status GetOrCreateItem(Handle handle, Item** item);
233  Status InstantiateSymbolicGradient(const NameAttrList& func,
234                                     const FunctionLibraryDefinition* lib_def,
235                                     FunctionBody** g_body);
236  bool IsLocalTarget(const InstantiateOptions& options);
237  AttrValueMap FixAttrs(const AttrSlice& attrs);
238  void RunRemote(const Options& opts, Handle handle,
239                 gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
240                 Executor::Args* exec_args, Item* item, DoneCallback done);
241
242  TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
243};
244
245FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
246    const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
247    const FunctionLibraryDefinition* lib_def,
248    const OptimizerOptions& optimizer_options,
249    CustomKernelCreator custom_kernel_creator,
250    ProcessFunctionLibraryRuntime* parent)
251    : device_mgr_(dmgr),
252      device_(device),
253      env_(env),
254      graph_def_version_(graph_def_version),
255      base_lib_def_(lib_def),
256      optimizer_(optimizer_options),
257      custom_kernel_creator_(std::move(custom_kernel_creator)),
258      device_name_(device_ == nullptr
259                       ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
260                       : device_->name()),
261      next_handle_(0),
262      parent_(parent) {
263  get_func_sig_ = [this](const string& op, const OpDef** sig) {
264    return base_lib_def_->LookUpOpDef(op, sig);
265  };
266  create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
267    return CreateKernel(ndef, kernel);
268  };
269}
270
271FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
272  // The most common patterns of FLR usage don't require the caller to
273  // explicitly release handles. As a result, we try to unref each item until
274  // it's erased.
275  for (auto item : items_) {
276    if (item.second) {
277      while (!item.second->Unref()) {
278      }
279    }
280  }
281}
282
283// An asynchronous op kernel which executes an instantiated function
284// defined in a library.
285class CallOp : public AsyncOpKernel {
286 public:
287  CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
288      : AsyncOpKernel(ctx), handle_(handle) {}
289
290  ~CallOp() override {}
291
292  void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
293    FunctionLibraryRuntime* lib = ctx->function_library();
294    OP_REQUIRES_ASYNC(ctx, lib != nullptr,
295                      errors::Internal("No function library is provided."),
296                      done);
297    FunctionLibraryRuntime::Options opts;
298    opts.step_id = ctx->step_id();
299    opts.rendezvous = ctx->rendezvous();
300    opts.cancellation_manager = ctx->cancellation_manager();
301    opts.step_container = ctx->step_container();
302    opts.stats_collector = ctx->stats_collector();
303    opts.runner = ctx->runner();
304    std::vector<Tensor> args;
305    args.reserve(ctx->num_inputs());
306    for (int i = 0; i < ctx->num_inputs(); ++i) {
307      args.push_back(ctx->input(i));
308    }
309    std::vector<Tensor>* rets = new std::vector<Tensor>;
310    lib->Run(opts, handle_, args, rets,
311             [ctx, done, rets](const Status& status) {
312               if (!status.ok()) {
313                 ctx->SetStatus(status);
314               } else {
315                 const int ret_size = static_cast<int>(rets->size());
316                 CHECK_EQ(ret_size, ctx->num_outputs());
317                 for (int i = 0; i < ret_size; ++i) {
318                   ctx->set_output(i, (*rets)[i]);
319                 }
320               }
321               delete rets;
322               done();
323             });
324  }
325
326 private:
327  FunctionLibraryRuntime::Handle handle_;
328
329  TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
330};
331
332const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
333  LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h);
334  if (local_handle == kInvalidLocalHandle) {
335    LOG(ERROR) << "Could not find Handle: " << h
336               << " on device: " << device_name_;
337    return nullptr;
338  }
339
340  mutex_lock l(mu_);
341  CHECK_EQ(1, items_.count(local_handle));
342  return items_[local_handle]->func_graph;
343}
344
345Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
346                                                OpKernel** kernel) {
347  return CreateKernel(ndef, base_lib_def_, kernel);
348}
349
350Status FunctionLibraryRuntimeImpl::CreateKernel(
351    const NodeDef& ndef, const FunctionLibraryDefinition* lib_def,
352    OpKernel** kernel) {
353  // If a custom kernel creator is given, try that.
354  Status s;
355  if (custom_kernel_creator_) {
356    std::unique_ptr<OpKernel> ret;
357    s = custom_kernel_creator_(this, ndef, &ret);
358    if (s.ok()) {
359      *kernel = ret.release();
360      return s;
361    } else {
362      VLOG(2) << "Custom creator error: " << s;
363      // Falls through.
364      s = Status::OK();
365    }
366  }
367
368  if (lib_def->Find(ndef.op()) == nullptr) {
369    // A primitive operation. Creates the registered kernel.
370    return CreateNonCachedKernel(device_, this, ndef, graph_def_version_,
371                                 kernel);
372  }
373
374  // Try to instantiate this function for the func/attr. Maybe it's
375  // cached already.
376  InstantiateOptions options;
377  if (lib_def != base_lib_def_) {
378    options.overlay_lib = lib_def;
379  }
380  Handle handle;
381  TF_RETURN_IF_ERROR(
382      Instantiate(ndef.op(), AttrSlice(&ndef.attr()), options, &handle));
383
384  const FunctionBody* fbody = GetFunctionBody(handle);
385  CHECK_NOTNULL(fbody);
386
387  // TODO(zhifengc): For now, we assume int32 and resources are always on host
388  // memory and other types are always on device memory. We should do type
389  // inference over function body to derive the correct input/output memory
390  // types.
391  MemoryTypeVector input_memory_types;
392  for (const auto& t : fbody->arg_types) {
393    input_memory_types.push_back(
394        (t == DT_INT32 || t == DT_RESOURCE) ? HOST_MEMORY : DEVICE_MEMORY);
395  }
396  MemoryTypeVector output_memory_types;
397  for (const auto& t : fbody->ret_types) {
398    output_memory_types.push_back(t == DT_INT32 ? HOST_MEMORY : DEVICE_MEMORY);
399  }
400
401  // Constructs a CallOp kernel for running the instantiated function.
402  auto device_type = DeviceType(device_->attributes().device_type());
403  OpKernelConstruction construction(
404      device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
405      &fbody->fdef.signature(), this, fbody->arg_types, input_memory_types,
406      fbody->ret_types, output_memory_types, graph_def_version_, &s);
407  *kernel = new CallOp(handle, &construction);
408  if (!s.ok()) {
409    delete *kernel;
410  }
411  return s;
412}
413
414Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
415    const FunctionDef& fdef, AttrSlice attrs,
416    const FunctionLibraryDefinition* lib_def, FunctionBody** fbody) {
417  if (lib_def == base_lib_def_) {
418    return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig_, fbody);
419  } else {
420    auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
421      return lib_def->LookUpOpDef(op, sig);
422    };
423    return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody);
424  }
425}
426
427Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
428    const NameAttrList& func, const FunctionLibraryDefinition* lib_def,
429    FunctionBody** g_body) {
430  const FunctionDef* fdef = lib_def->Find(func.name());
431  if (fdef == nullptr) {
432    // f is a primitive op.
433    gradient::Creator creator;
434    TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
435    if (creator == nullptr) {
436      return errors::InvalidArgument("No gradient is defined for ",
437                                     func.name());
438    }
439    FunctionDef grad_fdef;
440    // TODO(josh11b): Should filter out the attrs from func that aren't used
441    // by the gradient function.
442    TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
443    TF_RETURN_IF_ERROR(
444        FunctionDefToBody(grad_fdef, AttrSlice(&func.attr()), lib_def, g_body));
445  } else {
446    // f is a user-defined function.
447    InstantiateOptions options;
448    if (lib_def != base_lib_def_) {
449      options.overlay_lib = lib_def;
450    }
451    Handle f_handle;
452    TF_RETURN_IF_ERROR(
453        Instantiate(func.name(), AttrSlice(&func.attr()), options, &f_handle));
454    const FunctionBody* f_body = GetFunctionBody(f_handle);
455    CHECK_NOTNULL(f_body);
456    *g_body = SymbolicGradient(*f_body);
457  }
458  return Status::OK();
459}
460
461bool FunctionLibraryRuntimeImpl::IsLocalTarget(
462    const InstantiateOptions& options) {
463  if (device_ == nullptr) return true;
464  if (options.target.empty()) return true;
465  Device* target_device;
466  if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
467    return false;
468  }
469  return target_device == device_;
470}
471
472Status FunctionLibraryRuntimeImpl::Instantiate(
473    const string& function_name, AttrSlice attrs,
474    const InstantiateOptions& options, Handle* handle) {
475  if (!IsLocalTarget(options)) {
476    return parent_->Instantiate(function_name, attrs, options, handle);
477  }
478
479  // Since this is a local target, ensure that the local `device_name_` appears
480  // in the canonical key.
481  InstantiateOptions options_copy(options);
482  options_copy.target = device_name_;
483  const string key = Canonicalize(function_name, attrs, options_copy);
484  *handle = parent_->GetHandle(key);
485  if (*handle != kInvalidHandle) {
486    mutex_lock l(mu_);
487    items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref();
488    return Status::OK();
489  }
490
491  Status s;
492  const FunctionLibraryDefinition* lib_def =
493      options.overlay_lib ? options.overlay_lib : base_lib_def_;
494  FunctionBody* fbody = nullptr;
495  if (function_name == kGradientOp) {
496    const AttrValue* f = attrs.Find(kFuncAttr);
497    if (f == nullptr) {
498      return errors::InvalidArgument("SymbolicGradient is missing attr: f");
499    }
500    const auto& func = f->func();
501    if (func.name() == kGradientOp) {
502      return errors::InvalidArgument("Can't take gradient of SymbolicGradient");
503    }
504    const string grad = lib_def->FindGradient(func.name());
505    if (!grad.empty()) {
506      return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
507    }
508    TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, lib_def, &fbody));
509  } else {
510    const FunctionDef* fdef = lib_def->Find(function_name);
511    if (fdef == nullptr) {
512      return errors::NotFound("Function ", function_name, " is not defined.");
513    }
514    TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, lib_def, &fbody));
515  }
516
517  {
518    mutex_lock l(mu_);
519    *handle = parent_->GetHandle(key);
520    if (*handle != kInvalidHandle) {
521      delete fbody;
522      items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref();
523    } else {
524      *handle = parent_->AddHandle(key, device_name_, next_handle_);
525      Item* item = new Item;
526      item->func_graph = fbody;
527      item->overlay_lib = options.overlay_lib;
528      items_.insert({next_handle_, item});
529      next_handle_++;
530    }
531  }
532  return Status::OK();
533}
534
535Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
536  if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
537    return parent_->ReleaseHandle(handle);
538  }
539
540  LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
541  mutex_lock l(mu_);
542  CHECK_EQ(1, items_.count(h));
543  Item* item = items_[h];
544  if (item->Unref()) {
545    items_.erase(h);
546    TF_RETURN_IF_ERROR(parent_->RemoveHandle(handle));
547  }
548  return Status::OK();
549}
550
551void DumpGraph(StringPiece label, const Graph* g) {
552  // TODO(zhifengc): Change Graph to record #nodes.
553  VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
554          << g->num_edges();
555  if (VLOG_IS_ON(2)) {
556    for (const auto& line : str_util::Split(DebugString(g), '\n')) {
557      VLOG(2) << "|| " << line;
558    }
559  }
560}
561
562void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) {
563  OptimizerOptions opts;
564  opts.set_do_common_subexpression_elimination(true);
565  opts.set_do_function_inlining(true);
566  opts.set_do_constant_folding(true);
567  GraphOptimizer optimizer(opts);
568  optimizer.Optimize(lib, lib->env(), lib->device(), g, /*shape_map=*/nullptr);
569}
570
571namespace {
572// Removes all stateless nodes that do not contribute to a return
573// value from the function body.  Unlike `RemoveDeadNodes()`, which is
574// triggered by `OptimizerOptions.do_function_inlining`, this pass
575// ignores the SINK node, from which (by definition) all nodes are
576// reverse reachable.
577void PruneFunctionBody(Graph* g) {
578  VLOG(2) << "Pruning function body";
579  std::unordered_set<const Node*> nodes;
580  for (auto n : g->nodes()) {
581    // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
582    // to the seed set of `nodes`.
583    // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
584    // still needed. It would be preferable to prune entire loops and/or
585    // conditionals if they are not used in the graph.
586    if (n->IsControlFlow() || n->op_def().is_stateful()) {
587      nodes.insert(n);
588    }
589  }
590  bool changed = PruneForReverseReachability(g, std::move(nodes));
591  if (changed) {
592    FixupSourceAndSinkEdges(g);
593  }
594}
595}  // namespace
596
597Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
598  const FunctionBody* fbody;
599  const FunctionLibraryDefinition* lib_def;
600  {
601    mutex_lock l(mu_);
602    fbody = (*item)->func_graph;
603    lib_def = (*item)->overlay_lib;
604  }
605  if (!lib_def) {
606    lib_def = base_lib_def_;
607  }
608  std::unique_ptr<Graph> g(new Graph(lib_def));
609  CopyGraph(*fbody->graph, g.get());
610
611  PruneFunctionBody(g.get());
612  optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr);
613  TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()),
614                                       device()->name(), g.get()));
615
616  // Creates an executor based on the g.  This must be done without
617  // holding mu_ because create_kernel_ calls back into the library.
618  LocalExecutorParams params;
619  params.device = device_;
620  params.function_library = this;
621  if (lib_def == base_lib_def_) {
622    params.create_kernel = create_kernel_;
623  } else {
624    params.create_kernel = [this, lib_def](const NodeDef& ndef,
625                                           OpKernel** kernel) {
626      return CreateKernel(ndef, lib_def, kernel);
627    };
628  }
629  params.delete_kernel = [](OpKernel* kernel) {
630    DeleteNonCachedKernel(kernel);
631  };
632  Graph* graph = g.get();
633  Executor* exec;
634  TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &exec));
635
636  {
637    // Guard item since it is already inserted in items_.
638    mutex_lock l(mu_);
639    if ((*item)->exec) {
640      delete exec;
641    } else {
642      (*item)->graph = graph;
643      (*item)->exec = exec;
644    }
645  }
646  return Status::OK();
647}
648
649Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
650  LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
651  {
652    mutex_lock l(mu_);
653    if (items_.count(local_handle) == 0) {
654      return errors::NotFound("Function handle ", handle,
655                              " is not valid. Likely an internal error.");
656    }
657    *item = items_[local_handle];
658    if ((*item)->exec != nullptr) {
659      return Status::OK();
660    }
661  }
662  // NOTE: We need to call CreateItem out of mu_ because creating an
663  // executor needs to call CreateKernel.
664  return CreateItem(handle, item);
665}
666
667void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
668                                           gtl::ArraySlice<Tensor> args,
669                                           std::vector<Tensor>* rets,
670                                           Executor::Args* exec_args,
671                                           Item* item, DoneCallback done) {
672  DCHECK(exec_args->call_frame == nullptr);
673  string target_device = parent_->GetDeviceName(handle);
674  string source_device = opts.source_device;
675  Rendezvous* rendezvous = opts.rendezvous;
676  DeviceContext* device_context;
677  Status s = parent_->GetDeviceContext(target_device, &device_context);
678  if (!s.ok()) {
679    delete exec_args;
680    done(s);
681    return;
682  }
683  int64 src_incarnation, target_incarnation;
684  s = parent_->GetDeviceIncarnation(source_device, &src_incarnation);
685  s.Update(parent_->GetDeviceIncarnation(target_device, &target_incarnation));
686  if (!s.ok()) {
687    delete exec_args;
688    done(s);
689    return;
690  }
691
692  const FunctionBody* fbody = GetFunctionBody(handle);
693  FunctionCallFrame* frame =
694      new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
695  exec_args->call_frame = frame;
696  if (!s.ok()) {
697    delete frame;
698    delete exec_args;
699    done(s);
700    return;
701  }
702
703  // The ProcFLR sends the arguments to the function from the source_device to
704  // the target_device. So here we receive those arguments. Similarly, when the
705  // computation is done and stored in *rets, we send the return values back
706  // to the source_device (caller) so that the ProcFLR can receive them later.
707  std::vector<Tensor>* remote_args = new std::vector<Tensor>;
708  ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
709      source_device, target_device, "arg_", src_incarnation, args.size(),
710      device_context, {}, rendezvous, remote_args,
711      [frame, remote_args, item, source_device, target_device,
712       target_incarnation, rendezvous, device_context, rets, done,
713       exec_args](const Status& status) {
714        Status s = status;
715        if (s.ok()) {
716          s = frame->SetArgs(*remote_args);
717        }
718        if (!s.ok()) {
719          delete frame;
720          delete remote_args;
721          delete exec_args;
722          done(s);
723          return;
724        }
725        item->exec->RunAsync(
726            *exec_args, [item, frame, rets, done, source_device, target_device,
727                         target_incarnation, rendezvous, device_context,
728                         remote_args, exec_args](const Status& status) {
729              Status s = status;
730              if (s.ok()) {
731                s = frame->ConsumeRetvals(rets);
732              }
733              delete frame;
734              if (!s.ok()) {
735                delete remote_args;
736                delete exec_args;
737                done(s);
738                return;
739              }
740              s = ProcessFunctionLibraryRuntime::SendTensors(
741                  target_device, source_device, "ret_", target_incarnation,
742                  *rets, device_context, {}, rendezvous);
743              delete remote_args;
744              delete exec_args;
745              done(s);
746            });
747      });
748}
749
750void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
751                                     gtl::ArraySlice<Tensor> args,
752                                     std::vector<Tensor>* rets,
753                                     DoneCallback done) {
754  if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
755    done(errors::Cancelled(""));
756    return;
757  }
758  Options run_opts = opts;
759  if (opts.create_rendezvous) {
760    Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
761    run_opts.rendezvous = rendezvous;
762    run_opts.create_rendezvous = false;
763    done = [done, rendezvous](const Status& status) {
764      rendezvous->Unref();
765      done(status);
766    };
767  }
768  if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
769    parent_->Run(run_opts, handle, args, rets, done);
770    return;
771  }
772
773  DCHECK(run_opts.runner != nullptr);
774
775  Executor::Args* exec_args = new Executor::Args;
776  // Inherit the step_id from the caller.
777  exec_args->step_id = run_opts.step_id;
778  exec_args->rendezvous = run_opts.rendezvous;
779  exec_args->stats_collector = run_opts.stats_collector;
780  exec_args->cancellation_manager = run_opts.cancellation_manager;
781  exec_args->step_container = run_opts.step_container;
782  exec_args->runner = *run_opts.runner;
783
784  Item* item = nullptr;
785  Status s = GetOrCreateItem(handle, &item);
786  if (!s.ok()) {
787    delete exec_args;
788    done(s);
789    return;
790  }
791
792  if (run_opts.remote_execution) {
793    // NOTE(mrry): `RunRemote()` will set `exec_args->call_frame` for us.
794    RunRemote(run_opts, handle, args, rets, exec_args, item, done);
795    return;
796  }
797
798  const FunctionBody* fbody = GetFunctionBody(handle);
799  FunctionCallFrame* frame =
800      new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
801  exec_args->call_frame = frame;
802  s = frame->SetArgs(args);
803  if (!s.ok()) {
804    delete frame;
805    delete exec_args;
806    done(s);
807    return;
808  }
809
810  item->exec->RunAsync(
811      // Executor args
812      *exec_args,
813      // Done callback.
814      [item, frame, rets, done, exec_args](const Status& status) {
815        Status s = status;
816        if (s.ok()) {
817          s = frame->ConsumeRetvals(rets);
818        }
819        delete frame;
820        delete exec_args;
821        done(s);
822      });
823}
824
825void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
826                                     CallFrameInterface* frame,
827                                     DoneCallback done) {
828  if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
829    done(errors::Cancelled(""));
830    return;
831  }
832  if (!parent_->IsInstantiatedOnDevice(device_name_, handle) ||
833      opts.remote_execution) {
834    done(errors::Unimplemented("Remote calling with CallFrameInterface"));
835    return;
836  }
837
838  Options run_opts = opts;
839  if (opts.create_rendezvous) {
840    Rendezvous* rendezvous = new IntraProcessRendezvous(device_mgr_);
841    run_opts.rendezvous = rendezvous;
842    run_opts.create_rendezvous = false;
843    done = std::bind(
844        [rendezvous](DoneCallback done,
845                     // Begin unbound arguments.
846                     const Status& status) {
847          rendezvous->Unref();
848          done(status);
849        },
850        std::move(done), std::placeholders::_1);
851  }
852
853  Item* item = nullptr;
854  Status s = GetOrCreateItem(handle, &item);
855  if (!s.ok()) {
856    done(s);
857    return;
858  }
859  DCHECK(run_opts.runner != nullptr);
860
861  Executor::Args* exec_args = new Executor::Args;
862  // Inherit the step_id from the caller.
863  exec_args->step_id = run_opts.step_id;
864  exec_args->rendezvous = run_opts.rendezvous;
865  exec_args->stats_collector = run_opts.stats_collector;
866  exec_args->cancellation_manager = run_opts.cancellation_manager;
867  exec_args->step_container = run_opts.step_container;
868  exec_args->runner = *run_opts.runner;
869  exec_args->call_frame = frame;
870
871  item->exec->RunAsync(
872      // Executor args
873      *exec_args,
874      // Done callback.
875      std::bind(
876          [item, frame, exec_args](DoneCallback done,
877                                   // Start unbound arguments.
878                                   const Status& status) {
879            delete exec_args;
880            done(status);
881          },
882          std::move(done), std::placeholders::_1));
883}
884
885bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
886  const OpDef* op_def;
887  const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
888  return s.ok() && op_def->is_stateful();
889}
890
891string FunctionLibraryRuntimeImpl::DebugString(Handle handle) {
892  Item* item = nullptr;
893  Status s = GetOrCreateItem(handle, &item);
894  if (s.ok()) {
895    return tensorflow::DebugString(item->graph);
896  } else {
897    return s.ToString();
898  }
899}
900
901Status FunctionLibraryRuntimeImpl::Clone(
902    std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
903    std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr,
904    FunctionLibraryRuntime** out_flr) {
905  TF_RETURN_IF_ERROR(
906      parent_->Clone(env_, graph_def_version_, optimizer_.options(),
907                     custom_kernel_creator_, out_lib_def, out_pflr));
908  *out_flr = (*out_pflr)->GetFLR(device_->name());
909  if (out_flr != nullptr) {
910    return Status::OK();
911  } else {
912    return errors::Internal("Cloning FunctionLibraryRuntime failed.");
913  }
914}
915
916namespace {
917
918struct CustomCreatorSingleton {
919  mutex mu;
920  CustomKernelCreator custom_creator = nullptr;
921
922  void Set(CustomKernelCreator cb) {
923    mutex_lock l(mu);
924    custom_creator = std::move(cb);
925  }
926
927  CustomKernelCreator Get() {
928    mutex_lock l(mu);
929    return custom_creator;
930  }
931};
932
933CustomCreatorSingleton* GetCustomCreatorSingleton() {
934  static CustomCreatorSingleton* ccs = new CustomCreatorSingleton;
935  return ccs;
936}
937
938}  // namespace
939
940void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
941  GetCustomCreatorSingleton()->Set(std::move(cb));
942}
943
944std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
945    const DeviceMgr* device_mgr, Env* env, Device* device,
946    int graph_def_version, const FunctionLibraryDefinition* lib_def,
947    const OptimizerOptions& optimizer_options,
948    CustomKernelCreator custom_kernel_creator,
949    ProcessFunctionLibraryRuntime* parent) {
950  return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
951      device_mgr, env, device, graph_def_version, lib_def, optimizer_options,
952      std::move(custom_kernel_creator), parent));
953}
954
955std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
956    const DeviceMgr* device_mgr, Env* env, Device* device,
957    int graph_def_version, const FunctionLibraryDefinition* lib_def,
958    const OptimizerOptions& optimizer_options,
959    ProcessFunctionLibraryRuntime* parent) {
960  return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version,
961                                   lib_def, optimizer_options,
962                                   GetCustomCreatorSingleton()->Get(), parent);
963}
964
965bool RemoveDeadNodes(Graph* g) {
966  VLOG(2) << "Removing dead nodes";
967  std::unordered_set<const Node*> nodes;
968  for (auto n : g->nodes()) {
969    if (n->IsSource() || n->IsSink() || n->IsControlFlow() ||
970        n->op_def().is_stateful()) {
971      nodes.insert(n);
972    }
973  }
974  return PruneForReverseReachability(g, std::move(nodes));
975}
976
977namespace {
978// If 'edges' contains only 1 non-control edge, returns it. Otherwise,
979// returns a nullptr.
980const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
981  const Edge* ret = nullptr;
982  for (const Edge* e : edges) {
983    if (e->IsControlEdge() || ret) {
984      // Don't touch it if there is a control edge.
985      return nullptr;
986    }
987    if (IsRefType(e->src()->output_type(e->src_output()))) {
988      // Don't touch it if the identity node is effectively de-reffing
989      // a ref.
990      return nullptr;
991    }
992    if (IsRecv(e->src()) || IsSwitch(e->src())) {
993      // Don't touch it if the identity is introduced for control flow.
994      // Recv disables all its successors if it receives a dead signal.
995      // When Recv has an outgoing control edge, the current executor
996      // would not disable the destination. The current solution (see
997      // graph_partition.cc) is to add an identity after Recv and change
998      // the control edge to be from this identity node. So the identity
999      // can't be removed.
1000      return nullptr;
1001    }
1002    ret = e;
1003  }
1004  return ret;
1005}
1006}  // end namespace
1007
1008bool RemoveIdentityNodes(Graph* g) {
1009  VLOG(2) << "Removing identity nodes";
1010  bool removed_any = false;
1011  gtl::InlinedVector<Node*, 8> matches;
1012  for (Node* n : g->nodes()) {
1013    if (!n->IsIdentity()) continue;
1014    if (!GetTheOnlyDataEdge(n->in_edges())) continue;
1015
1016    // Some identity nodes are used as sink nodes to give names to output
1017    // tensors. These nodes are not going to be executed unless they are in the
1018    // fetch set. But if they are in the fetch set we don't want to remove them.
1019    if (n->out_edges().empty()) continue;
1020
1021    matches.push_back(n);
1022  }
1023  if (!matches.empty()) {
1024    for (Node* n : matches) {
1025      const Edge* in = GetTheOnlyDataEdge(n->in_edges());
1026      for (const Edge* out : n->out_edges()) {
1027        if (out->IsControlEdge()) {
1028          g->AddControlEdge(in->src(), out->dst());
1029        } else {
1030          g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
1031        }
1032      }
1033      VLOG(2) << "Remove Identity: " << n->DebugString();
1034      g->RemoveNode(n);
1035      removed_any = true;
1036    }
1037  }
1038  return removed_any;
1039}
1040
1041bool RemoveListArrayConverter(Graph* g) {
1042  VLOG(2) << "Removing list array converter";
1043  gtl::InlinedVector<Node*, 8> matches;
1044  for (Node* n : g->nodes()) {
1045    if ((n->type_string() == "_ListToArray") ||
1046        (n->type_string() == "_ArrayToList")) {
1047      matches.push_back(n);
1048    }
1049  }
1050  bool removed_any = false;
1051  if (!matches.empty()) {
1052    for (Node* n : matches) {
1053      if (n->num_inputs() != n->num_outputs()) {
1054        continue;  // Not expected. Skip.
1055      }
1056      gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
1057
1058      // Process input edges first.
1059      Node* input_control_node = nullptr;
1060      for (const Edge* e : n->in_edges()) {
1061        if (e->IsControlEdge()) {
1062          if (input_control_node == nullptr) {
1063            // If node "n" has any control dependencies, adds a no-op
1064            // node (input_control_node) which the additional Identity
1065            // nodes depends on and the input_control_node depends on
1066            // the node "n"s control dependencies.
1067            input_control_node = AddNoOp(g);
1068          }
1069          g->AddControlEdge(e->src(), input_control_node);
1070        } else {
1071          const int index = e->dst_input();
1072          Node** id_node = &identity_nodes[index];
1073          if (*id_node != nullptr) {
1074            LOG(ERROR)
1075                << "RemoveListArrayConverter unexpected duplicated input: "
1076                << e->dst_input();
1077            return removed_any;
1078          }
1079          *id_node = AddIdentity(g, {e->src(), e->src_output()});
1080        }
1081      }
1082
1083      // If node "n" has any control dependencies, the added identity
1084      // nodes should have control dependencies on input_control_node.
1085      if (input_control_node != nullptr) {
1086        for (Node* id : identity_nodes) {
1087          g->AddControlEdge(input_control_node, id);
1088        }
1089      }
1090
1091      Node* output_control_node = nullptr;
1092      for (const Edge* e : n->out_edges()) {
1093        if (e->IsControlEdge()) {
1094          if (output_control_node == nullptr) {
1095            // If node "n" is control-depended upon by other nodes,
1096            // adds a no-op node (output_control_node) which those
1097            // nodes will depend on and output_control_node depends on
1098            // all Identity nodes.
1099            output_control_node = AddNoOp(g);
1100          }
1101          g->AddControlEdge(output_control_node, e->dst());
1102        } else {
1103          Node* id_node = identity_nodes[e->src_output()];
1104          if (id_node == nullptr) {
1105            LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
1106                       << e->src_output();
1107            return removed_any;
1108          }
1109          CHECK(id_node);
1110          g->AddEdge(id_node, 0, e->dst(), e->dst_input());
1111        }
1112      }
1113
1114      // If any nodes have control dependencies on node "n", those
1115      // nodes should have control dependencies on
1116      // output_control_node.
1117      if (output_control_node != nullptr) {
1118        for (Node* id : identity_nodes) {
1119          g->AddControlEdge(id, output_control_node);
1120        }
1121      }
1122
1123      g->RemoveNode(n);
1124      removed_any = true;
1125    }
1126  }
1127  return removed_any;
1128}
1129
1130// Returns true iff the function '*fbody' can be inlined at 'node'
1131// based on the type signature of 'node' and 'fbody'.
1132static bool ValidateInlining(const Node* node, const FunctionBody* fbody) {
1133  if (static_cast<size_t>(node->num_inputs()) != fbody->arg_types.size()) {
1134    return false;
1135  }
1136  if (static_cast<size_t>(node->num_inputs()) != fbody->arg_nodes.size()) {
1137    return false;
1138  }
1139  if (static_cast<size_t>(node->num_outputs()) != fbody->ret_types.size()) {
1140    return false;
1141  }
1142  if (static_cast<size_t>(node->num_outputs()) != fbody->ret_nodes.size()) {
1143    return false;
1144  }
1145  for (int i = 0; i < node->num_inputs(); ++i) {
1146    if (node->input_type(i) != fbody->arg_types[i]) return false;
1147  }
1148  for (int i = 0; i < node->num_outputs(); ++i) {
1149    if (node->output_type(i) != fbody->ret_types[i]) return false;
1150  }
1151  return true;
1152}
1153
1154// Given a "caller" in "graph", which is a function call of a function
1155// to "fbody". Replaces the "caller" with fbody->graph and connects
1156// edges properly.
1157void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
1158                        Node* caller, const FunctionBody* fbody) {
1159  if (!ValidateInlining(caller, fbody)) {
1160    LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. "
1161                 << DebugString(fbody->graph);
1162    return;
1163  }
1164
1165  // Input edges. For data edges coming into "caller", we first compute the
1166  // <src>:<src_output> for the i-th input in "inputs".
1167  // If "caller" has any input control dependencies, we add a NoOp
1168  // node "input_control_node", which depends on "caller"'s control inputs.
1169  std::vector<Endpoint> inputs(caller->num_inputs());
1170  Node* input_control_node = nullptr;
1171  for (const Edge* e : caller->in_edges()) {
1172    if (e->IsControlEdge()) {
1173      if (input_control_node == nullptr) {
1174        input_control_node = AddNoOp(g);
1175      }
1176      g->AddControlEdge(e->src(), input_control_node);
1177    } else {
1178      inputs[e->dst_input()] = {e->src(), e->src_output()};
1179    }
1180  }
1181
1182  // Duplicate fbody->graph into 'g'.  First, we copy the nodes of
1183  // fbody->graph into 'g' except the source and sink nodes.  We copy
1184  // edges among nodes in 'fbody->graph'.
1185  //
1186  // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
1187  // remember 'y' in node_map[x->id()].
1188  std::vector<Node*> node_map(fbody->graph->num_node_ids());
1189  Status s;
1190  for (Node* n : fbody->graph->op_nodes()) {
1191    NodeDef ndef = n->def();
1192    ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
1193    ndef.set_device(caller->def().device());
1194    Node* clone = g->AddNode(ndef, &s);
1195    TF_CHECK_OK(s);
1196    node_map[n->id()] = clone;
1197
1198    // If there is an input control node, and one of:
1199    // a) the node has no data or control inputs, or
1200    // b) the node is a function call or SymbolicGradient,
1201    // then add a control edge from the input control node to the clone.
1202    //
1203    // We must not execute any nodes if the original function call would not
1204    // have executed. This is especially critical when the function call is
1205    // inside a control-flow construct like tf.cond(). Case (a) ensures that
1206    // such nodes do not run.
1207    //
1208    // The purpose of case (b) is to ensure that instances of case (a) created
1209    // by further inlining steps also receive the control dependency.
1210    if (input_control_node) {
1211      bool has_inputs = false;
1212      for (const Edge* e : n->in_edges()) {
1213        if (!e->src()->IsSource()) {
1214          has_inputs = true;
1215          break;
1216        }
1217      }
1218      if (!has_inputs || flib_def.Find(clone->type_string()) != nullptr ||
1219          clone->type_string() == "SymbolicGradient") {
1220        g->AddControlEdge(input_control_node, clone);
1221      }
1222    }
1223  }
1224  for (const Edge* e : fbody->graph->edges()) {
1225    if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
1226        e->dst()->IsSink()) {
1227      continue;
1228    }
1229    Node* src_copy = node_map[e->src()->id()];
1230    Node* dst_copy = node_map[e->dst()->id()];
1231    g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1232  }
1233
1234  // Connect input edges.
1235  //
1236  // We create one Identity node for each input. Then, we connect inputs[i] to
1237  // the i-th identity node added. The nodes that previously connected
1238  // to the j-th output of i-th arg node are reconnected to the i-th
1239  // identity node.
1240  //
1241  // The added identity nodes depend on "input_control_node".
1242  for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
1243    Node* arg = node_map[fbody->arg_nodes[i]->id()];
1244    Node* n = AddIdentity(g, inputs[i]);
1245    if (input_control_node) {
1246      g->AddControlEdge(input_control_node, n);
1247    }
1248    for (const Edge* e : arg->out_edges()) {
1249      if (e->IsControlEdge()) {
1250        g->AddControlEdge(n, e->dst());
1251      } else {
1252        g->AddEdge(n, 0, e->dst(), e->dst_input());
1253      }
1254    }
1255    node_map[fbody->arg_nodes[i]->id()] = n;
1256    g->RemoveNode(arg);  // 'arg' is disconnected.
1257  }
1258
1259  // Connect output edges.
1260  //
1261  // For i-th return node in fbody->graph, we add in "g" an identity
1262  // node (outputs[i-th]). We then reconnect every incoming edge into
1263  // the i-th return node to the added identity node.
1264  //
1265  // For every data edge coming out of "callee"s i-th output, we
1266  // reconnect it to the i-th identity added above.
1267  //
1268  // If "callee" is control-depended upon by any other nodes, we add a
1269  // NoOp node "output_control_node". "output_control_node" depends on
1270  // all identity nodes added above. And nodes previously depend on
1271  // "callee" is changed to depend on "output_control_node".
1272  std::vector<Node*> outputs(caller->num_outputs());
1273  for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
1274    Node* ret = node_map[fbody->ret_nodes[i]->id()];
1275    Endpoint data;  // Data input for the ret node.
1276    for (const Edge* e : ret->in_edges()) {
1277      if (!e->IsControlEdge()) {
1278        data = {e->src(), e->src_output()};
1279        break;
1280      }
1281    }
1282    CHECK(data.node != nullptr);
1283    Node* n = AddIdentity(g, data);
1284    outputs[i] = n;
1285    for (const Edge* e : ret->in_edges()) {
1286      if (e->IsControlEdge()) {
1287        g->AddControlEdge(e->src(), n);
1288      }
1289    }
1290    g->RemoveNode(ret);  // 'ret' is disconnected.
1291  }
1292  Node* output_control_node = nullptr;
1293  for (const Edge* e : caller->out_edges()) {
1294    if (e->IsControlEdge()) {
1295      if (output_control_node == nullptr) {
1296        output_control_node = AddNoOp(g);
1297        for (Node* n : outputs) {
1298          g->AddControlEdge(n, output_control_node);
1299        }
1300      }
1301      g->AddControlEdge(output_control_node, e->dst());
1302    } else {
1303      g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
1304    }
1305  }
1306  g->RemoveNode(caller);  // 'caller' is replaced with inlined nodes.
1307}
1308
1309bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
1310  std::vector<std::pair<Node*, const FunctionBody*>> candidates;
1311  const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
1312  for (Node* node : graph->nodes()) {
1313    VLOG(3) << "Expanding " << node->DebugString();
1314    bool noinline;
1315    if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
1316      VLOG(3) << "noinline: " << node->DebugString();
1317      continue;
1318    }
1319    FunctionLibraryRuntime::Handle handle;
1320    Status s = lib->Instantiate(node->type_string(), node->attrs(), &handle);
1321    if (!s.ok()) {
1322      // Either "node" is a primitive op, or the instantiation failed.
1323      if (errors::IsNotFound(s)) {
1324        VLOG(3) << "ExpandInlineFunctions " << s;
1325      } else {
1326        LOG(ERROR) << "ExpandInlineFunctions " << s;
1327      }
1328      continue;
1329    }
1330    const FunctionBody* fbody = lib->GetFunctionBody(handle);
1331    CHECK_NOTNULL(fbody);
1332    candidates.push_back({node, fbody});
1333  }
1334  for (const auto& p : candidates) {
1335    InlineFunctionBody(*fld, graph, p.first, p.second);
1336  }
1337  return !candidates.empty();
1338}
1339
1340string NewName(const Node* n, bool pretty) {
1341  if (pretty) {
1342    return strings::StrCat(n->type_string(), n->id());
1343  } else {
1344    return strings::StrCat("n", n->id());
1345  }
1346}
1347
1348// TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
1349// and stash the original NodeDef name as an attr for documentation
1350// purpose.
1351void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) {
1352  // We visit nodes in forward topological sort order, which is a
1353  // possible execution order of the graph.
1354  gtl::InlinedVector<const Edge*, 4> inputs;
1355  gdef->Clear();
1356  gdef->mutable_versions()->CopyFrom(g->versions());
1357
1358  std::vector<Node*> start_nodes;
1359  for (Node* n : g->nodes()) {
1360    if (n->out_edges().empty()) {
1361      start_nodes.push_back(n);
1362    }
1363  }
1364
1365  ReverseDFSFrom(*g, start_nodes, nullptr, [gdef, pretty, &inputs](Node* n) {
1366    if (!n->IsOp()) return;
1367    NodeDef* ndef = gdef->add_node();
1368    ndef->set_name(NewName(n, pretty));
1369    ndef->set_op(n->type_string());
1370    for (const auto& attr : n->attrs()) {
1371      (*ndef->mutable_attr())[attr.first] = attr.second;
1372    }
1373    inputs.clear();
1374    inputs.resize(n->num_inputs());
1375    for (const Edge* e : n->in_edges()) {
1376      if (e->IsControlEdge()) {
1377        inputs.push_back(e);
1378      } else {
1379        if (inputs[e->dst_input()] == nullptr) {
1380          inputs[e->dst_input()] = e;
1381        } else {
1382          LOG(WARNING) << "Malformed graph node. multiple input edges: "
1383                       << n->DebugString();
1384        }
1385      }
1386    }
1387    // node->name() is merely NodeDef::name, which are not guaranteed
1388    // to be unique and stable after optimization rewrites. Therefore,
1389    // we use "n<node id>" instead.
1390    for (const Edge* e : inputs) {
1391      if (e == nullptr) {
1392        ndef->add_input("unknown");
1393        continue;
1394      }
1395      const string srcname = NewName(e->src(), pretty);
1396      if (!e->src()->IsOp()) {
1397      } else if (e->IsControlEdge()) {
1398        ndef->add_input(strings::StrCat("^", srcname));
1399      } else if (e->src_output() == 0) {
1400        ndef->add_input(srcname);
1401      } else {
1402        ndef->add_input(strings::StrCat(srcname, ":", e->src_output()));
1403      }
1404    }
1405  });
1406}
1407
1408string DebugString(const Graph* g) {
1409  GraphDef gdef;
1410  ToGraphDef(g, &gdef);
1411  return DebugString(gdef);
1412}
1413
1414FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
1415                           DataTypeSlice ret_t, Graph* g)
1416    : fdef(f),
1417      graph(g),
1418      arg_types(arg_t.begin(), arg_t.end()),
1419      ret_types(ret_t.begin(), ret_t.end()) {
1420  this->arg_nodes.resize(arg_types.size());
1421  this->ret_nodes.resize(ret_types.size());
1422  for (Node* n : this->graph->op_nodes()) {
1423    gtl::InlinedVector<Node*, 4>* node_vec;
1424    if (n->type_string() == kRetOp) {
1425      node_vec = &this->ret_nodes;
1426    } else if (n->type_string() == kArgOp) {
1427      node_vec = &this->arg_nodes;
1428    } else {
1429      continue;
1430    }
1431    int index;
1432    TF_CHECK_OK(GetNodeAttr(n->attrs(), "index", &index));
1433    CHECK_LE(0, index);
1434    CHECK_LT(index, node_vec->size());
1435    (*node_vec)[index] = n;
1436  }
1437}
1438
1439FunctionBody::~FunctionBody() { delete this->graph; }
1440
1441class SymbolicGradientHelper {
1442 public:
1443  explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
1444
1445  ~SymbolicGradientHelper() { delete gbody_; }
1446
1447  FunctionBody* Compute();
1448
1449 private:
1450  const FunctionBody* fbody_;
1451  FunctionBody* gbody_ = nullptr;
1452
1453  // Makes a copy of fbody_ in gbody_.
1454  void Copy();
1455
1456  TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
1457};
1458
1459void SymbolicGradientHelper::Copy() {
1460  const Graph& src = *(fbody_->graph);
1461  gbody_->graph = new Graph(src.op_registry());
1462  Graph* dst = gbody_->graph;
1463
1464  std::vector<Node*> node_map(src.num_node_ids());
1465
1466  // Copy the nodes.
1467  node_map[src.source_node()->id()] = dst->source_node();
1468  node_map[src.sink_node()->id()] = dst->sink_node();
1469  for (Node* n : src.op_nodes()) {
1470    node_map[n->id()] = dst->CopyNode(n);
1471  }
1472
1473  // Copy the edges.
1474  for (const Edge* e : src.edges()) {
1475    Node* src_copy = node_map[e->src()->id()];
1476    Node* dst_copy = node_map[e->dst()->id()];
1477    dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
1478  }
1479
1480  // Save inputs in copied graph.
1481  CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
1482  gbody_->arg_types = fbody_->arg_types;
1483  for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1484    gbody_->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
1485  }
1486
1487  // Save outputs in copied graph.
1488  CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
1489  gbody_->ret_types = fbody_->ret_types;
1490  for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
1491    gbody_->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
1492  }
1493}
1494
1495FunctionBody* SymbolicGradientHelper::Compute() {
1496  CHECK(gbody_ == nullptr);
1497  gbody_ = new FunctionBody;
1498
1499  // Copy fbody_ into gbody_.
1500  Copy();
1501
1502  Graph* g = gbody_->graph;
1503
1504  const int num_y = static_cast<int>(gbody_->ret_nodes.size());
1505
1506  // Populate 'y_node_outputs_' with node function body outputs.
1507  // Populate 'y_grad_nodes' with initial gradient nodes for each return node of
1508  // the original function body (these will be 'arg' nodes in the function
1509  // gradient body).
1510  std::vector<NodeOut> y_node_outputs;
1511  y_node_outputs.reserve(num_y);
1512  std::vector<NodeOut> y_grad_node_outputs;
1513  y_grad_node_outputs.reserve(num_y);
1514  for (int i = 0; i < num_y; ++i) {
1515    Node* y = gbody_->ret_nodes[i];
1516    y_node_outputs.push_back({y, 0});
1517    DCHECK_EQ(y->type_string(), kRetOp);
1518    const DataType dtype = y->input_type(0);
1519    const int index = static_cast<int>(gbody_->arg_nodes.size());
1520    Node* dy = AddArg(g, dtype, index);
1521    gbody_->arg_types.push_back(dtype);
1522    gbody_->arg_nodes.push_back(dy);
1523    y_grad_node_outputs.push_back({dy, 0});
1524  }
1525
1526  // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
1527  const size_t num_x = fbody_->arg_nodes.size();
1528  std::vector<NodeOut> x_node_outputs;
1529  x_node_outputs.reserve(num_x);
1530  for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
1531    x_node_outputs.push_back({gbody_->arg_nodes[i], 0});
1532  }
1533
1534  // Call AddSymbolicGradients which will add nodes to graph 'g' that
1535  // compute the function gradient (adding an entry in 'x_grad_node_outputs' for
1536  // each node in 'x_node_outputs').
1537  std::vector<NodeOut> x_grad_node_outputs;
1538  TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
1539                                   y_grad_node_outputs, &x_grad_node_outputs,
1540                                   g));
1541
1542  // Remove the old return nodes from the function body.
1543  for (Node* n : gbody_->ret_nodes) {
1544    g->RemoveNode(n);
1545  }
1546  gbody_->ret_types = fbody_->arg_types;
1547  gbody_->ret_nodes.clear();
1548  // Add new return nodes to the function gradient body for each node
1549  // in 'x_grad_nodes'.
1550  const int arg_types_size = static_cast<int>(fbody_->arg_types.size());
1551  for (int i = 0; i < arg_types_size; ++i) {
1552    Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
1553    Node* ret = AddRet(g, grad, i);
1554    gbody_->ret_nodes.push_back(ret);
1555  }
1556
1557  auto ret = gbody_;
1558  gbody_ = nullptr;
1559  return ret;
1560}
1561
1562FunctionBody* SymbolicGradient(const FunctionBody& f) {
1563  return SymbolicGradientHelper(f).Compute();
1564}
1565
1566Status FunctionDefToBodyHelper(
1567    const FunctionDef& fdef, const AttrSlice& attrs,
1568    const FunctionLibraryDefinition* const lib_def,
1569    const std::function<Status(const string&, const OpDef**)>& get_func_sig,
1570    FunctionBody** fbody) {
1571  // Instantiates the function template into a graph def.
1572  InstantiationResult result;
1573  TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result));
1574
1575  std::unique_ptr<Graph> graph(new Graph(lib_def));
1576  GraphConstructorOptions opts;
1577  opts.allow_internal_ops = true;
1578  opts.expect_device_spec = false;
1579  TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get()));
1580
1581  // Call BuildControlFlowInfo to validate that this function body has
1582  // well-formed control flow.
1583  // NOTE(skyewm): this is usually done in Partition(), but we don't partition
1584  // function bodies. This should be removed if function bodies ever go through
1585  // the Partition() path.
1586  std::vector<ControlFlowInfo> dummy;
1587  TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy));
1588
1589  *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types,
1590                            graph.release());
1591  return Status::OK();
1592}
1593
1594}  // end namespace tensorflow
1595