1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
17
18#include <functional>
19#include <memory>
20#include <numeric>
21#include <string>
22#include <unordered_map>
23#include <vector>
24
25#include "tensorflow/compiler/jit/graph_to_functiondef.h"
26#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
27#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
28#include "tensorflow/compiler/tf2xla/const_analysis.h"
29#include "tensorflow/compiler/tf2xla/dump_graph.h"
30#include "tensorflow/compiler/xla/status_macros.h"
31#include "tensorflow/core/common_runtime/function.h"
32#include "tensorflow/core/common_runtime/optimization_registry.h"
33#include "tensorflow/core/common_runtime/shape_refiner.h"
34#include "tensorflow/core/framework/function.h"
35#include "tensorflow/core/framework/graph_def_util.h"
36#include "tensorflow/core/framework/node_def_builder.h"
37#include "tensorflow/core/framework/node_def_util.h"
38#include "tensorflow/core/graph/algorithm.h"
39#include "tensorflow/core/graph/graph.h"
40#include "tensorflow/core/graph/graph_def_builder.h"
41#include "tensorflow/core/graph/tensor_id.h"
42#include "tensorflow/core/lib/gtl/flatset.h"
43#include "tensorflow/core/lib/gtl/map_util.h"
44#include "tensorflow/core/lib/hash/hash.h"
45#include "tensorflow/core/lib/strings/str_util.h"
46#include "tensorflow/core/lib/strings/strcat.h"
47#include "tensorflow/core/public/session_options.h"
48#include "tensorflow/core/public/version.h"
49#include "tensorflow/core/util/device_name_utils.h"
50
51namespace tensorflow {
52
53const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
54const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
55const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
56
57namespace {
58
59bool AreAllParentsConst(const Node& n,
60                        const gtl::FlatSet<const Node*>& runtime_const_nodes) {
61  if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") {
62    // If the current node is itself a cast-to-const, no need
63    // to look at the incoming edges.
64    return true;
65  }
66
67  bool all_parents_const = true;
68  bool atleast_one_non_control_edge = false;
69  for (const Edge* in : n.in_edges()) {
70    atleast_one_non_control_edge =
71        atleast_one_non_control_edge || !in->IsControlEdge();
72    if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
73      all_parents_const = false;
74      break;
75    }
76  }
77  return all_parents_const && atleast_one_non_control_edge;
78}
79
80void MarkGuaranteedConstants(
81    const Graph& graph,
82    const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
83  gtl::FlatSet<const Node*> guaranteed_const_nodes;
84  std::vector<const Node*> srcs;
85  srcs.reserve(src_arg_pairs.size());
86  for (const auto& src_arg : src_arg_pairs) {
87    srcs.push_back(src_arg.first);
88  }
89  ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
90                 /*leave=*/[&guaranteed_const_nodes](const Node* n) {
91                   // TODO(vinuraja): Doesn't work in the presence of loops.
92                   if (AreAllParentsConst(*n, guaranteed_const_nodes)) {
93                     guaranteed_const_nodes.insert(n);
94                   }
95                 });
96
97  for (auto& src_arg : src_arg_pairs) {
98    if (guaranteed_const_nodes.count(src_arg.first) != 0) {
99      VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
100      src_arg.second->AddAttr("_is_guaranteed_constant", true);
101    }
102  }
103}
104
105// A node/slot pair.
106// TODO(phawkins): is there a common definition of this?
107struct NodeSlot {
108  NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {}
109  NodeSlot(const Node* node, int slot)
110      : node(node), slot(slot), dtype(DT_INVALID) {}
111  NodeSlot(const Node* node, int slot, DataType dtype)
112      : node(node), slot(slot), dtype(dtype) {}
113
114  const Node* node;
115  int slot;
116
117  // Optional: used to record the destination type of a source NodeSlot in case
118  // the source output is a Ref type that is cast to a Tensor at the
119  // destination.
120  DataType dtype;
121
122  bool operator==(const NodeSlot& other) const {
123    return node == other.node && slot == other.slot && dtype == other.dtype;
124  }
125
126  // Leave dtype out of the hash since there are never two NodeSlots with the
127  // same node and slot and different dtypes.
128  struct Hasher {
129    uint64 operator()(NodeSlot const& s) const {
130      return Hash64Combine(std::hash<const Node*>()(s.node),
131                           std::hash<int>()(s.slot));
132    }
133  };
134
135  struct PairHasher {
136    uint64 operator()(std::pair<NodeSlot, NodeSlot> const& s) const {
137      return Hash64Combine(Hasher()(s.first), Hasher()(s.second));
138    }
139  };
140};
141
142// TODO(phawkins) add a canonical copy of these operator names and refactor
143// everything to use it.
144static const char* const kArgOp = "_Arg";
145static const char* const kRetValOp = "_Retval";
146static const char* const kHostComputeOp = "_XlaHostCompute";
147static const char* const kSendFromHostOp = "_XlaSendFromHost";
148static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
149
150class Encapsulator {
151 public:
152  Encapsulator(string group_attribute, string outside_compilation_attribute,
153               Graph const* graph_in)
154      : group_attribute_(std::move(group_attribute)),
155        outside_compilation_attribute_(
156            std::move(outside_compilation_attribute)),
157        graph_in_(graph_in) {}
158
159  // Find subgraphs marked with 'group_attribute', and build a new
160  // subgraph, one for each value of 'group_attribute'.
161  Status SplitIntoSubgraphs();
162
163  // Build a FunctionDef for each subgraph, and add it 'library'. The values of
164  // the 'group_attribute' annotations become the function names.
165  // If 'reuse_existing_functions' is set, use an existing function with the
166  // same name, if any.
167  // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
168  // function conversion.
169  Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
170                           bool reuse_existing_functions,
171                           FunctionLibraryDefinition* library);
172
173  // Write a copy of the input graph to 'graph_out', where the subgraphs are
174  // replaced with calls to the new functions.
175  Status BuildOutputGraph(bool parallel_checking, Graph* graph_out,
176                          FunctionLibraryDefinition* library);
177
178 private:
179  // A subgraph of the input, all marked with a common 'group_attribute'
180  // value. A subgraph may contain multiple `outside_compilation' clusters.
181  //
182  // In the following simple example, A, B, ..., E are nodes in the original
183  // graph. The group attributes and outside_compilation attributes g and oc are
184  // each shown as either 0 or empty.
185  //
186  //  A  -->  B  -->  C  -->  D  -->  E
187  //  g:      g:0     g:0     g:0     g:
188  //  oc:     oc:     oc:0    oc:     oc:
189  //
190  // The example is rewritten to two graphs; one on the host and one to be
191  // compiled. The host graph is as follows. RAH is a RecvAtHost node receiving
192  // input from the compiled cluster, and SFH is a SendFromHost node sending
193  // input back to the compiled cluster. Dotted edges are control edges. A
194  // 'sequencing' node S is inserted, and both RAH and SFH are connected via S
195  // to E (and in general all nodes that depend on nodes in the compiled
196  // cluster) to ensure that they are not pruned.
197  //
198  //  A  -->  Call  -->  E
199  //                     ^
200  //                     .
201  //           ........> S
202  //       ....          ^
203  //     ..             .
204  //  RAH -->  C  --> SFH
205  //
206  // The compiled cluster is as follows. HC is a HostCompute node which is the
207  // source of a channel to the RAH node above and the destination of a channel
208  // from the SFH node above.
209  //
210  //  Arg  --> B  --> HC  --> D --> Retval
211  //
212  // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is
213  // at most one RAH and SFH in each outside_compilation cluster. This design is
214  // preferred over adding separate Arg/Retval nodes for each transmitted value
215  // because it allows optimizations to the host code that would like to limit
216  // communication between host and device and, e.g., raise only one interrupt
217  // per channel rather than one per transmitted value.
218  //
219  // The shapes of the outputs from the HC node in general cannot be determined
220  // until the shapes of its inputs are known at compile time, since e.g.,
221  // above, the shape of C's outputs aren't known until the shape of its inputs
222  // are known. If the shapes of the HC's outputs can be determined during the
223  // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal
224  // graph is stored in the shape_inference_graph attr. This graph can be used
225  // when compiling the HC Op to determined the shape of the SFH inputs given
226  // the shapes of any ancestor RAH outputs. If it can be determined that the
227  // shape of the SFH inputs will not be inferrable even once the shapes of the
228  // RAH outputs are known, an error is returned by the rewriter.
229  class Subgraph {
230   public:
231    // Creates a graph to build the subgraph in, if it doesn't already exist,
232    // using the same op registry and versions as graph_in.
233    Node* MakeNodeImage(const Graph* graph_in, Node* node);
234
235    // Returns the graph the subgraph is being built in.
236    Graph* GetGraph() const;
237
238    // Builds a FunctionDef, and adds it to 'library'. The value of the
239    // 'group_attribute' annotations becomes the function name.  If
240    // 'reuse_existing_functions' is set, use an existing function with the same
241    // name, if any.  If 'rewrite_subgraph_fn' is set, it is applied to the
242    // subgraph before function conversion.
243    Status BuildFunctionDef(const string& name_in,
244                            const RewriteSubgraphFn& rewrite_subgraph_fn,
245                            bool reuse_existing_functions,
246                            FunctionLibraryDefinition* library);
247
248    // Adds the function call node to graph_out.
249    Status AddFunctionCallNode(
250        const std::unordered_map<const Node*, Node*>& node_images,
251        bool parallel_checking, Graph* graph_out);
252
253    // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out.
254    Status AddOutsideCompilationHostIONodes(
255        const string& subgraph_name,
256        const std::unordered_map<const Node*, Node*>& node_images,
257        Graph* graph_out);
258
259    // Returns the names of all the outside_compilation subgraphs in this
260    // Subgraph.
261    void GetOutsideCompilationSubgraphNames(std::vector<string>* names) const;
262
263    // Returns the Node that inputs to the function should be wired up to.
264    Node* GetCallNodeForInputs() const;
265
266    // Returns the Node that outputs to the function should be wired up to.
267    Node* GetCallNodeForOutputs() const;
268
269    // Returns the index of the arg that the dst of edge should connect to.
270    int GetArgIndexForEdge(const Edge* edge) const;
271
272    // Returns the index of the result that the src of edge should connect to.
273    int GetResultIndexForEdge(const Edge* edge) const;
274
275    // Returns the RecvAtHost node for an outside_compilation subgraph.
276    Node* GetRecvAtHostNode(
277        const string& outside_compilation_subgraph_name) const;
278
279    // Returns the output slot for the RecvAtHost node that corresponds to the
280    // source of edge in an outside_compilation subgraph.
281    int GetRecvAtHostSlot(const string& outside_compilation_subgraph_name,
282                          const Edge* edge) const;
283
284    // Returns the SendFromHost node for an outside_compilation subgraph.
285    Node* GetSendFromHostNode(
286        const string& outside_compilation_subgraph_name) const;
287
288    // Returns the input slot for the SendFromHost node that corresponds to the
289    // destination of edge in an outside_compilation subgraph.
290    int GetSendFromHostSlot(const string& outside_compilation_subgraph_name,
291                            const Edge* edge) const;
292
293    // Creates an _Arg node for the src node of edge, and add its index to
294    // args_by_src_, if none exists yet. Also adds its index to args_by_dst_,
295    // and adds the edge within the subgraph from the _Arg node to the image of
296    // the dst node.
297    Status RecordArg(const Edge* edge,
298                     const std::unordered_map<const Node*, Node*>& node_images,
299                     std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
300
301    // Creates a _Retval node for the src node of edge, and add it to results_,
302    // if none exists yet. If a new _Retval node is created, also adds the edge
303    // within the subgraph from the src to the _Retval node.
304    Status RecordResult(
305        const Edge* edge,
306        const std::unordered_map<const Node*, Node*>& node_images);
307
308    // Creates an outside_compilation subgraph for outside_compilation_id if
309    // none exists yet. Creates an entry for the src node of edge in the list of
310    // inputs for the outside_compilation subgraph, if none exists yet.
311    void RecordOutsideCompilationInputOrControl(
312        const string& outside_compilation_id, const Edge* edge);
313
314    // Creates an outside_compilation subgraph for outside_compilation_id if
315    // none exists yet. Creates an entry for the src node of edge in the list of
316    // outputs by src for the outside_compilation subgraph, if none exists
317    // yet. Creates an entry for the dst node of edge in the list of outputs by
318    // dst for the outside_compilation subgraph.
319    void RecordOutsideCompilationOutputOrControl(
320        const string& outside_compilation_id, const Edge* edge);
321
322    // Adds the HostCompute nodes for each outside_compilation subgraph.
323    Status AddHostComputes(
324        const string& subgraph_name,
325        const std::unordered_map<const Node*, Node*>& node_images);
326
327    // Creates the sequencer node if it doesn't exist, adding it to graph_out.
328    Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
329
330    // If there is a sequencer node, adds a control edge from the sequencer to
331    // all the downstream nodes of call_node_outputs.
332    void ConnectSequencerToOutputs(Graph* graph_out);
333
334    Status AddShapeInferenceInfo(
335        const string& outside_compilation_subgraph_name,
336        const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph);
337
338    Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
339
340   private:
341    struct OutsideCompilationSubgraph {
342      // Map from source (producer node/slot) tensors in the original graph to
343      // input index (slot number in the HostCompute/RecvAtHost nodes that will
344      // be created) for the outside_compilation subgraph.
345      std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
346
347      // Set of nodes in the original graph that are the source of control edges
348      // that cross from the containing compiled subgraph into the
349      // outside_compilation subgraph. These are recorded by
350      // RecordOutsideCompilationInputOrControl while walking all the subgraph
351      // edges, and lifted control edges within the subgraph are added by
352      // AddSendsToOutsideCompilation once the _HostCompute node has been
353      // created. The matching control edge from _RecvAtHost to the
354      // destination is added by CopyEdgeToOutputGraph.
355      std::unordered_set<const Node*> control_inputs;
356
357      // Maps from source (producer node/slot) and destination (consumer
358      // node/slot) tensors in the original graph to output index (slot number
359      // in the SendFromHost/HostCompute nodes that will be created) for the
360      // outside_compilation subgraph.
361      std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
362      std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
363
364      // Set of nodes in the original graph that are the destination of control
365      // edges that cross from the outside_compilation subgraph into the
366      // containing compiled subgraph. These are recorded by
367      // RecordOutsideCompilationOutputOrControl while walking all the subgraph
368      // edges, and lifted control edges within the subgraph are added by
369      // AddRecvsFromToOutsideCompilation once the _HostCompute node has been
370      // created. The matching control edge from the source to _SendFromHost to
371      // the destination is added by CopyEdgeToOutputGraph.
372      std::unordered_set<const Node*> control_outputs;
373
374      // Name of the _HostCompute node in the subgraph.
375      string host_compute_name;
376
377      // _RecvAtHost node in the output graph. Not owned.
378      Node* recv_at_host = nullptr;
379
380      // _SendFromHost node in the output graph. Not owned.
381      Node* send_from_host = nullptr;
382    };
383
384    // Builds a ParallelCheck op that compares the output of the original
385    // subgraph with the encapsulated subgraph.
386    Status BuildParallelCheckOp(
387        const std::unordered_map<const Node*, Node*>& node_images,
388        Graph* graph_out);
389
390    // Builds a _RecvAtHost node producing all the inputs of an
391    // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host.
392    Status AddRecvAtHostNode(const string& subgraph_name,
393                             const string& oc_subgraph_name,
394                             OutsideCompilationSubgraph* oc_subgraph,
395                             Graph* graph_out);
396
397    // Builds a _SendFromHost node consuming all the outputs of an
398    // outside_compilation subgraph and stores it in oc_subgraph.send_from_host.
399    Status AddSendFromHostNode(
400        const std::unordered_map<const Node*, Node*>& node_images,
401        const string& subgraph_name, const string& oc_subgraph_name,
402        OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out);
403
404    // The subgraph extracted from the input graph, suitable for being turned
405    // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
406    // returned by _Retval nodes.
407    std::unique_ptr<Graph> graph_;
408
409    // Which device are these nodes on? Used to assign a device to the call
410    // node.
411    string device_;
412
413    // NodeDef for the function call node.
414    NodeDef call_node_def_;
415
416    // Function call node(s) in the output graph. Not owned.
417    // If parallel_checking is enabled, 'call_node_inputs' is the function call
418    // node to which inputs should be fed, and 'call_node_outputs' is the
419    // parallel check op from which outputs should be read. If parallel checking
420    // is disabled, both point to the function call node.
421    Node* call_node_inputs_;
422    Node* call_node_outputs_;
423
424    // Maps from source (producer node/slot) and destination
425    // (consumer node/slot) tensors in the input graph to _Arg numbers in
426    // the subgraph. The source map is one-to-one, whereas the dest map may be
427    // many-to-one.
428    std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_src_;
429    std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_dst_;
430
431    // The _Arg nodes in the subgraph, in order by argument number.
432    std::vector<Node*> args_;
433
434    // Map from source tensor in the input graph to result #.
435    std::unordered_map<NodeSlot, int, NodeSlot::Hasher> results_;
436
437    // The outside_compilation clusters in this subgraph.
438    std::unordered_map<string, OutsideCompilationSubgraph>
439        outside_compilation_subgraphs_;
440
441    // NoOp node in the output graph that is sequenced after the call node and
442    // used to prevent host-side outside_compilation sends and recvs from being
443    // pruned.
444    Node* sequencer_ = nullptr;
445  };
446
447  // Returns the key attribute and outside_compilation attribute associated
448  // with a node in attr, and outside_compilation_attr, respectively. Sets
449  // either result to the empty string if the respective attribute is not
450  // found. Returns error status if there is an outside_compilation attribute
451  // and no key attribute,
452  Status GetFunctionNameAttr(Node const* node, string* attr,
453                             string* outside_compilation_attr) const;
454
455  // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
456  // subgraphs for data edges that cross subgraph boundaries.
457  Status CopySubgraphEdges(
458      const std::unordered_map<const Node*, Node*>& node_images,
459      std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
460
461  // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes,
462  // or nodes marked outside_compilation.
463  Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
464
465  // Copies all nodes that aren't in a compiled subgraph to the output graph.
466  Status CopyNodesToOutputGraph(
467      bool parallel_checking, Graph* graph_out,
468      std::unordered_map<const Node*, Node*>* node_images);
469
470  // Adds function call nodes for each compiled subgraph.
471  Status AddFunctionCallNodes(
472      const std::unordered_map<const Node*, Node*>& node_images,
473      bool parallel_checking, Graph* graph_out);
474
475  // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all
476  // outside_compilation subgraphs.
477  Status AddOutsideCompilationHostIONodes(
478      const std::unordered_map<const Node*, Node*>& node_images,
479      Graph* graph_out);
480
481  // Finds the image of an edge source in the output graph. If the edge crosses
482  // a subgraph boundary it is the output of a call node, otherwise it is a node
483  // in the output graph.
484  Status FindOutputImageOfEdgeSrc(
485      const string& src_func_id, const string& src_outside_compilation_id,
486      const string& dst_func_id, const string& dst_outside_compilation_id,
487      const std::unordered_map<const Node*, Node*>& node_images,
488      const Node* original_src_node, Node** src_image);
489
490  // Finds an edge source slot in the output graph. If the edge crosses a
491  // subgraph boundary it is a slot on the output of a call node or a
492  // _RecvAtHost node, otherwise it is a slot on a node in the output graph.
493  int FindOutputSlotOfEdgeSrc(const string& src_func_id,
494                              const string& src_outside_compilation_id,
495                              const string& dst_func_id,
496                              const string& dst_outside_compilation_id,
497                              const Edge* edge);
498
499  // Finds the image of an edge destination in the output graph. If the edge
500  // crosses a subgraph boundary it is the input of a call node or a
501  // _SendFromHost node, otherwise it is a node in the output graph.
502  Status FindOutputImageOfEdgeDst(
503      const string& src_func_id, const string& src_outside_compilation_id,
504      const string& dst_func_id, const string& dst_outside_compilation_id,
505      const std::unordered_map<const Node*, Node*>& node_images,
506      const Node* original_dst_node, Node** dst_image);
507
508  // Finds an edge destination slot in the output graph. If the edge crosses a
509  // subgraph boundary it is a slot on the input of a call node or a
510  // _SendFromHost node, otherwise it is a slot on a node in the output graph.
511  int FindOutputSlotOfEdgeDst(const string& src_func_id,
512                              const string& src_outside_compilation_id,
513                              const string& dst_func_id,
514                              const string& dst_outside_compilation_id,
515                              const Edge* edge);
516
517  // Copies a single edge to the output graph. The edge is either entirely
518  // within the output graph, or crosses into or out of a compiled subgraph.
519  Status CopyEdgeToOutputGraph(
520      const Edge* edge, const string& src_func_id,
521      const string& src_outside_compilation_id, const string& dst_func_id,
522      const string& dst_outside_compilation_id,
523      const std::unordered_map<const Node*, Node*>& node_images,
524      bool parallel_checking, Graph* graph_out,
525      std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
526          edges_added);
527
528  // Adds all edges to the output graph.
529  Status AddEdgesToOutputGraph(
530      const std::unordered_map<const Node*, Node*>& node_images,
531      bool parallel_checking, Graph* graph_out);
532
533  // Constructs a minimal shape inference graph that can be used to determine
534  // the shape of send_node at the time that the subgraph is compiled.
535  // recv_at_host_nodes contains the names of all the recv_at_host nodes that
536  // send_node might depend on. These recv_at_host nodes have shapes that are
537  // not known during the rewrite pass, but will be known at compile time.
538  //
539  // If the shapes of all the inputs to send_node can be determined during the
540  // rewrite pass, on exit graphdef_out is empty and the shapes are returned in
541  // static_shape_out. Otherwise graphdef_out contains a graph that can be used
542  // for shape inference at compile time, where all the source nodes of the
543  // graph are either constants with known shapes, or nodes named in
544  // recv_at_host_nodes.
545  //
546  // A non-OK status is returned if neither of the above conditions can be
547  // satisfied, e.g., because send_node depends on a node that doesn't have a
548  // registered shape inference function.
549  Status DoStaticShapeInferenceForOutsideCompilationSend(
550      const Graph& graph_in, const ShapeRefiner& shape_refiner,
551      const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
552      FunctionLibraryDefinition* library,
553      std::vector<TensorShapeProto>* static_shape_out,
554      std::unique_ptr<GraphDef>* graphdef_out);
555
556  // Makes a copy of graph containing only nodes that are ancestors of at least
557  // one node in send_from_host_nodes and store it in pruned_graph. On exit
558  // nodes_images contains a mapping from nodes in graph to nodes in
559  // pruned_graph. All functions in the copied graph are inlined.
560  Status MakePrunedGraphCopyAndInline(
561      const Graph& graph, const std::vector<Node*>& sink_nodes,
562      std::unique_ptr<Graph>* pruned_graph,
563      std::unordered_map<const Node*, Node*>* node_images,
564      FunctionLibraryDefinition* library);
565
566  // Makes a copy of graph containing only nodes that are ancestors of a
567  // send_from_host node in an outside_compilation subgraph, and store it in
568  // pruned_graph. Also perform shape inference on the pruned graph, using
569  // shape_refiner. On exit node_images contains a mapping from nodes in graph
570  // to nodes in pruned_graph.
571  Status MakeGraphForOutsideCompilationSends(
572      const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
573      ShapeRefiner* shape_refiner,
574      std::unordered_map<const Node*, Node*>* node_images,
575      FunctionLibraryDefinition* library);
576
577  // Performs static shape inference, as far as possible, for the send_from_host
578  // nodes in each outside_compilation subgraph. Where it is not possible to
579  // determine the shape statically, stores a serialized GraphDef in the
580  // HostCompute 'shape_inference_graph' attr, to be used at compile time for
581  // final inference. If the shapes are known statically they are stored in the
582  // HostCompute 'shapes' attr.
583  Status GetShapeInfoForOutsideCompilationSends(
584      Graph* graph_out, FunctionLibraryDefinition* library);
585
586  const string group_attribute_;
587  const string outside_compilation_attribute_;
588  const Graph* graph_in_;
589
590  std::unordered_map<string, Subgraph> subgraphs_;
591
592  TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
593};
594
595Node* Encapsulator::Subgraph::GetCallNodeForInputs() const {
596  return call_node_inputs_;
597}
598
599Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const {
600  return call_node_outputs_;
601}
602
603int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
604  return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input()));
605}
606
607int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
608  return results_.at(NodeSlot(edge->src(), edge->src_output()));
609}
610
611Node* Encapsulator::Subgraph::GetRecvAtHostNode(
612    const string& outside_compilation_subgraph_name) const {
613  return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
614      .recv_at_host;
615}
616
617int Encapsulator::Subgraph::GetRecvAtHostSlot(
618    const string& outside_compilation_subgraph_name, const Edge* edge) const {
619  return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
620      .inputs.at(NodeSlot(edge->src(), edge->src_output()));
621}
622
623Node* Encapsulator::Subgraph::GetSendFromHostNode(
624    const string& outside_compilation_subgraph_name) const {
625  return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
626      .send_from_host;
627}
628
629int Encapsulator::Subgraph::GetSendFromHostSlot(
630    const string& outside_compilation_subgraph_name, const Edge* edge) const {
631  return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
632      .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input()));
633}
634
635Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
636  if (!graph_) {
637    graph_.reset(new Graph(graph_in->op_registry()));
638    graph_->set_versions(graph_in->versions());
639  }
640
641  if (device_.empty()) {
642    device_ = node->assigned_device_name().empty()
643                  ? node->requested_device()
644                  : node->assigned_device_name();
645  }
646
647  return graph_->CopyNode(node);
648}
649
650Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
651
652Status Encapsulator::Subgraph::RecordArg(
653    const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
654    std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
655  Node* src_node = edge->src();
656  int src_slot = edge->src_output();
657  std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
658  bool inserted;
659  std::tie(iter, inserted) =
660      args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size());
661  int arg_index = iter->second;
662  if (inserted) {
663    NodeDef arg_def;
664    NodeDefBuilder builder(
665        strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
666    DataType dtype = edge->dst()->input_type(edge->dst_input());
667    builder.Attr("T", dtype);
668    builder.Attr("index", arg_index);
669    Status s = builder.Finalize(&arg_def);
670    if (!s.ok()) return s;
671
672    Node* arg = graph_->AddNode(arg_def, &s);
673    if (!s.ok()) return s;
674
675    src_arg_pairs->push_back({src_node, arg});
676    args_.push_back(arg);
677  }
678  Node* dst_node = edge->dst();
679  Node* dst_image = node_images.at(dst_node);
680  int dst_slot = edge->dst_input();
681  args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index;
682  graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
683  return Status::OK();
684}
685
686Status Encapsulator::Subgraph::RecordResult(
687    const Edge* edge,
688    const std::unordered_map<const Node*, Node*>& node_images) {
689  Node* src_node = edge->src();
690  Node* src_image = node_images.at(src_node);
691  int src_slot = edge->src_output();
692  std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
693  bool inserted;
694  std::tie(iter, inserted) =
695      results_.emplace(NodeSlot(src_node, src_slot), results_.size());
696  int ret_index = iter->second;
697  if (inserted) {
698    NodeDef ret_def;
699    NodeDefBuilder builder(
700        strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
701    DataType dtype = src_node->output_type(src_slot);
702    builder.Attr("T", dtype);
703    builder.Attr("index", ret_index);
704    builder.Input(src_image->name(), src_slot, dtype);
705    Status s = builder.Finalize(&ret_def);
706    if (!s.ok()) return s;
707    Node* ret = graph_->AddNode(ret_def, &s);
708    if (!s.ok()) return s;
709
710    graph_->AddEdge(src_image, src_slot, ret, 0);
711  }
712  return Status::OK();
713}
714
715void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl(
716    const string& outside_compilation_id, const Edge* edge) {
717  auto iter = outside_compilation_subgraphs_
718                  .emplace(outside_compilation_id, OutsideCompilationSubgraph())
719                  .first;
720  OutsideCompilationSubgraph& outside_subgraph = iter->second;
721  if (edge->IsControlEdge()) {
722    outside_subgraph.control_inputs.insert(edge->src());
723  } else {
724    int input_index = outside_subgraph.inputs.size();
725    outside_subgraph.inputs.emplace(NodeSlot(edge->src(), edge->src_output()),
726                                    input_index);
727  }
728}
729
730void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
731    const string& outside_compilation_id, const Edge* edge) {
732  auto subgraph_iter =
733      outside_compilation_subgraphs_
734          .emplace(outside_compilation_id, OutsideCompilationSubgraph())
735          .first;
736  OutsideCompilationSubgraph& outside_subgraph = subgraph_iter->second;
737  if (edge->IsControlEdge()) {
738    outside_subgraph.control_outputs.insert(edge->dst());
739  } else {
740    DataType dtype = edge->dst()->input_type(edge->dst_input());
741    auto output_iter =
742        outside_subgraph.outputs_by_src
743            .emplace(NodeSlot(edge->src(), edge->src_output(), dtype),
744                     outside_subgraph.outputs_by_src.size())
745            .first;
746    int output_index = output_iter->second;
747    outside_subgraph.outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] =
748        output_index;
749  }
750}
751
752Status Encapsulator::Subgraph::AddHostComputes(
753    const string& subgraph_name,
754    const std::unordered_map<const Node*, Node*>& node_images) {
755  for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
756    const string& oc_subgraph_name = oc_subgraph_iter.first;
757    OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
758    if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() ||
759        !oc_subgraph.outputs_by_src.empty() ||
760        !oc_subgraph.control_outputs.empty()) {
761      // Build a _HostCompute node.
762      std::vector<NodeDefBuilder::NodeOut> inputs(oc_subgraph.inputs.size());
763      std::vector<DataType> input_dtypes(oc_subgraph.inputs.size(), DT_INVALID);
764      std::vector<DataType> output_dtypes(oc_subgraph.outputs_by_src.size(),
765                                          DT_INVALID);
766
767      for (const auto& input_src : oc_subgraph.inputs) {
768        const Node* src_node = input_src.first.node;
769        Node* src_image = node_images.at(src_node);
770        int src_slot = input_src.first.slot;
771        int input_index = input_src.second;
772
773        DataType dtype = src_node->output_type(src_slot);
774        inputs[input_index].Reset(src_image->name(), src_slot, dtype);
775        input_dtypes[input_index] = dtype;
776      }
777
778      for (const auto& output : oc_subgraph.outputs_by_src) {
779        DataType dtype = output.first.dtype;
780        int output_index = output.second;
781        output_dtypes[output_index] = dtype;
782      }
783
784      NodeDef host_compute_def;
785      NodeDefBuilder builder(strings::StrCat("outside_compilation_",
786                                             oc_subgraph_name, "_host_compute"),
787                             kHostComputeOp);
788      builder.Input(inputs);
789      builder.Attr("Tinputs", input_dtypes);
790      builder.Attr("Toutputs", output_dtypes);
791      builder.Attr("key",
792                   strings::StrCat("host_compute_channel_", subgraph_name, "_",
793                                   oc_subgraph_name));
794      Status s = builder.Finalize(&host_compute_def);
795      if (!s.ok()) return s;
796
797      Node* host_compute = graph_->AddNode(host_compute_def, &s);
798      if (!s.ok()) return s;
799      oc_subgraph.host_compute_name = host_compute->name();
800
801      // Connect the _HostCompute node to its producers in the subgraph.
802      for (auto& input_src : oc_subgraph.inputs) {
803        const Node* src_node = input_src.first.node;
804        Node* src_image = node_images.at(src_node);
805        int src_slot = input_src.first.slot;
806        int input_index = input_src.second;
807        graph_->AddEdge(src_image, src_slot, host_compute, input_index);
808      }
809
810      // Connect the _HostCompute node to its control edge producers in the
811      // subgraph.
812      for (const auto& src_node : oc_subgraph.control_inputs) {
813        Node* src_image = node_images.at(src_node);
814        graph_->AddControlEdge(src_image, host_compute);
815      }
816
817      // Connect the consumers in the subgraph to the _HostCompute node.
818      for (const auto& output : oc_subgraph.outputs_by_dst) {
819        const Node* dst_node = output.first.node;
820        Node* dst_image = node_images.at(dst_node);
821        int dst_slot = output.first.slot;
822        int output_index = output.second;
823
824        graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
825      }
826
827      // Connect the control edge consumers in the subgraph to the _HostCompute
828      // node.
829      for (const auto& dst_node : oc_subgraph.control_outputs) {
830        Node* dst_image = node_images.at(dst_node);
831        graph_->AddControlEdge(host_compute, dst_image);
832      }
833    }
834  }
835
836  return Status::OK();
837}
838
839Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
840                                                  Graph* graph_out) {
841  if (sequencer_ == nullptr) {
842    NodeDef seq_def;
843    NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"),
844                           "NoOp");
845    Status s = builder.Finalize(&seq_def);
846    if (!s.ok()) return s;
847
848    sequencer_ = graph_out->AddNode(seq_def, &s);
849    if (!s.ok()) return s;
850    sequencer_->set_assigned_device_name(device_);
851  }
852  return Status::OK();
853}
854
855void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) {
856  if (sequencer_ != nullptr) {
857    std::unordered_set<Node*> output_dependencies;
858    for (Node* node : call_node_outputs_->out_nodes()) {
859      output_dependencies.insert(node);
860    }
861    for (Node* node : output_dependencies) {
862      graph_out->AddControlEdge(sequencer_, node);
863    }
864  }
865}
866
867Status Encapsulator::Subgraph::BuildFunctionDef(
868    const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
869    bool reuse_existing_functions, FunctionLibraryDefinition* library) {
870  // name_in is copied here because name may be modified below if
871  // rewrite_subgraph_fn is true.
872  string name = name_in;
873  call_node_def_.set_op(name);
874  call_node_def_.set_name(name);
875  call_node_def_.set_device(device_);
876
877  if (rewrite_subgraph_fn) {
878    // Initialize the input and output permutations to the identity.
879    std::vector<int> input_permutation(args_by_src_.size());
880    std::iota(input_permutation.begin(), input_permutation.end(), 0);
881    std::vector<int> output_permutation(results_.size());
882    std::iota(output_permutation.begin(), output_permutation.end(), 0);
883
884    TF_RETURN_IF_ERROR(rewrite_subgraph_fn(
885        &graph_, &input_permutation, &output_permutation, &call_node_def_));
886
887    // Apply the input/output permutations to the 'args_by_...' and 'results_'
888    // mappings, so when we build edges in BuildOutputGraph() we
889    // connect them to the right input/output positions.
890    if (input_permutation.size() != args_by_src_.size()) {
891      return errors::InvalidArgument("Input permutation has incorrect size.");
892    }
893    if (output_permutation.size() != results_.size()) {
894      return errors::InvalidArgument("Output permutation has incorrect size.");
895    }
896    for (auto& arg : args_by_src_) {
897      arg.second = input_permutation[arg.second];
898    }
899    for (auto& arg : args_by_dst_) {
900      arg.second = input_permutation[arg.second];
901    }
902    for (auto& result : results_) {
903      result.second = output_permutation[result.second];
904    }
905
906    name = call_node_def_.op();
907  }
908
909  FunctionDef fdef;
910  TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
911
912  if (VLOG_IS_ON(1)) {
913    VLOG(2) << "Build function def " << name;
914    dump_graph::DumpGraphToFile(
915        strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library);
916    dump_graph::DumpFunctionDefToFile(
917        strings::StrCat("encapsulate_fdef_", name), fdef);
918  }
919
920  if (!reuse_existing_functions || library->Find(name) == nullptr) {
921    TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
922  }
923  return Status::OK();
924}
925
926Status Encapsulator::Subgraph::AddShapeInferenceInfo(
927    const string& outside_compilation_subgraph_name,
928    const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) {
929  OutsideCompilationSubgraph& oc_subgraph =
930      outside_compilation_subgraphs_.at(outside_compilation_subgraph_name);
931
932  Node* host_compute = nullptr;
933  for (Node* n : graph_->nodes()) {
934    if (n->name() == oc_subgraph.host_compute_name) {
935      host_compute = n;
936      break;
937    }
938  }
939  if (host_compute == nullptr) {
940    return errors::InvalidArgument(
941        "After rewriting subgraph ", outside_compilation_subgraph_name,
942        " there is no HostCompute Op for outside compilation subgraph ",
943        oc_subgraph.host_compute_name);
944  }
945
946  if (inference_graph == nullptr) {
947    host_compute->AddAttr("shape_inference_graph", "");
948    host_compute->AddAttr("shapes", shapes);
949  } else {
950    string serialized_graph;
951    if (!inference_graph->SerializeToString(&serialized_graph)) {
952      return errors::Internal(
953          "Failed to serialize graph for outside compilation subgraph ",
954          oc_subgraph.host_compute_name);
955    }
956    host_compute->AddAttr("shape_inference_graph", serialized_graph);
957    host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
958  }
959  return Status::OK();
960}
961
962Status Encapsulator::Subgraph::ReplaceFunctionDef(
963    FunctionLibraryDefinition* library) {
964  const string& name = call_node_def_.name();
965
966  FunctionDef fdef;
967  TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
968
969  if (VLOG_IS_ON(1)) {
970    VLOG(2) << "Replace function def " << name;
971    dump_graph::DumpGraphToFile(
972        strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
973        library);
974    dump_graph::DumpFunctionDefToFile(
975        strings::StrCat("replace_encapsulate_fdef_", name), fdef);
976  }
977
978  TF_RETURN_IF_ERROR(library->RemoveFunction(name));
979  TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
980  return Status::OK();
981}
982
983Status Encapsulator::Subgraph::BuildParallelCheckOp(
984    const std::unordered_map<const Node*, Node*>& node_images,
985    Graph* graph_out) {
986  // Build an index mapping output positions to node/slot pairs in the
987  // original graph.
988  std::vector<NodeSlot> results_by_num(results_.size());
989  for (const auto& entry : results_) {
990    results_by_num[entry.second] = entry.first;
991  }
992
993  // Build a parallel check NodeDef.
994  int num_results = results_by_num.size();
995  std::vector<DataType> result_dtypes(num_results);
996  std::vector<NodeDefBuilder::NodeOut> expected_outputs(num_results);
997  std::vector<NodeDefBuilder::NodeOut> actual_outputs(num_results);
998  for (int i = 0; i < num_results; ++i) {
999    const NodeSlot& node_slot = results_by_num[i];
1000    result_dtypes[i] = node_slot.node->output_type(node_slot.slot);
1001    expected_outputs[i] =
1002        NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(),
1003                                node_slot.slot, result_dtypes[i]);
1004    actual_outputs[i] =
1005        NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]);
1006  }
1007  // Assign the parallel check op to a CPU on the same task as the cluster it is
1008  // checking.
1009  string device, dummy;
1010  if (!DeviceNameUtils::SplitDeviceName(
1011          call_node_inputs_->assigned_device_name(), &device, &dummy)) {
1012    return errors::InvalidArgument("Could not parse device name");
1013  }
1014  strings::StrAppend(&device, "/cpu:0");
1015
1016  NodeDef check_def;
1017  TF_RETURN_IF_ERROR(
1018      NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(),
1019                                                        "_parallel_check")),
1020                     "ParallelCheck")
1021          .Device(device)
1022          .Attr("T", result_dtypes)
1023          .Input(expected_outputs)
1024          .Input(actual_outputs)
1025          .Finalize(&check_def));
1026
1027  Status s;
1028  Node* check_op = graph_out->AddNode(check_def, &s);
1029  if (!s.ok()) return s;
1030  check_op->set_assigned_device_name(device);
1031
1032  // TODO(phawkins): it seems redundant to call AddEdge as well as
1033  // pass Inputs to the NodeDefBuilder, but I have been unable to find a
1034  // way to avoid it.
1035  for (int i = 0; i < num_results; ++i) {
1036    const NodeSlot& node_slot = results_by_num[i];
1037    graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op,
1038                       i);
1039    graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i);
1040  }
1041
1042  call_node_outputs_ = check_op;
1043  return Status::OK();
1044}
1045
1046Status Encapsulator::Subgraph::AddFunctionCallNode(
1047    const std::unordered_map<const Node*, Node*>& node_images,
1048    bool parallel_checking, Graph* graph_out) {
1049  Status s;
1050  call_node_inputs_ = graph_out->AddNode(call_node_def_, &s);
1051  if (!s.ok()) return s;
1052
1053  // Copy the assigned device and the key_annotation over.
1054  call_node_inputs_->set_assigned_device_name(device_);
1055  call_node_outputs_ = call_node_inputs_;
1056
1057  if (parallel_checking) {
1058    TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out));
1059  }
1060  return Status::OK();
1061}
1062
1063Status Encapsulator::Subgraph::AddRecvAtHostNode(
1064    const string& subgraph_name, const string& oc_subgraph_name,
1065    OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) {
1066  std::vector<DataType> dtypes(oc_subgraph->inputs.size(), DT_INVALID);
1067
1068  for (const auto& input : oc_subgraph->inputs) {
1069    const Node* src_node = input.first.node;
1070    int src_slot = input.first.slot;
1071    int input_index = input.second;
1072
1073    DataType dtype = src_node->output_type(src_slot);
1074    dtypes[input_index] = dtype;
1075  }
1076
1077  NodeDef recv_def;
1078  NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
1079                                         "_", oc_subgraph_name, "_recv"),
1080                         kRecvAtHostOp);
1081  builder.Attr("Toutputs", dtypes);
1082  builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
1083                                      "_", oc_subgraph_name));
1084  Status s = builder.Finalize(&recv_def);
1085  if (!s.ok()) return s;
1086
1087  oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s);
1088  if (!s.ok()) return s;
1089  oc_subgraph->recv_at_host->set_assigned_device_name(device_);
1090
1091  // Add a control dependency forcing the RecvAtHost to run before the subgraph
1092  // completes. This has no effect on execution order but prevents the
1093  // RecvAtHost being pruned.
1094  TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out));
1095  graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_);
1096
1097  return Status::OK();
1098}
1099
1100Status Encapsulator::Subgraph::AddSendFromHostNode(
1101    const std::unordered_map<const Node*, Node*>& node_images,
1102    const string& subgraph_name, const string& oc_subgraph_name,
1103    OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) {
1104  std::vector<DataType> dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID);
1105  std::vector<NodeDefBuilder::NodeOut> inputs(
1106      oc_subgraph->outputs_by_src.size());
1107
1108  for (const auto& output : oc_subgraph->outputs_by_src) {
1109    const Node* src_node = output.first.node;
1110    Node* src_image = node_images.at(src_node);
1111    int src_slot = output.first.slot;
1112    int output_index = output.second;
1113
1114    DataType dtype = src_node->output_type(src_slot);
1115    dtypes[output_index] = dtype;
1116    inputs[output_index].Reset(src_image->name(), src_slot, dtype);
1117  }
1118
1119  NodeDef send_def;
1120  NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
1121                                         "_", oc_subgraph_name, "_send"),
1122                         kSendFromHostOp);
1123  builder.Attr("Tinputs", dtypes);
1124  builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
1125                                      "_", oc_subgraph_name));
1126  builder.Input(inputs);
1127  Status s = builder.Finalize(&send_def);
1128  if (!s.ok()) return s;
1129
1130  oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s);
1131  if (!s.ok()) return s;
1132  oc_subgraph->send_from_host->set_assigned_device_name(device_);
1133
1134  // Add a control dependency forcing the SendFromHost to run before the
1135  // subgraph completes. This has no effect on execution order but prevents the
1136  // RecvAtHost being pruned.
1137  TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out));
1138  graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_);
1139
1140  return Status::OK();
1141}
1142
1143Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes(
1144    const string& subgraph_name,
1145    const std::unordered_map<const Node*, Node*>& node_images,
1146    Graph* graph_out) {
1147  for (auto& outside_compilation_subgraph_entry :
1148       outside_compilation_subgraphs_) {
1149    const string& oc_name = outside_compilation_subgraph_entry.first;
1150    OutsideCompilationSubgraph& oc_subgraph =
1151        outside_compilation_subgraph_entry.second;
1152
1153    if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) {
1154      TF_RETURN_IF_ERROR(
1155          AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out));
1156    }
1157
1158    if (!oc_subgraph.outputs_by_src.empty() ||
1159        !oc_subgraph.control_outputs.empty()) {
1160      TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name,
1161                                             oc_name, &oc_subgraph, graph_out));
1162    }
1163  }
1164  return Status::OK();
1165}
1166
1167void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames(
1168    std::vector<string>* names) const {
1169  for (auto& entry : outside_compilation_subgraphs_) {
1170    names->push_back(entry.first);
1171  }
1172}
1173
1174Status Encapsulator::GetFunctionNameAttr(
1175    Node const* node, string* attr, string* outside_compilation_attr) const {
1176  Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
1177  if (s.code() == error::Code::NOT_FOUND) {
1178    // Return empty attr if there's no group_attribute.
1179    attr->clear();
1180  } else {
1181    TF_RETURN_IF_ERROR(s);
1182  }
1183  bool has_group_attr = s.ok();
1184  s = GetNodeAttr(node->attrs(), outside_compilation_attribute_,
1185                  outside_compilation_attr);
1186  if (s.code() == error::Code::NOT_FOUND) {
1187    // Return empty attr if there's no outside_compilation attribute.
1188    outside_compilation_attr->clear();
1189  } else {
1190    TF_RETURN_IF_ERROR(s);
1191    if (!has_group_attr) {
1192      return errors::InvalidArgument(
1193          "Node ", node->name(), " has ", outside_compilation_attribute_,
1194          " attribute but no ", group_attribute_, " attribute.");
1195    }
1196  }
1197  return Status::OK();
1198}
1199
1200bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) {
1201  return !func_id.empty() && outside_compilation_id.empty();
1202}
1203
1204Status Encapsulator::CopySubgraphNodes(
1205    std::unordered_map<const Node*, Node*>* node_images) {
1206  for (Node* node : graph_in_->op_nodes()) {
1207    string func_id;
1208    string outside_compilation_id;
1209    TF_RETURN_IF_ERROR(
1210        GetFunctionNameAttr(node, &func_id, &outside_compilation_id));
1211    if (!IsInSubgraph(func_id, outside_compilation_id)) continue;
1212
1213    Subgraph& subgraph = subgraphs_[func_id];
1214    Node* image = subgraph.MakeNodeImage(graph_in_, node);
1215    image->ClearAttr(group_attribute_);
1216    (*node_images)[node] = image;
1217  }
1218  return Status::OK();
1219}
1220
1221Status Encapsulator::CopySubgraphEdges(
1222    const std::unordered_map<const Node*, Node*>& node_images,
1223    std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
1224  for (const Edge* edge : graph_in_->edges()) {
1225    string src_func_id;
1226    string src_outside_compilation_id;
1227    TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id,
1228                                           &src_outside_compilation_id));
1229    string dst_func_id;
1230    string dst_outside_compilation_id;
1231    TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id,
1232                                           &dst_outside_compilation_id));
1233    Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
1234    Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
1235
1236    // Copy edges that are local to a subgraph.
1237    if (IsInSubgraph(src_func_id, src_outside_compilation_id) &&
1238        IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
1239        src_func_id == dst_func_id) {
1240      Graph* g = subgraphs_[src_func_id].GetGraph();
1241      if (edge->IsControlEdge()) {
1242        g->AddControlEdge(src_image, dst_image);
1243      } else {
1244        g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
1245      }
1246      continue;
1247    }
1248
1249    // Record 'src' as an output of its subgraph, if applicable.
1250    if (IsInSubgraph(src_func_id, src_outside_compilation_id)) {
1251      if (!edge->IsControlEdge()) {
1252        DataType dtype = edge->src()->output_type(edge->src_output());
1253        if (IsRefType(dtype)) {
1254          return errors::InvalidArgument(
1255              "Ref Tensors (e.g., Variables) are not supported as results: "
1256              "tensor ",
1257              edge->src()->name(), ":", edge->src_output());
1258        }
1259      }
1260
1261      Subgraph& src_subgraph = subgraphs_[src_func_id];
1262      if (src_func_id == dst_func_id) {
1263        // src is in the subgraph and dst is outside_compilation in the same
1264        // subgraph.
1265        src_subgraph.RecordOutsideCompilationInputOrControl(
1266            dst_outside_compilation_id, edge);
1267      } else {
1268        // Ignore control edges leaving the subgraph. We will lift them onto the
1269        // enclosing call operators in BuildOutputGraph().
1270        if (!edge->IsControlEdge()) {
1271          TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
1272        }
1273      }
1274    }
1275
1276    // Record 'dst' as an input of its subgraph, if applicable.
1277    if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) {
1278      // Look at the type of the destination not the source, since Ref output
1279      // Tensors can be automatically cast to non-Ref Tensors at the
1280      // destination.
1281      if (!edge->IsControlEdge()) {
1282        DataType dtype = edge->dst()->input_type(edge->dst_input());
1283        if (IsRefType(dtype)) {
1284          return errors::InvalidArgument(
1285              "Ref Tensors (e.g., Variables) are not supported as args: "
1286              "tensor ",
1287              edge->src()->name(), ":", edge->src_output());
1288        }
1289      }
1290
1291      Subgraph& dst_subgraph = subgraphs_[dst_func_id];
1292      if (src_func_id == dst_func_id) {
1293        // dst is in the subgraph and src is outside_compilation in the same
1294        // subgraph.
1295        dst_subgraph.RecordOutsideCompilationOutputOrControl(
1296            src_outside_compilation_id, edge);
1297      } else {
1298        // Ignore control edges entering the subgraph. We will lift them onto
1299        // the enclosing call operators in BuildOutputGraph().
1300        if (!edge->IsControlEdge()) {
1301          TF_RETURN_IF_ERROR(
1302              dst_subgraph.RecordArg(edge, node_images, src_arg_pairs));
1303        }
1304      }
1305    }
1306  }
1307  return Status::OK();
1308}
1309
1310Status Encapsulator::SplitIntoSubgraphs() {
1311  Status s;
1312
1313  // Map from input graph nodes to subgraph nodes.
1314  std::unordered_map<const Node*, Node*> node_images;
1315
1316  // Each entry of src_arg_pairs is a pair whose first element is a node in the
1317  // original graph that has an output edge in the subgraph, and whose second
1318  // element is the arg node in the subgraph that it sends to. The vector will
1319  // be filled in below in AddArgs.
1320  std::vector<std::pair<const Node*, Node*>> src_arg_pairs;
1321
1322  TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
1323  TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
1324
1325  // For each subgraph, add the nodes that deal with inputs and outputs its
1326  // nested outside_compilation subgraphs. These could not be added earlier
1327  // during CopySubgraphEdges since we need to discover all the types of the
1328  // inputs and outputs for an outside_compilation subgraph before creating a
1329  // single input and output node for it.
1330  for (auto& entry : subgraphs_) {
1331    Subgraph& subgraph = entry.second;
1332    TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images));
1333  }
1334
1335  MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
1336
1337  for (auto& entry : subgraphs_) {
1338    Subgraph& subgraph = entry.second;
1339    FixupSourceAndSinkEdges(subgraph.GetGraph());
1340  }
1341
1342  return s;
1343}
1344
1345Status Encapsulator::BuildFunctionDefs(
1346    const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
1347    FunctionLibraryDefinition* library) {
1348  for (auto& subgraph_entry : subgraphs_) {
1349    string name = subgraph_entry.first;
1350    Subgraph& subgraph = subgraph_entry.second;
1351    TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef(
1352        name, rewrite_subgraph_fn, reuse_existing_functions, library));
1353  }
1354  return Status::OK();
1355}
1356
1357Status Encapsulator::CopyNodesToOutputGraph(
1358    bool parallel_checking, Graph* graph_out,
1359    std::unordered_map<const Node*, Node*>* node_images) {
1360  for (Node* node : graph_in_->op_nodes()) {
1361    string func_id;
1362    string outside_compilation_id;
1363    TF_RETURN_IF_ERROR(
1364        GetFunctionNameAttr(node, &func_id, &outside_compilation_id));
1365
1366    // Don't copy nodes that going to be encapsulated, unless parallel checking
1367    // is enabled.
1368    if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking)
1369      continue;
1370
1371    Node* image = graph_out->CopyNode(node);
1372    if (!outside_compilation_id.empty()) {
1373      if (parallel_checking) {
1374        return errors::InvalidArgument(
1375            "Parallel checking is not supported when outside_compilation "
1376            "clusters are present.");
1377      }
1378      image->ClearAttr(group_attribute_);
1379      image->ClearAttr(outside_compilation_attribute_);
1380    }
1381    (*node_images)[node] = image;
1382  }
1383  (*node_images)[graph_in_->source_node()] = graph_out->source_node();
1384  (*node_images)[graph_in_->sink_node()] = graph_out->sink_node();
1385  return Status::OK();
1386}
1387
1388Status Encapsulator::AddFunctionCallNodes(
1389    const std::unordered_map<const Node*, Node*>& node_images,
1390    bool parallel_checking, Graph* graph_out) {
1391  for (auto& subgraph_entry : subgraphs_) {
1392    TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode(
1393        node_images, parallel_checking, graph_out));
1394  }
1395  return Status::OK();
1396}
1397
1398Status Encapsulator::AddOutsideCompilationHostIONodes(
1399    const std::unordered_map<const Node*, Node*>& node_images,
1400    Graph* graph_out) {
1401  for (auto& subgraph_entry : subgraphs_) {
1402    const string& subgraph_name = subgraph_entry.first;
1403    Subgraph& subgraph = subgraph_entry.second;
1404    TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes(
1405        subgraph_name, node_images, graph_out));
1406  }
1407  return Status::OK();
1408}
1409
1410Status Encapsulator::FindOutputImageOfEdgeSrc(
1411    const string& src_func_id, const string& src_outside_compilation_id,
1412    const string& dst_func_id, const string& dst_outside_compilation_id,
1413    const std::unordered_map<const Node*, Node*>& node_images,
1414    const Node* original_src_node, Node** src_image) {
1415  if (IsInSubgraph(src_func_id, src_outside_compilation_id)) {
1416    if (dst_func_id == src_func_id) {
1417      // The edge is from a subgraph to an outside_compilation cluster in the
1418      // same subgraph so use the appropriate _RecvAtHost node in the output
1419      // graph.
1420      TF_RET_CHECK(!dst_outside_compilation_id.empty());
1421      *src_image = subgraphs_.at(src_func_id)
1422                       .GetRecvAtHostNode(dst_outside_compilation_id);
1423    } else {
1424      // The edge is from a subgraph to a regular node in the output graph so
1425      // use the subgraph's call node output.
1426      *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs();
1427    }
1428  } else {
1429    // The source of the edge is in the output graph so use the node image in
1430    // the output graph.
1431    *src_image = node_images.at(original_src_node);
1432  }
1433  return Status::OK();
1434}
1435
1436int Encapsulator::FindOutputSlotOfEdgeSrc(
1437    const string& src_func_id, const string& src_outside_compilation_id,
1438    const string& dst_func_id, const string& dst_outside_compilation_id,
1439    const Edge* edge) {
1440  if (IsInSubgraph(src_func_id, src_outside_compilation_id)) {
1441    const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
1442    if (src_func_id == dst_func_id) {
1443      // 'src' is in a subgraph and 'dst' is outside_compilation in the same
1444      // subgraph. Use the corresponding _RecvAtHost output instead.
1445      return src_subgraph.GetRecvAtHostSlot(dst_outside_compilation_id, edge);
1446    } else {
1447      // 'src' is in a subgraph and 'dst' is a regular node in the output
1448      // graph. Use the corresponding call output instead.
1449      return src_subgraph.GetResultIndexForEdge(edge);
1450    }
1451  } else {
1452    // The source of the edge is in the output graph so use the regular edge
1453    // slot.
1454    return edge->src_output();
1455  }
1456}
1457
1458Status Encapsulator::FindOutputImageOfEdgeDst(
1459    const string& src_func_id, const string& src_outside_compilation_id,
1460    const string& dst_func_id, const string& dst_outside_compilation_id,
1461    const std::unordered_map<const Node*, Node*>& node_images,
1462    const Node* original_dst_node, Node** dst_image) {
1463  if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) {
1464    if (src_func_id == dst_func_id) {
1465      // The edge is to a subgraph from an outside_compilation cluster in the
1466      // same subgraph so use the appropriate _SendFromHost node in the output
1467      // graph.
1468      TF_RET_CHECK(!src_outside_compilation_id.empty());
1469      *dst_image = subgraphs_.at(dst_func_id)
1470                       .GetSendFromHostNode(src_outside_compilation_id);
1471    } else {
1472      // The edge is to a subgraph from a regular node in the output graph so
1473      // use the subgraph's call node input.
1474      *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs();
1475    }
1476  } else {
1477    // The destination of the edge is in the output graph so use the node image
1478    // in the output graph.
1479    *dst_image = node_images.at(original_dst_node);
1480  }
1481  return Status::OK();
1482}
1483
1484int Encapsulator::FindOutputSlotOfEdgeDst(
1485    const string& src_func_id, const string& src_outside_compilation_id,
1486    const string& dst_func_id, const string& dst_outside_compilation_id,
1487    const Edge* edge) {
1488  if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) {
1489    const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
1490    if (dst_func_id == src_func_id) {
1491      // 'dst' is in a subgraph and 'src' is outside_compilation in the same
1492      // subgraph. Use the corresponding _SendFromHost input instead.
1493      return dst_subgraph.GetSendFromHostSlot(src_outside_compilation_id, edge);
1494    } else {
1495      // 'dst' is in a subgraph and 'src' is a regular node in the output
1496      // graph. Use the corresponding call input instead.
1497      return dst_subgraph.GetArgIndexForEdge(edge);
1498    }
1499  } else {
1500    // The destination of the edge is in the output graph so use the regular
1501    // edge slot.
1502    return edge->dst_input();
1503  }
1504}
1505
1506Status Encapsulator::CopyEdgeToOutputGraph(
1507    const Edge* edge, const string& src_func_id,
1508    const string& src_outside_compilation_id, const string& dst_func_id,
1509    const string& dst_outside_compilation_id,
1510    const std::unordered_map<const Node*, Node*>& node_images,
1511    bool parallel_checking, Graph* graph_out,
1512    std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
1513        edges_added) {
1514  Node* src_image;
1515  TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
1516      src_func_id, src_outside_compilation_id, dst_func_id,
1517      dst_outside_compilation_id, node_images, edge->src(), &src_image));
1518  Node* dst_image;
1519  TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst(
1520      src_func_id, src_outside_compilation_id, dst_func_id,
1521      dst_outside_compilation_id, node_images, edge->dst(), &dst_image));
1522
1523  // If this is a control edge then copy it and return. Lift control edges onto
1524  // the enclosing call operator.
1525  if (edge->IsControlEdge()) {
1526    // Add the control edge, if we have not already added it, using the images
1527    // determined above (potentially call operators or RecvAtHost/SendFromHost).
1528    if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1))
1529            .second) {
1530      graph_out->AddControlEdge(src_image, dst_image);
1531    }
1532
1533    // If parallel checking is enabled, also add a control edge to the
1534    // corresponding parallel check op.
1535    if (parallel_checking) {
1536      graph_out->AddControlEdge(src_image, node_images.at(edge->dst()));
1537    }
1538    return Status::OK();
1539  }
1540
1541  int src_output =
1542      FindOutputSlotOfEdgeSrc(src_func_id, src_outside_compilation_id,
1543                              dst_func_id, dst_outside_compilation_id, edge);
1544
1545  int dst_input =
1546      FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id,
1547                              dst_func_id, dst_outside_compilation_id, edge);
1548
1549  if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
1550      parallel_checking) {
1551    // If we are parallel checking, also feed the tensor as an input to the
1552    // corresponding parallel check subgraph.
1553    graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()),
1554                       edge->dst_input());
1555  }
1556
1557  // Add the edge, if we have not already added it.
1558  if (edges_added
1559          ->emplace(NodeSlot(src_image, src_output),
1560                    NodeSlot(dst_image, dst_input))
1561          .second) {
1562    graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
1563  }
1564  return Status::OK();
1565}
1566
1567Status Encapsulator::AddEdgesToOutputGraph(
1568    const std::unordered_map<const Node*, Node*>& node_images,
1569    bool parallel_checking, Graph* graph_out) {
1570  // Set of edges already added to the output graph, represented as (src, dst)
1571  // pairs. We use the set to deduplicate edges; multiple edges in the input
1572  // graph may map to one edge in the output graph.
1573  std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>
1574      edges_added;
1575
1576  for (const Edge* edge : graph_in_->edges()) {
1577    string src_func_id;
1578    string src_outside_compilation_id;
1579    TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id,
1580                                           &src_outside_compilation_id));
1581    string dst_func_id;
1582    string dst_outside_compilation_id;
1583    TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id,
1584                                           &dst_outside_compilation_id));
1585
1586    // Ignore edges that are strictly contained within one subgraph, unless
1587    // we are constructing parallel check graphs.
1588    if (IsInSubgraph(src_func_id, src_outside_compilation_id) &&
1589        IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
1590        src_func_id == dst_func_id) {
1591      if (parallel_checking) {
1592        Node* src_image = node_images.at(edge->src());
1593        Node* dst_image = node_images.at(edge->dst());
1594        if (edge->IsControlEdge()) {
1595          graph_out->AddControlEdge(src_image, dst_image);
1596        } else {
1597          graph_out->AddEdge(src_image, edge->src_output(), dst_image,
1598                             edge->dst_input());
1599        }
1600      }
1601      continue;
1602    }
1603
1604    // We have an edge that crosses a cluster boundary or is entirely within the
1605    // unclustered graph.
1606    TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
1607        edge, src_func_id, src_outside_compilation_id, dst_func_id,
1608        dst_outside_compilation_id, node_images, parallel_checking, graph_out,
1609        &edges_added));
1610  }
1611
1612  for (auto& subgraph_entry : subgraphs_) {
1613    Subgraph& subgraph = subgraph_entry.second;
1614    subgraph.ConnectSequencerToOutputs(graph_out);
1615  }
1616
1617  return Status::OK();
1618}
1619
1620namespace {
1621
1622// Adds a dummy Const node to graph_out. The "constant" has the type of
1623// data_type and the shape indicated in 'shape'. The dummy node is not a valid
1624// Const node because it does not have any value defined, but this doesn't
1625// matter because it will only be used subsequently for shape inference. (It
1626// would be possible to add a switch statement over data_type to create a value
1627// for the constant, but that would entail maintaining the logic as new types
1628// are added, and is not necessary.)
1629Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
1630                         Graph* graph_out) {
1631  TensorProto dummy_proto;
1632  dummy_proto.set_dtype(data_type);
1633  *dummy_proto.mutable_tensor_shape() = shape;
1634  // Don't set any value field in the proto, since it is only going to be used
1635  // for shape inference.
1636
1637  GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
1638  NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
1639                           options.op_registry());
1640  node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
1641  return options.FinalizeBuilder(&node_builder);
1642}
1643
1644// Adds a copy of node_in to graph_out and adds the mapping to
1645// copied_node_images.
1646Status CopyShapeInferenceNodeToGraph(
1647    Node* node_in, const Node* send_node,
1648    const std::unordered_map<Node*, Node*>& dummy_node_images,
1649    FunctionLibraryDefinition* library,
1650    std::unordered_map<Node*, Node*>* copied_node_images, Graph* graph_out) {
1651  // Once all the ancestor nodes have been added to graph_out, add this node
1652  // and connect it to its ancestors.
1653  Node* node_out = graph_out->CopyNode(node_in);
1654  (*copied_node_images)[node_in] = node_out;
1655  // Don't bother to build the shape inference graph if there's a node with no
1656  // shape inference function, since it would just result in an error later at
1657  // compile time.
1658  const OpRegistrationData* op_reg_data;
1659  TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data));
1660  if (op_reg_data->shape_inference_fn == nullptr) {
1661    return errors::InvalidArgument(
1662        "Shape inference is not possible for outside_compilation "
1663        "SendFromHost node ",
1664        send_node->name(), " because it depends on node ", node_in->name(),
1665        " which does not have a shape inference function registered.");
1666  }
1667  // Add all the edges to the newly copied node.
1668  for (const Edge* in_edge : node_in->in_edges()) {
1669    if (!in_edge->IsControlEdge()) {
1670      Node* src = in_edge->src();
1671      const auto iter = dummy_node_images.find(src);
1672      if (iter == dummy_node_images.end()) {
1673        // The src is a copied node so use the original output port.
1674        graph_out->AddEdge((*copied_node_images)[in_edge->src()],
1675                           in_edge->src_output(), node_out,
1676                           in_edge->dst_input());
1677      } else {
1678        // The src is a dummy node so use output port 0.
1679        graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input());
1680      }
1681    }
1682  }
1683  return Status::OK();
1684}
1685
1686}  // namespace
1687
1688Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
1689    const Graph& graph_in, const ShapeRefiner& shape_refiner,
1690    const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
1691    FunctionLibraryDefinition* library,
1692    std::vector<TensorShapeProto>* static_shape_out,
1693    std::unique_ptr<GraphDef>* graphdef_out) {
1694  // Maps from nodes in graph_in to nodes in graph_out.
1695  //
1696  // When an edge has fully defined shape the source node in graph_in is
1697  // replaced in graph_out by a dummy constant node. The mapping from nodes
1698  // in graph_in to dummy nodes is stored in dummy_node_images.
1699  //
1700  // When a node in graph_in has at least one ancestor that doesn't have fully
1701  // defined shape, it is copied into graph_out. The mapping from nodes in
1702  // graph_in to copied nodes is stored in copied_node_images.
1703  //
1704  // The two types of node are treated differently because, when adding edges to
1705  // graph_out, an output from a dummy node always uses port 0, whereas an
1706  // output from a copied node uses the same port that was used in graph_in.
1707  std::unordered_map<Node*, Node*> dummy_node_images;
1708  std::unordered_map<Node*, Node*> copied_node_images;
1709
1710  std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry()));
1711  graph_out->set_versions(graph_in.versions());
1712  static_shape_out->resize(send_node->num_inputs());
1713
1714  // We don't use the standard ReverseDFS because we want to cut off traversal
1715  // whenever we find an output with fully defined shape.
1716  // TODO(misard) make this work properly in the presence of control flow.
1717  struct Work {
1718    Node* node;
1719    bool leave;  // Are we entering or leaving node?
1720  };
1721  std::vector<Work> stack({{send_node, false}});
1722  std::vector<bool> visited(graph_in.num_node_ids(), false);
1723  while (!stack.empty()) {
1724    Work w = stack.back();
1725    stack.pop_back();
1726    Node* n = w.node;
1727
1728    if (w.leave) {
1729      TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph(
1730          n, send_node, dummy_node_images, library, &copied_node_images,
1731          graph_out.get()));
1732    } else {
1733      if (visited[n->id()]) continue;
1734      visited[n->id()] = true;
1735
1736      // Arrange to revisit when all done with all inputs.
1737      stack.push_back(Work{n, true});
1738
1739      bool has_parent_with_unknown_shape = false;
1740      for (const Edge* in_edge : n->in_edges()) {
1741        if (!in_edge->IsControlEdge()) {
1742          Node* src_node = in_edge->src();
1743          int src_port = in_edge->src_output();
1744          shape_inference::InferenceContext* context =
1745              shape_refiner.GetContext(src_node);
1746          shape_inference::ShapeHandle shape = context->output(src_port);
1747          if (context->FullyDefined(shape)) {
1748            // This ancestor has known shape, so instead of adding it to the
1749            // stack, add a dummy node with that shape to graph_out and
1750            // continue.
1751            TensorShapeProto proto;
1752            context->ShapeHandleToProto(shape, &proto);
1753            dummy_node_images[src_node] = AddDummyShapedNode(
1754                src_node->output_type(src_port), proto, graph_out.get());
1755            if (n == send_node) {
1756              (*static_shape_out)[in_edge->dst_input()] = proto;
1757            }
1758          } else {
1759            if (!visited[src_node->id()]) {
1760              has_parent_with_unknown_shape = true;
1761              stack.push_back({src_node, false});
1762            }
1763          }
1764        }
1765      }
1766      if (!has_parent_with_unknown_shape) {
1767        if (n == send_node) {
1768          // The shapes of all the inputs to send_node are statically known. We
1769          // won't have to do any inference at compile time so return now: the
1770          // shapes were stored in static_shape_out above.
1771          graphdef_out->reset();
1772          return Status::OK();
1773        } else {
1774          // Any shape that is being processed is either the original send node
1775          // or has at least one output with statically-unknown shape. If the
1776          // latter and it doesn't have any inputs with statically-unknown
1777          // shape, then check that it is of the recv nodes that we can fill in
1778          // the shape of at run-time later. If it isn't one of those, then we
1779          // won't have any additional knowledge at compile time, so we already
1780          // know we won't be able to do shape inference and we can return an
1781          // error now.
1782          if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) {
1783            return errors::InvalidArgument(
1784                "Shape inference is not possible for outside_compilation "
1785                "SendFromHost node ",
1786                send_node->name(), " because shape of node ", n->name(),
1787                " will not be known at compilation time.");
1788          }
1789        }
1790      }
1791    }
1792  }
1793
1794  graphdef_out->reset(new GraphDef());
1795  graph_out->ToGraphDef(graphdef_out->get());
1796
1797  return Status::OK();
1798}
1799
1800Status Encapsulator::MakePrunedGraphCopyAndInline(
1801    const Graph& graph, const std::vector<Node*>& sink_nodes,
1802    std::unique_ptr<Graph>* pruned_graph,
1803    std::unordered_map<const Node*, Node*>* node_images,
1804    FunctionLibraryDefinition* library) {
1805  // First copy all ancestor nodes of sink_nodes into a new graph.
1806  pruned_graph->reset(new Graph(library));
1807  (*pruned_graph)->set_versions(graph.versions());
1808  ReverseDFSFrom(graph, sink_nodes,
1809                 /*enter=*/nullptr,
1810                 /*leave=*/[&](Node* n) {
1811                   if (!n->IsSource()) {
1812                     Node* copied = (*pruned_graph)->CopyNode(n);
1813                     node_images->emplace(n, copied);
1814                   }
1815                 });
1816
1817  // Add all the edges between copied nodes.
1818  for (auto entry : *node_images) {
1819    const Node* orig = entry.first;
1820    Node* image = entry.second;
1821    for (const Edge* out_edge : orig->out_edges()) {
1822      auto iter = node_images->find(out_edge->dst());
1823      if (iter != node_images->end()) {
1824        // The source and destination are both in the copied graph.
1825        (*pruned_graph)
1826            ->AddEdge(image, out_edge->src_output(), iter->second,
1827                      out_edge->dst_input());
1828      }
1829    }
1830  }
1831
1832  // Find all the function call nodes, and inline them.
1833  std::vector<Node*> function_nodes;
1834  for (auto node : (*pruned_graph)->nodes()) {
1835    const OpRegistrationData* op_reg_data;
1836    TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
1837    if (op_reg_data->is_function_op) {
1838      function_nodes.push_back(node);
1839    }
1840  }
1841  for (auto node : function_nodes) {
1842    VLOG(2) << "Inlining function " << node->name();
1843    const FunctionDef* fdef = library->Find(node->type_string());
1844    if (fdef == nullptr) {
1845      return errors::Internal("Failed to find function ", node->type_string(),
1846                              " in function library.");
1847    }
1848    FunctionBody* fbody = nullptr;
1849    TF_RETURN_IF_ERROR(
1850        FunctionDefToBodyHelper(*fdef, node->attrs(), library,
1851                                [library](const string& op, const OpDef** sig) {
1852                                  return library->LookUpOpDef(op, sig);
1853                                },
1854                                &fbody));
1855    InlineFunctionBody(*library, pruned_graph->get(), node, fbody);
1856    delete fbody;
1857  }
1858
1859  return Status::OK();
1860}
1861
1862Status Encapsulator::MakeGraphForOutsideCompilationSends(
1863    const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
1864    ShapeRefiner* shape_refiner,
1865    std::unordered_map<const Node*, Node*>* node_images,
1866    FunctionLibraryDefinition* library) {
1867  // Find all the send_from_host nodes in all subgraphs, to use as roots for the
1868  // pruning.
1869  std::vector<Node*> send_from_host_nodes;
1870  for (auto& subgraph_entry : subgraphs_) {
1871    Subgraph& subgraph = subgraph_entry.second;
1872    std::vector<string> outside_compilation_names;
1873    subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
1874    for (const auto& name : outside_compilation_names) {
1875      Node* send_node = subgraph.GetSendFromHostNode(name);
1876      if (send_node != nullptr) {
1877        send_from_host_nodes.push_back(send_node);
1878      }
1879    }
1880  }
1881
1882  // Make a copy of all the graph nodes needed to evaluate the send_from_host
1883  // nodes, inlining any functions as needed.
1884  TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
1885      graph, send_from_host_nodes, pruned_graph, node_images, library));
1886
1887  // Perform shape inference on the pruned graph.
1888  shape_refiner->set_require_shape_inference_fns(false);
1889  FixupSourceAndSinkEdges(pruned_graph->get());
1890  std::vector<Node*> post_order;
1891  GetReversePostOrder(*(*pruned_graph), &post_order);
1892  for (auto node : post_order) {
1893    // Ignore the status returned by the shape_refiner. At this point we want
1894    // the best effort shapes, even if no shape function is registered for a
1895    // node.
1896    Status status = shape_refiner->AddNode(node);
1897    if (!status.ok()) {
1898      VLOG(1) << "Shape inference failed for node: " << status;
1899    }
1900  }
1901
1902  return Status::OK();
1903}
1904
1905Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
1906    Graph* graph_out, FunctionLibraryDefinition* library) {
1907  std::unique_ptr<Graph> pruned_graph;
1908  ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
1909  std::unordered_map<const Node*, Node*> node_images;
1910  TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
1911      *graph_out, &pruned_graph, &shape_refiner, &node_images, library));
1912
1913  for (auto& subgraph_entry : subgraphs_) {
1914    Subgraph& subgraph = subgraph_entry.second;
1915    // Find all the recv_at_host nodes in this subgraph.
1916    std::vector<string> outside_compilation_names;
1917    subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
1918    std::unordered_set<string> recv_at_host_names;
1919    for (const auto& name : outside_compilation_names) {
1920      Node* recv_node = subgraph.GetRecvAtHostNode(name);
1921      if (recv_node != nullptr) {
1922        recv_at_host_names.insert(recv_node->name());
1923      }
1924    }
1925    // For each send_from_host node, do as much shape inference as possible
1926    // without knowing the shape of the recv_at_host nodes, and store the
1927    // result, along with enough information to complete the job at compile time
1928    // once the recv_at_host shapes are known.
1929    for (const auto& name : outside_compilation_names) {
1930      Node* send_node = subgraph.GetSendFromHostNode(name);
1931      std::vector<TensorShapeProto> static_shape;
1932      std::unique_ptr<GraphDef> graphdef;
1933      if (send_node != nullptr) {
1934        TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
1935            *pruned_graph, shape_refiner, recv_at_host_names,
1936            node_images[send_node], library, &static_shape, &graphdef));
1937        if (graphdef == nullptr) {
1938          VLOG(2) << "Send node  " << send_node->name() << " shapes";
1939          for (int i = 0; i < static_shape.size(); ++i) {
1940            VLOG(2) << static_shape[i].DebugString();
1941          }
1942        } else {
1943          VLOG(2) << "Send node " << send_node->name() << " graph\n"
1944                  << graphdef->DebugString();
1945        }
1946      }
1947      TF_RETURN_IF_ERROR(
1948          subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get()));
1949    }
1950    if (!outside_compilation_names.empty()) {
1951      TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library));
1952    }
1953  }
1954
1955  return Status::OK();
1956}
1957
1958Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out,
1959                                      FunctionLibraryDefinition* library) {
1960  // Map from nodes in the input graph to nodes in the output graph.
1961  std::unordered_map<const Node*, Node*> node_images;
1962
1963  TF_RETURN_IF_ERROR(
1964      CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images));
1965  TF_RETURN_IF_ERROR(
1966      AddFunctionCallNodes(node_images, parallel_checking, graph_out));
1967  TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out));
1968  TF_RETURN_IF_ERROR(
1969      AddEdgesToOutputGraph(node_images, parallel_checking, graph_out));
1970
1971  TF_RETURN_IF_ERROR(
1972      GetShapeInfoForOutsideCompilationSends(graph_out, library));
1973
1974  return Status::OK();
1975}
1976
1977}  // anonymous namespace
1978
1979Status EncapsulateSubgraphsInFunctions(
1980    string group_attribute, string outside_compilation_attribute,
1981    const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
1982    bool parallel_checking, bool reuse_existing_functions,
1983    std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
1984  Status s;
1985
1986  Encapsulator encapsulator(std::move(group_attribute),
1987                            std::move(outside_compilation_attribute),
1988                            &graph_in);
1989  TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs());
1990
1991  TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
1992      rewrite_subgraph_fn, reuse_existing_functions, library));
1993
1994  std::unique_ptr<Graph> out(new Graph(library));
1995  out->set_versions(graph_in.versions());
1996  TF_RETURN_IF_ERROR(
1997      encapsulator.BuildOutputGraph(parallel_checking, out.get(), library));
1998
1999  *graph_out = std::move(out);
2000  return Status::OK();
2001}
2002
2003// Finds the types of the _Arg nodes, indexed by position.
2004static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
2005  for (Node* n : graph.op_nodes()) {
2006    if (n->type_string() == kArgOp) {
2007      int index;
2008      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
2009      if (index < 0 || index >= types->size()) {
2010        return errors::InvalidArgument("Invalid argument number");
2011      }
2012      (*types)[index] = n->output_type(0);
2013    }
2014  }
2015  return Status::OK();
2016}
2017
2018// Renumber the indices of _Arg nodes in a graph, according to
2019// 'permutation' that maps old indices to new indices.
2020static Status RenumberArguments(Graph* graph,
2021                                const std::vector<int>& permutation) {
2022  for (Node* n : graph->op_nodes()) {
2023    if (n->type_string() == kArgOp) {
2024      int index;
2025      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
2026      if (index < 0 || index >= permutation.size()) {
2027        return errors::InvalidArgument("Invalid argument number");
2028      }
2029      n->AddAttr("index", permutation[index]);
2030    }
2031  }
2032  return Status::OK();
2033}
2034
2035Status EncapsulateSubgraphsPass::Run(
2036    const GraphOptimizationPassOptions& options) {
2037  VLOG(1) << "EncapsulateSubgraphsPass::Run";
2038  legacy_flags::EncapsulateSubgraphsPassFlags* flags =
2039      legacy_flags::GetEncapsulateSubgraphsPassFlags();
2040  if (VLOG_IS_ON(1)) {
2041    dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph,
2042                                options.flib_def);
2043  }
2044
2045  std::unique_ptr<Graph> graph_out;
2046  FunctionLibraryDefinition* const library = options.flib_def;
2047
2048  OptimizerOptions opts;
2049  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
2050      new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env,
2051                                        TF_GRAPH_DEF_VERSION, library, opts));
2052  FunctionLibraryRuntime* flr =
2053      pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
2054
2055  auto rewrite_subgraph = [flr](std::unique_ptr<Graph>* subgraph,
2056                                std::vector<int>* input_permutation,
2057                                std::vector<int>* output_permutation,
2058                                NodeDef* node) {
2059    // Optimize the subgraph.
2060    OptimizeGraph(flr, subgraph);
2061
2062    const int num_args = input_permutation->size();
2063    std::vector<bool> const_args(num_args);
2064    TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
2065
2066    DataTypeVector arg_types(num_args);
2067    TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
2068
2069    // Compute a permutation of the arguments such that the constant arguments
2070    // are first.
2071    const int num_consts =
2072        std::count(const_args.begin(), const_args.end(), true);
2073
2074    const int num_resources =
2075        std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
2076    const int num_nonconsts = num_args - num_resources - num_consts;
2077    if (num_nonconsts < 0) {
2078      return errors::Internal("num_nonconsts should be >= 0, was ",
2079                              num_nonconsts);
2080    }
2081
2082    int const_pos = 0;
2083    int arg_pos = num_consts;
2084    int resource_pos = num_consts + num_nonconsts;
2085    for (int i = 0; i < num_args; ++i) {
2086      if (const_args[i]) {
2087        if (arg_types[i] == DT_RESOURCE) {
2088          return errors::Internal(
2089              "Resource arguments cannot be constant (argument ", i, ")");
2090        }
2091        (*input_permutation)[i] = const_pos;
2092        ++const_pos;
2093      } else if (arg_types[i] == DT_RESOURCE) {
2094        (*input_permutation)[i] = resource_pos;
2095        ++resource_pos;
2096      } else {
2097        (*input_permutation)[i] = arg_pos;
2098        ++arg_pos;
2099      }
2100    }
2101
2102    // Renumber argument nodes in the graph.
2103    TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation));
2104
2105    // TODO(phawkins): add a forward is-constant analysis, similarly split
2106    // outputs into host-memory constants and device-memory non-constants.
2107
2108    AddNodeAttr(kXlaCompiledKernelAttr, true, node);
2109    AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
2110    AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
2111    return Status::OK();
2112  };
2113
2114  TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
2115      kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph,
2116      rewrite_subgraph, flags->tf_xla_parallel_checking,
2117      /*reuse_existing_functions=*/false, &graph_out, library));
2118
2119  if (VLOG_IS_ON(1)) {
2120    dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
2121                                options.flib_def);
2122  }
2123
2124  *options.graph = std::move(graph_out);
2125  return Status::OK();
2126}
2127
2128bool IsXlaCompiledKernel(const Node& node) {
2129  bool is_compiled = false;
2130  bool has_compilation_attr =
2131      GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
2132      is_compiled;
2133  return has_compilation_attr ? is_compiled : false;
2134}
2135
2136}  // namespace tensorflow
2137