16#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
18#include <functional>
19#include <memory>
20#include <numeric>
21#include <string>
22#include <unordered_map>
23#include <vector>
25#include "tensorflow/compiler/jit/graph_to_functiondef.h"
26#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
27#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
28#include "tensorflow/compiler/tf2xla/const_analysis.h"
29#include "tensorflow/compiler/tf2xla/dump_graph.h"
30#include "tensorflow/compiler/xla/status_macros.h"
31#include "tensorflow/core/common_runtime/function.h"
32#include "tensorflow/core/common_runtime/optimization_registry.h"
33#include "tensorflow/core/common_runtime/shape_refiner.h"
34#include "tensorflow/core/framework/function.h"
35#include "tensorflow/core/framework/graph_def_util.h"
36#include "tensorflow/core/framework/node_def_builder.h"
37#include "tensorflow/core/framework/node_def_util.h"
38#include "tensorflow/core/graph/algorithm.h"
39#include "tensorflow/core/graph/graph.h"
40#include "tensorflow/core/graph/graph_def_builder.h"
41#include "tensorflow/core/graph/tensor_id.h"
42#include "tensorflow/core/lib/gtl/flatset.h"
43#include "tensorflow/core/lib/gtl/map_util.h"
44#include "tensorflow/core/lib/hash/hash.h"
45#include "tensorflow/core/lib/strings/str_util.h"
46#include "tensorflow/core/lib/strings/strcat.h"
47#include "tensorflow/core/public/session_options.h"
48#include "tensorflow/core/public/version.h"
49#include "tensorflow/core/util/device_name_utils.h"
51namespace tensorflow {
53const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
54const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
55const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
57namespace {
59bool AreAllParentsConst(const Node& n,
60                        const gtl::FlatSet<const Node*>& runtime_const_nodes) {
61  if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") {
62    // If the current node is itself a cast-to-const, no need
63    // to look at the incoming edges.
64    return true;
65  }
67  bool all_parents_const = true;
68  bool atleast_one_non_control_edge = false;
69  for (const Edge* in : n.in_edges()) {
70    atleast_one_non_control_edge =
71        atleast_one_non_control_edge || !in->IsControlEdge();
72    if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
73      all_parents_const = false;
74      break;
75    }
76  }
77  return all_parents_const && atleast_one_non_control_edge;
80void MarkGuaranteedConstants(
81    const Graph& graph,
82    const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
83  gtl::FlatSet<const Node*> guaranteed_const_nodes;
84  std::vector<const Node*> srcs;
85  srcs.reserve(src_arg_pairs.size());
86  for (const auto& src_arg : src_arg_pairs) {
87    srcs.push_back(src_arg.first);
88  }
89  ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
90                 /*leave=*/[&guaranteed_const_nodes](const Node* n) {
91                   // TODO(vinuraja): Doesn't work in the presence of loops.
92                   if (AreAllParentsConst(*n, guaranteed_const_nodes)) {
93                     guaranteed_const_nodes.insert(n);
94                   }
95                 });
97  for (auto& src_arg : src_arg_pairs) {
98    if (guaranteed_const_nodes.count(src_arg.first) != 0) {
99      VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
100      src_arg.second->AddAttr("_is_guaranteed_constant", true);
101    }
102  }
105// A node/slot pair.
106// TODO(phawkins): is there a common definition of this?
107struct NodeSlot {
108  NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {}
109  NodeSlot(const Node* node, int slot)
110      : node(node), slot(slot), dtype(DT_INVALID) {}
111  NodeSlot(const Node* node, int slot, DataType dtype)
112      : node(node), slot(slot), dtype(dtype) {}
114  const Node* node;
115  int slot;
117  // Optional: used to record the destination type of a source NodeSlot in case
118  // the source output is a Ref type that is cast to a Tensor at the
119  // destination.
120  DataType dtype;
122  bool operator==(const NodeSlot& other) const {
123    return node == other.node && slot == other.slot && dtype == other.dtype;
124  }
126  // Leave dtype out of the hash since there are never two NodeSlots with the
127  // same node and slot and different dtypes.
128  struct Hasher {
129    uint64 operator()(NodeSlot const& s) const {
130      return Hash64Combine(std::hash<const Node*>()(s.node),
131                           std::hash<int>()(s.slot));
132    }
133  };
135  struct PairHasher {
136    uint64 operator()(std::pair<NodeSlot, NodeSlot> const& s) const {
137      return Hash64Combine(Hasher()(s.first), Hasher()(s.second));
138    }
139  };
142// TODO(phawkins) add a canonical copy of these operator names and refactor
143// everything to use it.
144static const char* const kArgOp = "_Arg";
145static const char* const kRetValOp = "_Retval";
146static const char* const kHostComputeOp = "_XlaHostCompute";
147static const char* const kSendFromHostOp = "_XlaSendFromHost";
148static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
150class Encapsulator {
151 public:
152  Encapsulator(string group_attribute, string outside_compilation_attribute,
153               Graph const* graph_in)
154      : group_attribute_(std::move(group_attribute)),
155        outside_compilation_attribute_(
156            std::move(outside_compilation_attribute)),
157        graph_in_(graph_in) {}
159  // Find subgraphs marked with 'group_attribute', and build a new
160  // subgraph, one for each value of 'group_attribute'.
161  Status SplitIntoSubgraphs();
163  // Build a FunctionDef for each subgraph, and add it 'library'. The values of
164  // the 'group_attribute' annotations become the function names.
165  // If 'reuse_existing_functions' is set, use an existing function with the
166  // same name, if any.
167  // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
168  // function conversion.
169  Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
170                           bool reuse_existing_functions,
171                           FunctionLibraryDefinition* library);
173  // Write a copy of the input graph to 'graph_out', where the subgraphs are
174  // replaced with calls to the new functions.
175  Status BuildOutputGraph(bool parallel_checking, Graph* graph_out,
176                          FunctionLibraryDefinition* library);
178 private:
179  // A subgraph of the input, all marked with a common 'group_attribute'
180  // value. A subgraph may contain multiple `outside_compilation' clusters.
181  //
182  // In the following simple example, A, B, ..., E are nodes in the original
183  // graph. The group attributes and outside_compilation attributes g and oc are
184  // each shown as either 0 or empty.
185  //
186  //  A  -->  B  -->  C  -->  D  -->  E
187  //  g:      g:0     g:0     g:0     g:
188  //  oc:     oc:     oc:0    oc:     oc:
189  //
190  // The example is rewritten to two graphs; one on the host and one to be
191  // compiled. The host graph is as follows. RAH is a RecvAtHost node receiving
192  // input from the compiled cluster, and SFH is a SendFromHost node sending
193  // input back to the compiled cluster. Dotted edges are control edges. A
194  // 'sequencing' node S is inserted, and both RAH and SFH are connected via S
195  // to E (and in general all nodes that depend on nodes in the compiled
196  // cluster) to ensure that they are not pruned.
197  //
198  //  A  -->  Call  -->  E
199  //                     ^
200  //                     .
201  //           ........> S
202  //       ....          ^
203  //     ..             .
204  //  RAH -->  C  --> SFH
205  //
206  // The compiled cluster is as follows. HC is a HostCompute node which is the
207  // source of a channel to the RAH node above and the destination of a channel
208  // from the SFH node above.
209  //
210  //  Arg  --> B  --> HC  --> D --> Retval
211  //
212  // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is
213  // at most one RAH and SFH in each outside_compilation cluster. This design is
214  // preferred over adding separate Arg/Retval nodes for each transmitted value
215  // because it allows optimizations to the host code that would like to limit
216  // communication between host and device and, e.g., raise only one interrupt
217  // per channel rather than one per transmitted value.
218  //
219  // The shapes of the outputs from the HC node in general cannot be determined
220  // until the shapes of its inputs are known at compile time, since e.g.,
221  // above, the shape of C's outputs aren't known until the shape of its inputs
222  // are known. If the shapes of the HC's outputs can be determined during the
223  // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal
224  // graph is stored in the shape_inference_graph attr. This graph can be used
225  // when compiling the HC Op to determined the shape of the SFH inputs given
226  // the shapes of any ancestor RAH outputs. If it can be determined that the
227  // shape of the SFH inputs will not be inferrable even once the shapes of the
228  // RAH outputs are known, an error is returned by the rewriter.
229  class Subgraph {
230   public:
231    // Creates a graph to build the subgraph in, if it doesn't already exist,
232    // using the same op registry and versions as graph_in.
233    Node* MakeNodeImage(const Graph* graph_in, Node* node);
235    // Returns the graph the subgraph is being built in.
236    Graph* GetGraph() const;
238    // Builds a FunctionDef, and adds it to 'library'. The value of the
239    // 'group_attribute' annotations becomes the function name.  If
240    // 'reuse_existing_functions' is set, use an existing function with the same
241    // name, if any.  If 'rewrite_subgraph_fn' is set, it is applied to the
242    // subgraph before function conversion.
243    Status BuildFunctionDef(const string& name_in,
244                            const RewriteSubgraphFn& rewrite_subgraph_fn,
245                            bool reuse_existing_functions,
246                            FunctionLibraryDefinition* library);
248    // Adds the function call node to graph_out.
249    Status AddFunctionCallNode(
250        const std::unordered_map<const Node*, Node*>& node_images,
251        bool parallel_checking, Graph* graph_out);
253    // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out.
254    Status AddOutsideCompilationHostIONodes(
255        const string& subgraph_name,
256        const std::unordered_map<const Node*, Node*>& node_images,
257        Graph* graph_out);
259    // Returns the names of all the outside_compilation subgraphs in this
260    // Subgraph.
261    void GetOutsideCompilationSubgraphNames(std::vector<string>* names) const;
263    // Returns the Node that inputs to the function should be wired up to.
264    Node* GetCallNodeForInputs() const;
266    // Returns the Node that outputs to the function should be wired up to.
267    Node* GetCallNodeForOutputs() const;
269    // Returns the index of the arg that the dst of edge should connect to.
270    int GetArgIndexForEdge(const Edge* edge) const;
272    // Returns the index of the result that the src of edge should connect to.
273    int GetResultIndexForEdge(const Edge* edge) const;
275    // Returns the RecvAtHost node for an outside_compilation subgraph.
276    Node* GetRecvAtHostNode(
277        const string& outside_compilation_subgraph_name) const;
279    // Returns the output slot for the RecvAtHost node that corresponds to the
280    // source of edge in an outside_compilation subgraph.
281    int GetRecvAtHostSlot(const string& outside_compilation_subgraph_name,
282                          const Edge* edge) const;
284    // Returns the SendFromHost node for an outside_compilation subgraph.
285    Node* GetSendFromHostNode(
286        const string& outside_compilation_subgraph_name) const;
288    // Returns the input slot for the SendFromHost node that corresponds to the
289    // destination of edge in an outside_compilation subgraph.
290    int GetSendFromHostSlot(const string& outside_compilation_subgraph_name,
291                            const Edge* edge) const;
293    // Creates an _Arg node for the src node of edge, and add its index to
294    // args_by_src_, if none exists yet. Also adds its index to args_by_dst_,
295    // and adds the edge within the subgraph from the _Arg node to the image of
296    // the dst node.
297    Status RecordArg(const Edge* edge,
298                     const std::unordered_map<const Node*, Node*>& node_images,
299                     std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
301    // Creates a _Retval node for the src node of edge, and add it to results_,
302    // if none exists yet. If a new _Retval node is created, also adds the edge
303    // within the subgraph from the src to the _Retval node.
304    Status RecordResult(
305        const Edge* edge,
306        const std::unordered_map<const Node*, Node*>& node_images);
308    // Creates an outside_compilation subgraph for outside_compilation_id if
309    // none exists yet. Creates an entry for the src node of edge in the list of
310    // inputs for the outside_compilation subgraph, if none exists yet.
311    void RecordOutsideCompilationInputOrControl(
312        const string& outside_compilation_id, const Edge* edge);
314    // Creates an outside_compilation subgraph for outside_compilation_id if
315    // none exists yet. Creates an entry for the src node of edge in the list of
316    // outputs by src for the outside_compilation subgraph, if none exists
317    // yet. Creates an entry for the dst node of edge in the list of outputs by
318    // dst for the outside_compilation subgraph.
319    void RecordOutsideCompilationOutputOrControl(
320        const string& outside_compilation_id, const Edge* edge);
322    // Adds the HostCompute nodes for each outside_compilation subgraph.
323    Status AddHostComputes(
324        const string& subgraph_name,
325        const std::unordered_map<const Node*, Node*>& node_images);
327    // Creates the sequencer node if it doesn't exist, adding it to graph_out.
328    Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
330    // If there is a sequencer node, adds a control edge from the sequencer to
331    // all the downstream nodes of call_node_outputs.
332    void ConnectSequencerToOutputs(Graph* graph_out);
334    Status AddShapeInferenceInfo(
335        const string& outside_compilation_subgraph_name,
336        const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph);
338    Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
340   private:
341    struct OutsideCompilationSubgraph {
342      // Map from source (producer node/slot) tensors in the original graph to
343      // input index (slot number in the HostCompute/RecvAtHost nodes that will
344      // be created) for the outside_compilation subgraph.
345      std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
347      // Set of nodes in the original graph that are the source of control edges
348      // that cross from the containing compiled subgraph into the
349      // outside_compilation subgraph. These are recorded by
350      // RecordOutsideCompilationInputOrControl while walking all the subgraph
351      // edges, and lifted control edges within the subgraph are added by
352      // AddSendsToOutsideCompilation once the _HostCompute node has been
353      // created. The matching control edge from _RecvAtHost to the
354      // destination is added by CopyEdgeToOutputGraph.
355      std::unordered_set<const Node*> control_inputs;
357      // Maps from source (producer node/slot) and destination (consumer
358      // node/slot) tensors in the original graph to output index (slot number
359      // in the SendFromHost/HostCompute nodes that will be created) for the
360      // outside_compilation subgraph.
361      std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
362      std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
364      // Set of nodes in the original graph that are the destination of control
365      // edges that cross from the outside_compilation subgraph into the
366      // containing compiled subgraph. These are recorded by
367      // RecordOutsideCompilationOutputOrControl while walking all the subgraph
368      // edges, and lifted control edges within the subgraph are added by
369      // AddRecvsFromToOutsideCompilation once the _HostCompute node has been
370      // created. The matching control edge from the source to _SendFromHost to
371      // the destination is added by CopyEdgeToOutputGraph.
372      std::unordered_set<const Node*> control_outputs;
374      // Name of the _HostCompute node in the subgraph.
375      string host_compute_name;
377      // _RecvAtHost node in the output graph. Not owned.
378      Node* recv_at_host = nullptr;
380      // _SendFromHost node in the output graph. Not owned.
381      Node* send_from_host = nullptr;
382    };
384    // Builds a ParallelCheck op that compares the output of the original
385    // subgraph with the encapsulated subgraph.
386    Status BuildParallelCheckOp(
387        const std::unordered_map<const Node*, Node*>& node_images,
388        Graph* graph_out);
390    // Builds a _RecvAtHost node producing all the inputs of an
391    // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host.
392    Status AddRecvAtHostNode(const string& subgraph_name,
393                             const string& oc_subgraph_name,
394                             OutsideCompilationSubgraph* oc_subgraph,
395                             Graph* graph_out);
397    // Builds a _SendFromHost node consuming all the outputs of an
398    // outside_compilation subgraph and stores it in oc_subgraph.send_from_host.
399    Status AddSendFromHostNode(
400        const std::unordered_map<const Node*, Node*>& node_images,
401        const string& subgraph_name, const string& oc_subgraph_name,
402        OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out);
404    // The subgraph extracted from the input graph, suitable for being turned
405    // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
406    // returned by _Retval nodes.
407    std::unique_ptr<Graph> graph_;
409    // Which device are these nodes on? Used to assign a device to the call
410    // node.
411    string device_;
413    // NodeDef for the function call node.
414    NodeDef call_node_def_;
416    // Function call node(s) in the output graph. Not owned.
417    // If parallel_checking is enabled, 'call_node_inputs' is the function call
418    // node to which inputs should be fed, and 'call_node_outputs' is the
419    // parallel check op from which outputs should be read. If parallel checking
420    // is disabled, both point to the function call node.
421    Node* call_node_inputs_;
422    Node* call_node_outputs_;
424    // Maps from source (producer node/slot) and destination
425    // (consumer node/slot) tensors in the input graph to _Arg numbers in
426    // the subgraph. The source map is one-to-one, whereas the dest map may be
427    // many-to-one.
428    std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_src_;
429    std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_dst_;
431    // The _Arg nodes in the subgraph, in order by argument number.
432    std::vector<Node*> args_;
434    // Map from source tensor in the input graph to result #.
435    std::unordered_map<NodeSlot, int, NodeSlot::Hasher> results_;
437    // The outside_compilation clusters in this subgraph.
438    std::unordered_map<string, OutsideCompilationSubgraph>
439        outside_compilation_subgraphs_;
441    // NoOp node in the output graph that is sequenced after the call node and
442    // used to prevent host-side outside_compilation sends and recvs from being
443    // pruned.
444    Node* sequencer_ = nullptr;
445  };
447  // Returns the key attribute and outside_compilation attribute associated
448  // with a node in attr, and outside_compilation_attr, respectively. Sets
449  // either result to the empty string if the respective attribute is not
450  // found. Returns error status if there is an outside_compilation attribute
451  // and no key attribute,
452  Status GetFunctionNameAttr(Node const* node, string* attr,
453                             string* outside_compilation_attr) const;
455  // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
456  // subgraphs for data edges that cross subgraph boundaries.
457  Status CopySubgraphEdges(
458      const std::unordered_map<const Node*, Node*>& node_images,
459      std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
461  // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes,
462  // or nodes marked outside_compilation.
463  Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
465  // Copies all nodes that aren't in a compiled subgraph to the output graph.
466  Status CopyNodesToOutputGraph(
467      bool parallel_checking, Graph* graph_out,
468      std::unordered_map<const Node*, Node*>* node_images);
470  // Adds function call nodes for each compiled subgraph.
471  Status AddFunctionCallNodes(
472      const std::unordered_map<const Node*, Node*>& node_images,
473      bool parallel_checking, Graph* graph_out);
475  // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all
476  // outside_compilation subgraphs.
477  Status AddOutsideCompilationHostIONodes(
478      const std::unordered_map<const Node*, Node*>& node_images,
479      Graph* graph_out);
481  // Finds the image of an edge source in the output graph. If the edge crosses
482  // a subgraph boundary it is the output of a call node, otherwise it is a node
483  // in the output graph.
484  Status FindOutputImageOfEdgeSrc(
485      const string& src_func_id, const string& src_outside_compilation_id,
486      const string& dst_func_id, const string& dst_outside_compilation_id,
487      const std::unordered_map<const Node*, Node*>& node_images,
488      const Node* original_src_node, Node** src_image);
490  // Finds an edge source slot in the output graph. If the edge crosses a
491  // subgraph boundary it is a slot on the output of a call node or a
492  // _RecvAtHost node, otherwise it is a slot on a node in the output graph.
493  int FindOutputSlotOfEdgeSrc(const string& src_func_id,
494                              const string& src_outside_compilation_id,
495                              const string& dst_func_id,
496                              const string& dst_outside_compilation_id,
497                              const Edge* edge);
499  // Finds the image of an edge destination in the output graph. If the edge
500  // crosses a subgraph boundary it is the input of a call node or a
501  // _SendFromHost node, otherwise it is a node in the output graph.
502  Status FindOutputImageOfEdgeDst(
503      const string& src_func_id, const string& src_outside_compilation_id,
504      const string& dst_func_id, const string& dst_outside_compilation_id,
505      const std::unordered_map<const Node*, Node*>& node_images,
506      const Node* original_dst_node, Node** dst_image);
508  // Finds an edge destination slot in the output graph. If the edge crosses a
509  // subgraph boundary it is a slot on the input of a call node or a
510  // _SendFromHost node, otherwise it is a slot on a node in the output graph.
511  int FindOutputSlotOfEdgeDst(const string& src_func_id,
512                              const string& src_outside_compilation_id,
513                              const string& dst_func_id,
514                              const string& dst_outside_compilation_id,
515                              const Edge* edge);
517  // Copies a single edge to the output graph. The edge is either entirely
518  // within the output graph, or crosses into or out of a compiled subgraph.
519  Status CopyEdgeToOutputGraph(
520      const Edge* edge, const string& src_func_id,
521      const string& src_outside_compilation_id, const string& dst_func_id,
522      const string& dst_outside_compilation_id,
523      const std::unordered_map<const Node*, Node*>& node_images,
524      bool parallel_checking, Graph* graph_out,
525      std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
526          edges_added);
528  // Adds all edges to the output graph.
529  Status AddEdgesToOutputGraph(
530      const std::unordered_map<const Node*, Node*>& node_images,
531      bool parallel_checking, Graph* graph_out);
533  // Constructs a minimal shape inference graph that can be used to determine
534  // the shape of send_node at the time that the subgraph is compiled.
535  // recv_at_host_nodes contains the names of all the recv_at_host nodes that
536  // send_node might depend on. These recv_at_host nodes have shapes that are
537  // not known during the rewrite pass, but will be known at compile time.
538  //
539  // If the shapes of all the inputs to send_node can be determined during the
540  // rewrite pass, on exit graphdef_out is empty and the shapes are returned in
541  // static_shape_out. Otherwise graphdef_out contains a graph that can be used
542  // for shape inference at compile time, where all the source nodes of the
543  // graph are either constants with known shapes, or nodes named in
544  // recv_at_host_nodes.
545  //
546  // A non-OK status is returned if neither of the above conditions can be
547  // satisfied, e.g., because send_node depends on a node that doesn't have a
548  // registered shape inference function.
549  Status DoStaticShapeInferenceForOutsideCompilationSend(
550      const Graph& graph_in, const ShapeRefiner& shape_refiner,
551      const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
552      FunctionLibraryDefinition* library,
553      std::vector<TensorShapeProto>* static_shape_out,
554      std::unique_ptr<GraphDef>* graphdef_out);
556  // Makes a copy of graph containing only nodes that are ancestors of at least
557  // one node in send_from_host_nodes and store it in pruned_graph. On exit
558  // nodes_images contains a mapping from nodes in graph to nodes in
559  // pruned_graph. All functions in the copied graph are inlined.
560  Status MakePrunedGraphCopyAndInline(
561      const Graph& graph, const std::vector<Node*>& sink_nodes,
562      std::unique_ptr<Graph>* pruned_graph,
563      std::unordered_map<const Node*, Node*>* node_images,
564      FunctionLibraryDefinition* library);
566  // Makes a copy of graph containing only nodes that are ancestors of a
567  // send_from_host node in an outside_compilation subgraph, and store it in
568  // pruned_graph. Also perform shape inference on the pruned graph, using
569  // shape_refiner. On exit node_images contains a mapping from nodes in graph
570  // to nodes in pruned_graph.
571  Status MakeGraphForOutsideCompilationSends(
572      const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
573      ShapeRefiner* shape_refiner,
574      std::unordered_map<const Node*, Node*>* node_images,
575      FunctionLibraryDefinition* library);
577  // Performs static shape inference, as far as possible, for the send_from_host
578  // nodes in each outside_compilation subgraph. Where it is not possible to
579  // determine the shape statically, stores a serialized GraphDef in the
580  // HostCompute 'shape_inference_graph' attr, to be used at compile time for
581  // final inference. If the shapes are known statically they are stored in the
582  // HostCompute 'shapes' attr.
583  Status GetShapeInfoForOutsideCompilationSends(
584      Graph* graph_out, FunctionLibraryDefinition* library);
586  const string group_attribute_;
587  const string outside_compilation_attribute_;
588  const Graph* graph_in_;
590  std::unordered_map<string, Subgraph> subgraphs_;
595Node* Encapsulator::Subgraph::GetCallNodeForInputs() const {
596  return call_node_inputs_;
599Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const {
600  return call_node_outputs_;
603int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
604  return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input()));
607int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
608  return results_.at(NodeSlot(edge->src(), edge->src_output()));
611Node* Encapsulator::Subgraph::GetRecvAtHostNode(
612    const string& outside_compilation_subgraph_name) const {
613  return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
614      .recv_at_host;
617int Encapsulator::Subgraph::GetRecvAtHostSlot(
618    const string& outside_compilation_subgraph_name, const Edge* edge) const {
619  return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
620      .inputs.at(NodeSlot(edge->src(), edge->src_output()));
623Node* Encapsulator::Subgraph::GetSendFromHostNode(
624    const string& outside_compilation_subgraph_name) const {
625  return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
626      .send_from_host;
629int Encapsulator::Subgraph::GetSendFromHostSlot(
630    const string& outside_compilation_subgraph_name, const Edge* edge) const {
631  return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
632      .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input()));
635Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
636  if (!graph_) {
637    graph_.reset(new Graph(graph_in->op_registry()));
638    graph_->set_versions(graph_in->versions());
639  }
641  if (device_.empty()) {
642    device_ = node->assigned_device_name().empty()
643                  ? node->requested_device()
644                  : node->assigned_device_name();
645  }
647  return graph_->CopyNode(node);
650Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
652Status Encapsulator::Subgraph::RecordArg(
653    const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
654    std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
655  Node* src_node = edge->src();
656  int src_slot = edge->src_output();
657  std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
658  bool inserted;
659  std::tie(iter, inserted) =
660      args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size());
661  int arg_index = iter->second;
662  if (inserted) {
663    NodeDef arg_def;
664    NodeDefBuilder builder(
665        strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
666    DataType dtype = edge->dst()->input_type(edge->dst_input());
667    builder.Attr("T", dtype);
668    builder.Attr("index", arg_index);
669    Status s = builder.Finalize(&arg_def);
670    if (!s.ok()) return s;
672    Node* arg = graph_->AddNode(arg_def, &s);
673    if (!s.ok()) return s;
675    src_arg_pairs->push_back({src_node, arg});
676    args_.push_back(arg);
677  }
678  Node* dst_node = edge->dst();
679  Node* dst_image = node_images.at(dst_node);
680  int dst_slot = edge->dst_input();
681  args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index;
682  graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
683  return Status::OK();
686Status Encapsulator::Subgraph::RecordResult(
687    const Edge* edge,
688    const std::unordered_map<const Node*, Node*>& node_images) {
689  Node* src_node = edge->src();
690  Node* src_image = node_images.at(src_node);
691  int src_slot = edge->src_output();
692  std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
693  bool inserted;
694  std::tie(iter, inserted) =
695      results_.emplace(NodeSlot(src_node, src_slot), results_.size());
696  int ret_index = iter->second;
697  if (inserted) {
698    NodeDef ret_def;
699    NodeDefBuilder builder(
700        strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
701    DataType dtype = src_node->output_type(src_slot);
702    builder.Attr("T", dtype);
703    builder.Attr("index", ret_index);
704    builder.Input(src_image->name(), src_slot, dtype);
705    Status s = builder.Finalize(&ret_def);
706    if (!s.ok()) return s;
707    Node* ret = graph_->AddNode(ret_def, &s);
708    if (!s.ok()) return s;
710    graph_->AddEdge(src_image, src_slot, ret, 0);
711  }
712  return Status::OK();
715void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl(
716    const string& outside_compilation_id, const Edge* edge) {
717  auto iter = outside_compilation_subgraphs_
718                  .emplace(outside_compilation_id, OutsideCompilationSubgraph())
719                  .first;
720  OutsideCompilationSubgraph& outside_subgraph = iter->second;
721  if (edge->IsControlEdge()) {
722    outside_subgraph.control_inputs.insert(edge->src());
723  } else {
724    int input_index = outside_subgraph.inputs.size();
725    outside_subgraph.inputs.emplace(NodeSlot(edge->src(), edge->src_output()),
726                                    input_index);
727  }
730void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
731    const string& outside_compilation_id, const Edge* edge) {
732  auto subgraph_iter =
733      outside_compilation_subgraphs_
734          .emplace(outside_compilation_id, OutsideCompilationSubgraph())
735          .first;
736  OutsideCompilationSubgraph& outside_subgraph = subgraph_iter->second;
737  if (edge->IsControlEdge()) {
738    outside_subgraph.control_outputs.insert(edge->dst());
739  } else {
740    DataType dtype = edge->dst()->input_type(edge->dst_input());
741    auto output_iter =
742        outside_subgraph.outputs_by_src
743            .emplace(NodeSlot(edge->src(), edge->src_output(), dtype),
744                     outside_subgraph.outputs_by_src.size())
745            .first;
746    int output_index = output_iter->second;
747    outside_subgraph.outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] =
748        output_index;
749  }
752Status Encapsulator::Subgraph::AddHostComputes(
753    const string& subgraph_name,
754    const std::unordered_map<const Node*, Node*>& node_images) {
755  for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
756    const string& oc_subgraph_name = oc_subgraph_iter.first;
757    OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
758    if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() ||
759        !oc_subgraph.outputs_by_src.empty() ||
760        !oc_subgraph.control_outputs.empty()) {
761      // Build a _HostCompute node.
762      std::vector<NodeDefBuilder::NodeOut> inputs(oc_subgraph.inputs.size());
763      std::vector<DataType> input_dtypes(oc_subgraph.inputs.size(), DT_INVALID);
764      std::vector<DataType> output_dtypes(oc_subgraph.outputs_by_src.size(),
765                                          DT_INVALID);
767      for (const auto& input_src : oc_subgraph.inputs) {
768        const Node* src_node = input_src.first.node;
769        Node* src_image = node_images.at(src_node);
770        int src_slot = input_src.first.slot;
771        int input_index = input_src.second;
773        DataType dtype = src_node->output_type(src_slot);
774        inputs[input_index].Reset(src_image->name(), src_slot, dtype);
775        input_dtypes[input_index] = dtype;
776      }
778      for (const auto& output : oc_subgraph.outputs_by_src) {
779        DataType dtype = output.first.dtype;
780        int output_index = output.second;
781        output_dtypes[output_index] = dtype;
782      }
784      NodeDef host_compute_def;
785      NodeDefBuilder builder(strings::StrCat("outside_compilation_",
786                                             oc_subgraph_name, "_host_compute"),
787                             kHostComputeOp);
788      builder.Input(inputs);
789      builder.Attr("Tinputs", input_dtypes);
790      builder.Attr("Toutputs", output_dtypes);
791      builder.Attr("key",
792                   strings::StrCat("host_compute_channel_", subgraph_name, "_",
793                                   oc_subgraph_name));
794      Status s = builder.Finalize(&host_compute_def);
795      if (!s.ok()) return s;
797      Node* host_compute = graph_->AddNode(host_compute_def, &s);
798      if (!s.ok()) return s;
799      oc_subgraph.host_compute_name = host_compute->name();
801      // Connect the _HostCompute node to its producers in the subgraph.
802      for (auto& input_src : oc_subgraph.inputs) {
803        const Node* src_node = input_src.first.node;
804        Node* src_image = node_images.at(src_node);
805        int src_slot = input_src.first.slot;
806        int input_index = input_src.second;
807        graph_->AddEdge(src_image, src_slot, host_compute, input_index);
808      }
810      // Connect the _HostCompute node to its control edge producers in the
811      // subgraph.
812      for (const auto& src_node : oc_subgraph.control_inputs) {
813        Node* src_image = node_images.at(src_node);
814        graph_->AddControlEdge(src_image, host_compute);
815      }
817      // Connect the consumers in the subgraph to the _HostCompute node.
818      for (const auto& output : oc_subgraph.outputs_by_dst) {
819        const Node* dst_node = output.first.node;
820        Node* dst_image = node_images.at(dst_node);
821        int dst_slot = output.first.slot;
822        int output_index = output.second;
824        graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
825      }
827      // Connect the control edge consumers in the subgraph to the _HostCompute
828      // node.
829      for (const auto& dst_node : oc_subgraph.control_outputs) {
830        Node* dst_image = node_images.at(dst_node);
831        graph_->AddControlEdge(host_compute, dst_image);
832      }
833    }
834  }
836  return Status::OK();
839Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
840                                                  Graph* graph_out) {
841  if (sequencer_ == nullptr) {
842    NodeDef seq_def;
843    NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"),
844                           "NoOp");
845    Status s = builder.Finalize(&seq_def);
846    if (!s.ok()) return s;
848    sequencer_ = graph_out->AddNode(seq_def, &s);
849    if (!s.ok()) return s;
850    sequencer_->set_assigned_device_name(device_);
851  }
852  return Status::OK();
855void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) {
856  if (sequencer_ != nullptr) {
857    std::unordered_set<Node*> output_dependencies;
858    for (Node* node : call_node_outputs_->out_nodes()) {
859      output_dependencies.insert(node);
860    }
861    for (Node* node : output_dependencies) {
862      graph_out->AddControlEdge(sequencer_, node);
863    }
864  }
867Status Encapsulator::Subgraph::BuildFunctionDef(
868    const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
869    bool reuse_existing_functions, FunctionLibraryDefinition* library) {
870  // name_in is copied here because name may be modified below if
871  // rewrite_subgraph_fn is true.
872  string name = name_in;
873  call_node_def_.set_op(name);
874  call_node_def_.set_name(name);
875  call_node_def_.set_device(device_);
877  if (rewrite_subgraph_fn) {
878    // Initialize the input and output permutations to the identity.
879    std::vector<int> input_permutation(args_by_src_.size());
880    std::iota(input_permutation.begin(), input_permutation.end(), 0);
881    std::vector<int> output_permutation(results_.size());
882    std::iota(output_permutation.begin(), output_permutation.end(), 0);
884    TF_RETURN_IF_ERROR(rewrite_subgraph_fn(
885        &graph_, &input_permutation, &output_permutation, &call_node_def_));
887    // Apply the input/output permutations to the 'args_by_...' and 'results_'
888    // mappings, so when we build edges in BuildOutputGraph() we
889    // connect them to the right input/output positions.
890    if (input_permutation.size() != args_by_src_.size()) {
891      return errors::InvalidArgument("Input permutation has incorrect size.");
892    }
893    if (output_permutation.size() != results_.size()) {
894      return errors::InvalidArgument("Output permutation has incorrect size.");
895    }
896    for (auto& arg : args_by_src_) {
897      arg.second = input_permutation[arg.second];
898    }
899    for (auto& arg : args_by_dst_) {
900      arg.second = input_permutation[arg.second];
901    }
902    for (auto& result : results_) {
903      result.second = output_permutation[result.second];
904    }
906    name = call_node_def_.op();
907  }
909  FunctionDef fdef;
910  TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
912  if (VLOG_IS_ON(1)) {
913    VLOG(2) << "Build function def " << name;
914    dump_graph::DumpGraphToFile(
915        strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library);
916    dump_graph::DumpFunctionDefToFile(
917        strings::StrCat("encapsulate_fdef_", name), fdef);
918  }
920  if (!reuse_existing_functions || library->Find(name) == nullptr) {
921    TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
922  }
923  return Status::OK();
926Status Encapsulator::Subgraph::AddShapeInferenceInfo(
927    const string& outside_compilation_subgraph_name,
928    const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) {
929  OutsideCompilationSubgraph& oc_subgraph =
930      outside_compilation_subgraphs_.at(outside_compilation_subgraph_name);
932  Node* host_compute = nullptr;
933  for (Node* n : graph_->nodes()) {
934    if (n->name() == oc_subgraph.host_compute_name) {
935      host_compute = n;
936      break;
937    }
938  }
939  if (host_compute == nullptr) {
940    return errors::InvalidArgument(
941        "After rewriting subgraph ", outside_compilation_subgraph_name,
942        " there is no HostCompute Op for outside compilation subgraph ",
943        oc_subgraph.host_compute_name);
944  }
946  if (inference_graph == nullptr) {
947    host_compute->AddAttr("shape_inference_graph", "");
948    host_compute->AddAttr("shapes", shapes);
949  } else {
950    string serialized_graph;
951    if (!inference_graph->SerializeToString(&serialized_graph)) {
952      return errors::Internal(
953          "Failed to serialize graph for outside compilation subgraph ",
954          oc_subgraph.host_compute_name);
955    }
956    host_compute->AddAttr("shape_inference_graph", serialized_graph);
957    host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
958  }
959  return Status::OK();
962Status Encapsulator::Subgraph::ReplaceFunctionDef(
963    FunctionLibraryDefinition* library) {
964  const string& name = call_node_def_.name();
966  FunctionDef fdef;
967  TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
969  if (VLOG_IS_ON(1)) {
970    VLOG(2) << "Replace function def " << name;
971    dump_graph::DumpGraphToFile(
972        strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
973        library);
974    dump_graph::DumpFunctionDefToFile(
975        strings::StrCat("replace_encapsulate_fdef_", name), fdef);
976  }
978  TF_RETURN_IF_ERROR(library->RemoveFunction(name));
979  TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
980  return Status::OK();
983Status Encapsulator::Subgraph::BuildParallelCheckOp(
984    const std::unordered_map<const Node*, Node*>& node_images,
985    Graph* graph_out) {
986  // Build an index mapping output positions to node/slot pairs in the
987  // original graph.
988  std::vector<NodeSlot> results_by_num(results_.size());
989  for (const auto& entry : results_) {
990    results_by_num[entry.second] = entry.first;
991  }
993  // Build a parallel check NodeDef.
994  int num_results = results_by_num.size();
995  std::vector<DataType> result_dtypes(num_results);
996  std::vector<NodeDefBuilder::NodeOut> expected_outputs(num_results);
997  std::vector<NodeDefBuilder::NodeOut> actual_outputs(num_results);
998  for (int i = 0; i < num_results; ++i) {
999    const NodeSlot& node_slot = results_by_num[i];
1000    result_dtypes[i] = node_slot.node->output_type(node_slot.slot);
1001    expected_outputs[i] =
1002        NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(),
1003                                node_slot.slot, result_dtypes[i]);
1004    actual_outputs[i] =
1005        NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]);
1006  }
1007  // Assign the parallel check op to a CPU on the same task as the cluster it is
1008  // checking.
1009  string device, dummy;
1010  if (!DeviceNameUtils::SplitDeviceName(
1011          call_node_inputs_->assigned_device_name(), &device, &dummy)) {
1012    return errors::InvalidArgument("Could not parse device name");
1013  }
1014  strings::StrAppend(&device, "/cpu:0");
1016  NodeDef check_def;
1018      NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(),
1019                                                        "_parallel_check")),
1020                     "ParallelCheck")
1021          .Device(device)
1022          .Attr("T", result_dtypes)
1023          .Input(expected_outputs)
1024          .Input(actual_outputs)
1025          .Finalize(&check_def));
1027  Status s;
1028  Node* check_op = graph_out->AddNode(check_def, &s);
1029  if (!s.ok()) return s;
1030  check_op->set_assigned_device_name(device);
1032  // TODO(phawkins): it seems redundant to call AddEdge as well as
1033  // pass Inputs to the NodeDefBuilder, but I have been unable to find a
1034  // way to avoid it.
1035  for (int i = 0; i < num_results; ++i) {
1036    const NodeSlot& node_slot = results_by_num[i];
1037    graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op,
1038                       i);
1039    graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i);
1040  }
1042  call_node_outputs_ = check_op;
1043  return Status::OK();
1046Status Encapsulator::Subgraph::AddFunctionCallNode(
1047    const std::unordered_map<const Node*, Node*>& node_images,
1048    bool parallel_checking, Graph* graph_out) {
1049  Status s;
1050  call_node_inputs_ = graph_out->AddNode(call_node_def_, &s);
1051  if (!s.ok()) return s;
1053  // Copy the assigned device and the key_annotation over.
1054  call_node_inputs_->set_assigned_device_name(device_);
1055  call_node_outputs_ = call_node_inputs_;
1057  if (parallel_checking) {
1058    TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out));
1059  }
1060  return Status::OK();
1063Status Encapsulator::Subgraph::AddRecvAtHostNode(
1064    const string& subgraph_name, const string& oc_subgraph_name,
1065    OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) {
1066  std::vector<DataType> dtypes(oc_subgraph->inputs.size(), DT_INVALID);
1068  for (const auto& input : oc_subgraph->inputs) {
1069    const Node* src_node = input.first.node;
1070    int src_slot = input.first.slot;
1071    int input_index = input.second;
1073    DataType dtype = src_node->output_type(src_slot);
1074    dtypes[input_index] = dtype;
1075  }
1077  NodeDef recv_def;
1078  NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
1079                                         "_", oc_subgraph_name, "_recv"),
1080                         kRecvAtHostOp);
1081  builder.Attr("Toutputs", dtypes);
1082  builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
1083                                      "_", oc_subgraph_name));
1084  Status s = builder.Finalize(&recv_def);
1085  if (!s.ok()) return s;
1087  oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s);
1088  if (!s.ok()) return s;
1089  oc_subgraph->recv_at_host->set_assigned_device_name(device_);
1091  // Add a control dependency forcing the RecvAtHost to run before the subgraph
1092  // completes. This has no effect on execution order but prevents the
1093  // RecvAtHost being pruned.
1094  TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out));
1095  graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_);
1097  return Status::OK();
1100Status Encapsulator::Subgraph::AddSendFromHostNode(
1101    const std::unordered_map<const Node*, Node*>& node_images,
1102    const string& subgraph_name, const string& oc_subgraph_name,
1103    OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) {
1104  std::vector<DataType> dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID);
1105  std::vector<NodeDefBuilder::NodeOut> inputs(
1106      oc_subgraph->outputs_by_src.size());
1108  for (const auto& output : oc_subgraph->outputs_by_src) {
1109    const Node* src_node = output.first.node;
1110    Node* src_image = node_images.at(src_node);
1111    int src_slot = output.first.slot;
1112    int output_index = output.second;
1114    DataType dtype = src_node->output_type(src_slot);
1115    dtypes[output_index] = dtype;
1116    inputs[output_index].Reset(src_image->name(), src_slot, dtype);
1117  }
1119  NodeDef send_def;
1120  NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
1121                                         "_", oc_subgraph_name, "_send"),
1122                         kSendFromHostOp);
1123  builder.Attr("Tinputs", dtypes);
1124  builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
1125                                      "_", oc_subgraph_name));
1126  builder.Input(inputs);
1127  Status s = builder.Finalize(&send_def);
1128  if (!s.ok()) return s;
1130  oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s);
1131  if (!s.ok()) return s;
1132  oc_subgraph->send_from_host->set_assigned_device_name(device_);
1134  // Add a control dependency forcing the SendFromHost to run before the
1135  // subgraph completes. This has no effect on execution order but prevents the
1136  // RecvAtHost being pruned.
1137  TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out));
1138  graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_);
1140  return Status::OK();
1143Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes(
1144    const string& subgraph_name,
1145    const std::unordered_map<const Node*, Node*>& node_images,
1146    Graph* graph_out) {
1147  for (auto& outside_compilation_subgraph_entry :
1148       outside_compilation_subgraphs_) {
1149    const string& oc_name = outside_compilation_subgraph_entry.first;
1150    OutsideCompilationSubgraph& oc_subgraph =
1151        outside_compilation_subgraph_entry.second;
1153    if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) {
1155          AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out));
1156    }
1158    if (!oc_subgraph.outputs_by_src.empty() ||
1159        !oc_subgraph.control_outputs.empty()) {
1160      TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name,
1161                                             oc_name, &oc_subgraph, graph_out));
1162    }
1163  }
1164  return Status::OK();
1167void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames(
1168    std::vector<string>* names) const {
1169  for (auto& entry : outside_compilation_subgraphs_) {
1170    names->push_back(entry.first);
1171  }
1174Status Encapsulator::GetFunctionNameAttr(
1175    Node const* node, string* attr, string* outside_compilation_attr) const {
1176  Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
1177  if (s.code() == error::Code::NOT_FOUND) {
1178    // Return empty attr if there's no group_attribute.
1179    attr->clear();
1180  } else {
1181    TF_RETURN_IF_ERROR(s);
1182  }
1183  bool has_group_attr = s.ok();
1184  s = GetNodeAttr(node->attrs(), outside_compilation_attribute_,
1185                  outside_compilation_attr);
1186  if (s.code() == error::Code::NOT_FOUND) {
1187    // Return empty attr if there's no outside_compilation attribute.
1188    outside_compilation_attr->clear();
1189  } else {
1190    TF_RETURN_IF_ERROR(s);
1191    if (!has_group_attr) {
1192      return errors::InvalidArgument(
1193          "Node ", node->name(), " has ", outside_compilation_attribute_,
1194          " attribute but no ", group_attribute_, " attribute.");
1195    }
1196  }
1197  return Status::OK();
1200bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) {
1201  return !func_id.empty() && outside_compilation_id.empty();
1204Status Encapsulator::CopySubgraphNodes(
1205    std::unordered_map<const Node*, Node*>* node_images) {
1206  for (Node* node : graph_in_->op_nodes()) {
1207    string func_id;
1208    string outside_compilation_id;
1210        GetFunctionNameAttr(node, &func_id, &outside_compilation_id));
1211    if (!IsInSubgraph(func_id, outside_compilation_id)) continue;
1213    Subgraph& subgraph = subgraphs_[func_id];
1214    Node* image = subgraph.MakeNodeImage(graph_in_, node);
1215    image->ClearAttr(group_attribute_);
1216    (*node_images)[node] = image;
1217  }
1218  return Status::OK();
1221Status Encapsulator::CopySubgraphEdges(
1222    const std::unordered_map<const Node*, Node*>& node_images,
1223    std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
1224  for (const Edge* edge : graph_in_->edges()) {
1225    string src_func_id;
1226    string src_outside_compilation_id;
1227    TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id,
1228                                           &src_outside_compilation_id));
1229    string dst_func_id;
1230    string dst_outside_compilation_id;
1231    TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id,
1232                                           &dst_outside_compilation_id));
1233    Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
1234    Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
1236    // Copy edges that are local to a subgraph.
1237    if (IsInSubgraph(src_func_id, src_outside_compilation_id) &&
1238        IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
1239        src_func_id == dst_func_id) {
1240      Graph* g = subgraphs_[src_func_id].GetGraph();
1241      if (edge->IsControlEdge()) {
1242        g->AddControlEdge(src_image, dst_image);
1243      } else {
1244        g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
1245      }
1246      continue;
1247    }
1249    // Record 'src' as an output of its subgraph, if applicable.
1250    if (IsInSubgraph(src_func_id, src_outside_compilation_id)) {
1251      if (!edge->IsControlEdge()) {
1252        DataType dtype = edge->src()->output_type(edge->src_output());
1253        if (IsRefType(dtype)) {
1254          return errors::InvalidArgument(
1255              "Ref Tensors (e.g., Variables) are not supported as results: "
1256              "tensor ",
1257              edge->src()->name(), ":", edge->src_output());
1258        }
1259      }
1261      Subgraph& src_subgraph = subgraphs_[src_func_id];
1262      if (src_func_id == dst_func_id) {
1263        // src is in the subgraph and dst is outside_compilation in the same
1264        // subgraph.
1265        src_subgraph.RecordOutsideCompilationInputOrControl(
1266            dst_outside_compilation_id, edge);
1267      } else {
1268        // Ignore control edges leaving the subgraph. We will lift them onto the
1269        // enclosing call operators in BuildOutputGraph().
1270        if (!edge->IsControlEdge()) {
1271          TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
1272        }
1273      }
1274    }
1276    // Record 'dst' as an input of its subgraph, if applicable.
1277    if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) {
1278      // Look at the type of the destination not the source, since Ref output
1279      // Tensors can be automatically cast to non-Ref Tensors at the
1280      // destination.
1281      if (!edge->IsControlEdge()) {
1282        DataType dtype = edge->dst()->input_type(edge->dst_input());
1283        if (IsRefType(dtype)) {
1284          return errors::InvalidArgument(
1285              "Ref Tensors (e.g., Variables) are not supported as args: "
1286              "tensor ",
1287              edge->src()->name(), ":", edge->src_output());
1288        }
1289      }
1291      Subgraph& dst_subgraph = subgraphs_[dst_func_id];
1292      if (src_func_id == dst_func_id) {
1293        // dst is in the subgraph and src is outside_compilation in the same
1294        // subgraph.
1295        dst_subgraph.RecordOutsideCompilationOutputOrControl(
1296            src_outside_compilation_id, edge);
1297      } else {
1298        // Ignore control edges entering the subgraph. We will lift them onto
1299        // the enclosing call operators in BuildOutputGraph().
1300        if (!edge->IsControlEdge()) {
1301          TF_RETURN_IF_ERROR(
1302              dst_subgraph.RecordArg(edge, node_images, src_arg_pairs));
1303        }
1304      }
1305    }
1306  }
1307  return Status::OK();
1310Status Encapsulator::SplitIntoSubgraphs() {
1311  Status s;
1313  // Map from input graph nodes to subgraph nodes.
1314  std::unordered_map<const Node*, Node*> node_images;
1316  // Each entry of src_arg_pairs is a pair whose first element is a node in the
1317  // original graph that has an output edge in the subgraph, and whose second
1318  // element is the arg node in the subgraph that it sends to. The vector will
1319  // be filled in below in AddArgs.
1320  std::vector<std::pair<const Node*, Node*>> src_arg_pairs;
1322  TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
1323  TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
1325  // For each subgraph, add the nodes that deal with inputs and outputs its
1326  // nested outside_compilation subgraphs. These could not be added earlier
1327  // during CopySubgraphEdges since we need to discover all the types of the
1328  // inputs and outputs for an outside_compilation subgraph before creating a
1329  // single input and output node for it.
1330  for (auto& entry : subgraphs_) {
1331    Subgraph& subgraph = entry.second;
1332    TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images));
1333  }
1335  MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
1337  for (auto& entry : subgraphs_) {
1338    Subgraph& subgraph = entry.second;
1339    FixupSourceAndSinkEdges(subgraph.GetGraph());
1340  }
1342  return s;
1345Status Encapsulator::BuildFunctionDefs(
1346    const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
1347    FunctionLibraryDefinition* library) {
1348  for (auto& subgraph_entry : subgraphs_) {
1349    string name = subgraph_entry.first;
1350    Subgraph& subgraph = subgraph_entry.second;
1351    TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef(
1352        name, rewrite_subgraph_fn, reuse_existing_functions, library));
1353  }
1354  return Status::OK();
1357Status Encapsulator::CopyNodesToOutputGraph(
1358    bool parallel_checking, Graph* graph_out,
1359    std::unordered_map<const Node*, Node*>* node_images) {
1360  for (Node* node : graph_in_->op_nodes()) {
1361    string func_id;
1362    string outside_compilation_id;
1364        GetFunctionNameAttr(node, &func_id, &outside_compilation_id));
1366    // Don't copy nodes that going to be encapsulated, unless parallel checking
1367    // is enabled.
1368    if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking)
1369      continue;
1371    Node* image = graph_out->CopyNode(node);
1372    if (!outside_compilation_id.empty()) {
1373      if (parallel_checking) {
1374        return errors::InvalidArgument(
1375            "Parallel checking is not supported when outside_compilation "
1376            "clusters are present.");
1377      }
1378      image->ClearAttr(group_attribute_);
1379      image->ClearAttr(outside_compilation_attribute_);
1380    }
1381    (*node_images)[node] = image;
1382  }
1383  (*node_images)[graph_in_->source_node()] = graph_out->source_node();
1384  (*node_images)[graph_in_->sink_node()] = graph_out->sink_node();
1385  return Status::OK();
1388Status Encapsulator::AddFunctionCallNodes(
1389    const std::unordered_map<const Node*, Node*>& node_images,
1390    bool parallel_checking, Graph* graph_out) {
1391  for (auto& subgraph_entry : subgraphs_) {
1392    TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode(
1393        node_images, parallel_checking, graph_out));
1394  }
1395  return Status::OK();
1398Status Encapsulator::AddOutsideCompilationHostIONodes(
1399    const std::unordered_map<const Node*, Node*>& node_images,
1400    Graph* graph_out) {
1401  for (auto& subgraph_entry : subgraphs_) {
1402    const string& subgraph_name = subgraph_entry.first;
1403    Subgraph& subgraph = subgraph_entry.second;
1404    TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes(
1405        subgraph_name, node_images, graph_out));
1406  }
1407  return Status::OK();
1410Status Encapsulator::FindOutputImageOfEdgeSrc(
1411    const string& src_func_id, const string& src_outside_compilation_id,
1412    const string& dst_func_id, const string& dst_outside_compilation_id,
1413    const std::unordered_map<const Node*, Node*>& node_images,
1414    const Node* original_src_node, Node** src_image) {
1415  if (IsInSubgraph(src_func_id, src_outside_compilation_id)) {
1416    if (dst_func_id == src_func_id) {
1417      // The edge is from a subgraph to an outside_compilation cluster in the
1418      // same subgraph so use the appropriate _RecvAtHost node in the output
1419      // graph.
1420      TF_RET_CHECK(!dst_outside_compilation_id.empty());
1421      *src_image = subgraphs_.at(src_func_id)
1422                       .GetRecvAtHostNode(dst_outside_compilation_id);
1423    } else {
1424      // The edge is from a subgraph to a regular node in the output graph so
1425      // use the subgraph's call node output.
1426      *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs();
1427    }
1428  } else {
1429    // The source of the edge is in the output graph so use the node image in
1430    // the output graph.
1431    *src_image = node_images.at(original_src_node);
1432  }
1433  return Status::OK();
1436int Encapsulator::FindOutputSlotOfEdgeSrc(
1437    const string& src_func_id, const string& src_outside_compilation_id,
1438    const string& dst_func_id, const string& dst_outside_compilation_id,
1439    const Edge* edge) {
1440  if (IsInSubgraph(src_func_id, src_outside_compilation_id)) {
1441    const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
1442    if (src_func_id == dst_func_id) {
1443      // 'src' is in a subgraph and 'dst' is outside_compilation in the same
1444      // subgraph. Use the corresponding _RecvAtHost output instead.
1445      return src_subgraph.GetRecvAtHostSlot(dst_outside_compilation_id, edge);
1446    } else {
1447      // 'src' is in a subgraph and 'dst' is a regular node in the output
1448      // graph. Use the corresponding call output instead.
1449      return src_subgraph.GetResultIndexForEdge(edge);
1450    }
1451  } else {
1452    // The source of the edge is in the output graph so use the regular edge
1453    // slot.
1454    return edge->src_output();
1455  }
1458Status Encapsulator::FindOutputImageOfEdgeDst(
1459    const string& src_func_id, const string& src_outside_compilation_id,
1460    const string& dst_func_id, const string& dst_outside_compilation_id,
1461    const std::unordered_map<const Node*, Node*>& node_images,
1462    const Node* original_dst_node, Node** dst_image) {
1463  if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) {
1464    if (src_func_id == dst_func_id) {
1465      // The edge is to a subgraph from an outside_compilation cluster in the
1466      // same subgraph so use the appropriate _SendFromHost node in the output
1467      // graph.
1468      TF_RET_CHECK(!src_outside_compilation_id.empty());
1469      *dst_image = subgraphs_.at(dst_func_id)
1470                       .GetSendFromHostNode(src_outside_compilation_id);
1471    } else {
1472      // The edge is to a subgraph from a regular node in the output graph so
1473      // use the subgraph's call node input.
1474      *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs();
1475    }
1476  } else {
1477    // The destination of the edge is in the output graph so use the node image
1478    // in the output graph.
1479    *dst_image = node_images.at(original_dst_node);
1480  }
1481  return Status::OK();
1484int Encapsulator::FindOutputSlotOfEdgeDst(
1485    const string& src_func_id, const string& src_outside_compilation_id,
1486    const string& dst_func_id, const string& dst_outside_compilation_id,
1487    const Edge* edge) {
1488  if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) {
1489    const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
1490    if (dst_func_id == src_func_id) {
1491      // 'dst' is in a subgraph and 'src' is outside_compilation in the same
1492      // subgraph. Use the corresponding _SendFromHost input instead.
1493      return dst_subgraph.GetSendFromHostSlot(src_outside_compilation_id, edge);
1494    } else {
1495      // 'dst' is in a subgraph and 'src' is a regular node in the output
1496      // graph. Use the corresponding call input instead.
1497      return dst_subgraph.GetArgIndexForEdge(edge);
1498    }
1499  } else {
1500    // The destination of the edge is in the output graph so use the regular
1501    // edge slot.
1502    return edge->dst_input();
1503  }
1506Status Encapsulator::CopyEdgeToOutputGraph(
1507    const Edge* edge, const string& src_func_id,
1508    const string& src_outside_compilation_id, const string& dst_func_id,
1509    const string& dst_outside_compilation_id,
1510    const std::unordered_map<const Node*, Node*>& node_images,
1511    bool parallel_checking, Graph* graph_out,
1512    std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
1513        edges_added) {
1514  Node* src_image;
1515  TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
1516      src_func_id, src_outside_compilation_id, dst_func_id,
1517      dst_outside_compilation_id, node_images, edge->src(), &src_image));
1518  Node* dst_image;
1519  TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst(
1520      src_func_id, src_outside_compilation_id, dst_func_id,
1521      dst_outside_compilation_id, node_images, edge->dst(), &dst_image));
1523  // If this is a control edge then copy it and return. Lift control edges onto
1524  // the enclosing call operator.
1525  if (edge->IsControlEdge()) {
1526    // Add the control edge, if we have not already added it, using the images
1527    // determined above (potentially call operators or RecvAtHost/SendFromHost).
1528    if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1))
1529            .second) {
1530      graph_out->AddControlEdge(src_image, dst_image);
1531    }
1533    // If parallel checking is enabled, also add a control edge to the
1534    // corresponding parallel check op.
1535    if (parallel_checking) {
1536      graph_out->AddControlEdge(src_image, node_images.at(edge->dst()));
1537    }
1538    return Status::OK();
1539  }
1541  int src_output =
1542      FindOutputSlotOfEdgeSrc(src_func_id, src_outside_compilation_id,
1543                              dst_func_id, dst_outside_compilation_id, edge);
1545  int dst_input =
1546      FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id,
1547                              dst_func_id, dst_outside_compilation_id, edge);
1549  if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
1550      parallel_checking) {
1551    // If we are parallel checking, also feed the tensor as an input to the
1552    // corresponding parallel check subgraph.
1553    graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()),
1554                       edge->dst_input());
1555  }
1557  // Add the edge, if we have not already added it.
1558  if (edges_added
1559          ->emplace(NodeSlot(src_image, src_output),
1560                    NodeSlot(dst_image, dst_input))
1561          .second) {
1562    graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
1563  }
1564  return Status::OK();
1567Status Encapsulator::AddEdgesToOutputGraph(
1568    const std::unordered_map<const Node*, Node*>& node_images,
1569    bool parallel_checking, Graph* graph_out) {
1570  // Set of edges already added to the output graph, represented as (src, dst)
1571  // pairs. We use the set to deduplicate edges; multiple edges in the input
1572  // graph may map to one edge in the output graph.
1573  std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>
1574      edges_added;
1576  for (const Edge* edge : graph_in_->edges()) {
1577    string src_func_id;
1578    string src_outside_compilation_id;
1579    TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id,
1580                                           &src_outside_compilation_id));
1581    string dst_func_id;
1582    string dst_outside_compilation_id;
1583    TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id,
1584                                           &dst_outside_compilation_id));
1586    // Ignore edges that are strictly contained within one subgraph, unless
1587    // we are constructing parallel check graphs.
1588    if (IsInSubgraph(src_func_id, src_outside_compilation_id) &&
1589        IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
1590        src_func_id == dst_func_id) {
1591      if (parallel_checking) {
1592        Node* src_image = node_images.at(edge->src());
1593        Node* dst_image = node_images.at(edge->dst());
1594        if (edge->IsControlEdge()) {
1595          graph_out->AddControlEdge(src_image, dst_image);
1596        } else {
1597          graph_out->AddEdge(src_image, edge->src_output(), dst_image,
1598                             edge->dst_input());
1599        }
1600      }
1601      continue;
1602    }
1604    // We have an edge that crosses a cluster boundary or is entirely within the
1605    // unclustered graph.
1606    TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
1607        edge, src_func_id, src_outside_compilation_id, dst_func_id,
1608        dst_outside_compilation_id, node_images, parallel_checking, graph_out,
1609        &edges_added));
1610  }
1612  for (auto& subgraph_entry : subgraphs_) {
1613    Subgraph& subgraph = subgraph_entry.second;
1614    subgraph.ConnectSequencerToOutputs(graph_out);
1615  }
1617  return Status::OK();
1620namespace {
1622// Adds a dummy Const node to graph_out. The "constant" has the type of
1623// data_type and the shape indicated in 'shape'. The dummy node is not a valid
1624// Const node because it does not have any value defined, but this doesn't
1625// matter because it will only be used subsequently for shape inference. (It
1626// would be possible to add a switch statement over data_type to create a value
1627// for the constant, but that would entail maintaining the logic as new types
1628// are added, and is not necessary.)
1629Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
1630                         Graph* graph_out) {
1631  TensorProto dummy_proto;
1632  dummy_proto.set_dtype(data_type);
1633  *dummy_proto.mutable_tensor_shape() = shape;
1634  // Don't set any value field in the proto, since it is only going to be used
1635  // for shape inference.
1637  GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
1638  NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
1639                           options.op_registry());
1640  node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
1641  return options.FinalizeBuilder(&node_builder);
1644// Adds a copy of node_in to graph_out and adds the mapping to
1645// copied_node_images.
1646Status CopyShapeInferenceNodeToGraph(
1647    Node* node_in, const Node* send_node,
1648    const std::unordered_map<Node*, Node*>& dummy_node_images,
1649    FunctionLibraryDefinition* library,
1650    std::unordered_map<Node*, Node*>* copied_node_images, Graph* graph_out) {
1651  // Once all the ancestor nodes have been added to graph_out, add this node
1652  // and connect it to its ancestors.
1653  Node* node_out = graph_out->CopyNode(node_in);
1654  (*copied_node_images)[node_in] = node_out;
1655  // Don't bother to build the shape inference graph if there's a node with no
1656  // shape inference function, since it would just result in an error later at
1657  // compile time.
1658  const OpRegistrationData* op_reg_data;
1659  TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data));
1660  if (op_reg_data->shape_inference_fn == nullptr) {
1661    return errors::InvalidArgument(
1662        "Shape inference is not possible for outside_compilation "
1663        "SendFromHost node ",
1664        send_node->name(), " because it depends on node ", node_in->name(),
1665        " which does not have a shape inference function registered.");
1666  }
1667  // Add all the edges to the newly copied node.
1668  for (const Edge* in_edge : node_in->in_edges()) {
1669    if (!in_edge->IsControlEdge()) {
1670      Node* src = in_edge->src();
1671      const auto iter = dummy_node_images.find(src);
1672      if (iter == dummy_node_images.end()) {
1673        // The src is a copied node so use the original output port.
1674        graph_out->AddEdge((*copied_node_images)[in_edge->src()],
1675                           in_edge->src_output(), node_out,
1676                           in_edge->dst_input());
1677      } else {
1678        // The src is a dummy node so use output port 0.
1679        graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input());
1680      }
1681    }
1682  }
1683  return Status::OK();
1686}  // namespace
1688Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
1689    const Graph& graph_in, const ShapeRefiner& shape_refiner,
1690    const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
1691    FunctionLibraryDefinition* library,
1692    std::vector<TensorShapeProto>* static_shape_out,
1693    std::unique_ptr<GraphDef>* graphdef_out) {
1694  // Maps from nodes in graph_in to nodes in graph_out.
1695  //
1696  // When an edge has fully defined shape the source node in graph_in is
1697  // replaced in graph_out by a dummy constant node. The mapping from nodes
1698  // in graph_in to dummy nodes is stored in dummy_node_images.
1699  //
1700  // When a node in graph_in has at least one ancestor that doesn't have fully
1701  // defined shape, it is copied into graph_out. The mapping from nodes in
1702  // graph_in to copied nodes is stored in copied_node_images.
1703  //
1704  // The two types of node are treated differently because, when adding edges to
1705  // graph_out, an output from a dummy node always uses port 0, whereas an
1706  // output from a copied node uses the same port that was used in graph_in.
1707  std::unordered_map<Node*, Node*> dummy_node_images;
1708  std::unordered_map<Node*, Node*> copied_node_images;
1710  std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry()));
1711  graph_out->set_versions(graph_in.versions());
1712  static_shape_out->resize(send_node->num_inputs());
1714  // We don't use the standard ReverseDFS because we want to cut off traversal
1715  // whenever we find an output with fully defined shape.
1716  // TODO(misard) make this work properly in the presence of control flow.
1717  struct Work {
1718    Node* node;
1719    bool leave;  // Are we entering or leaving node?
1720  };
1721  std::vector<Work> stack({{send_node, false}});
1722  std::vector<bool> visited(graph_in.num_node_ids(), false);
1723  while (!stack.empty()) {
1724    Work w = stack.back();
1725    stack.pop_back();
1726    Node* n = w.node;
1728    if (w.leave) {
1729      TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph(
1730          n, send_node, dummy_node_images, library, &copied_node_images,
1731          graph_out.get()));
1732    } else {
1733      if (visited[n->id()]) continue;
1734      visited[n->id()] = true;
1736      // Arrange to revisit when all done with all inputs.
1737      stack.push_back(Work{n, true});
1739      bool has_parent_with_unknown_shape = false;
1740      for (const Edge* in_edge : n->in_edges()) {
1741        if (!in_edge->IsControlEdge()) {
1742          Node* src_node = in_edge->src();
1743          int src_port = in_edge->src_output();
1744          shape_inference::InferenceContext* context =
1745              shape_refiner.GetContext(src_node);
1746          shape_inference::ShapeHandle shape = context->output(src_port);
1747          if (context->FullyDefined(shape)) {
1748            // This ancestor has known shape, so instead of adding it to the
1749            // stack, add a dummy node with that shape to graph_out and
1750            // continue.
1751            TensorShapeProto proto;
1752            context->ShapeHandleToProto(shape, &proto);
1753            dummy_node_images[src_node] = AddDummyShapedNode(
1754                src_node->output_type(src_port), proto, graph_out.get());
1755            if (n == send_node) {
1756              (*static_shape_out)[in_edge->dst_input()] = proto;
1757            }
1758          } else {
1759            if (!visited[src_node->id()]) {
1760              has_parent_with_unknown_shape = true;
1761              stack.push_back({src_node, false});
1762            }
1763          }
1764        }
1765      }
1766      if (!has_parent_with_unknown_shape) {
1767        if (n == send_node) {
1768          // The shapes of all the inputs to send_node are statically known. We
1769          // won't have to do any inference at compile time so return now: the
1770          // shapes were stored in static_shape_out above.
1771          graphdef_out->reset();
1772          return Status::OK();
1773        } else {
1774          // Any shape that is being processed is either the original send node
1775          // or has at least one output with statically-unknown shape. If the
1776          // latter and it doesn't have any inputs with statically-unknown
1777          // shape, then check that it is of the recv nodes that we can fill in
1778          // the shape of at run-time later. If it isn't one of those, then we
1779          // won't have any additional knowledge at compile time, so we already
1780          // know we won't be able to do shape inference and we can return an
1781          // error now.
1782          if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) {
1783            return errors::InvalidArgument(
1784                "Shape inference is not possible for outside_compilation "
1785                "SendFromHost node ",
1786                send_node->name(), " because shape of node ", n->name(),
1787                " will not be known at compilation time.");
1788          }
1789        }
1790      }
1791    }
1792  }
1794  graphdef_out->reset(new GraphDef());
1795  graph_out->ToGraphDef(graphdef_out->get());
1797  return Status::OK();
1800Status Encapsulator::MakePrunedGraphCopyAndInline(
1801    const Graph& graph, const std::vector<Node*>& sink_nodes,
1802    std::unique_ptr<Graph>* pruned_graph,
1803    std::unordered_map<const Node*, Node*>* node_images,
1804    FunctionLibraryDefinition* library) {
1805  // First copy all ancestor nodes of sink_nodes into a new graph.
1806  pruned_graph->reset(new Graph(library));
1807  (*pruned_graph)->set_versions(graph.versions());
1808  ReverseDFSFrom(graph, sink_nodes,
1809                 /*enter=*/nullptr,
1810                 /*leave=*/[&](Node* n) {
1811                   if (!n->IsSource()) {
1812                     Node* copied = (*pruned_graph)->CopyNode(n);
1813                     node_images->emplace(n, copied);
1814                   }
1815                 });
1817  // Add all the edges between copied nodes.
1818  for (auto entry : *node_images) {
1819    const Node* orig = entry.first;
1820    Node* image = entry.second;
1821    for (const Edge* out_edge : orig->out_edges()) {
1822      auto iter = node_images->find(out_edge->dst());
1823      if (iter != node_images->end()) {
1824        // The source and destination are both in the copied graph.
1825        (*pruned_graph)
1826            ->AddEdge(image, out_edge->src_output(), iter->second,
1827                      out_edge->dst_input());
1828      }
1829    }
1830  }
1832  // Find all the function call nodes, and inline them.
1833  std::vector<Node*> function_nodes;
1834  for (auto node : (*pruned_graph)->nodes()) {
1835    const OpRegistrationData* op_reg_data;
1836    TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
1837    if (op_reg_data->is_function_op) {
1838      function_nodes.push_back(node);
1839    }
1840  }
1841  for (auto node : function_nodes) {
1842    VLOG(2) << "Inlining function " << node->name();
1843    const FunctionDef* fdef = library->Find(node->type_string());
1844    if (fdef == nullptr) {
1845      return errors::Internal("Failed to find function ", node->type_string(),
1846                              " in function library.");
1847    }
1848    FunctionBody* fbody = nullptr;
1850        FunctionDefToBodyHelper(*fdef, node->attrs(), library,
1851                                [library](const string& op, const OpDef** sig) {
1852                                  return library->LookUpOpDef(op, sig);
1853                                },
1854                                &fbody));
1855    InlineFunctionBody(*library, pruned_graph->get(), node, fbody);
1856    delete fbody;
1857  }
1859  return Status::OK();
1862Status Encapsulator::MakeGraphForOutsideCompilationSends(
1863    const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
1864    ShapeRefiner* shape_refiner,
1865    std::unordered_map<const Node*, Node*>* node_images,
1866    FunctionLibraryDefinition* library) {
1867  // Find all the send_from_host nodes in all subgraphs, to use as roots for the
1868  // pruning.
1869  std::vector<Node*> send_from_host_nodes;
1870  for (auto& subgraph_entry : subgraphs_) {
1871    Subgraph& subgraph = subgraph_entry.second;
1872    std::vector<string> outside_compilation_names;
1873    subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
1874    for (const auto& name : outside_compilation_names) {
1875      Node* send_node = subgraph.GetSendFromHostNode(name);
1876      if (send_node != nullptr) {
1877        send_from_host_nodes.push_back(send_node);
1878      }
1879    }
1880  }
1882  // Make a copy of all the graph nodes needed to evaluate the send_from_host
1883  // nodes, inlining any functions as needed.
1884  TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
1885      graph, send_from_host_nodes, pruned_graph, node_images, library));
1887  // Perform shape inference on the pruned graph.
1888  shape_refiner->set_require_shape_inference_fns(false);
1889  FixupSourceAndSinkEdges(pruned_graph->get());
1890  std::vector<Node*> post_order;
1891  GetReversePostOrder(*(*pruned_graph), &post_order);
1892  for (auto node : post_order) {
1893    // Ignore the status returned by the shape_refiner. At this point we want
1894    // the best effort shapes, even if no shape function is registered for a
1895    // node.
1896    Status status = shape_refiner->AddNode(node);
1897    if (!status.ok()) {
1898      VLOG(1) << "Shape inference failed for node: " << status;
1899    }
1900  }
1902  return Status::OK();
1905Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
1906    Graph* graph_out, FunctionLibraryDefinition* library) {
1907  std::unique_ptr<Graph> pruned_graph;
1908  ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
1909  std::unordered_map<const Node*, Node*> node_images;
1910  TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
1911      *graph_out, &pruned_graph, &shape_refiner, &node_images, library));
1913  for (auto& subgraph_entry : subgraphs_) {
1914    Subgraph& subgraph = subgraph_entry.second;
1915    // Find all the recv_at_host nodes in this subgraph.
1916    std::vector<string> outside_compilation_names;
1917    subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
1918    std::unordered_set<string> recv_at_host_names;
1919    for (const auto& name : outside_compilation_names) {
1920      Node* recv_node = subgraph.GetRecvAtHostNode(name);
1921      if (recv_node != nullptr) {
1922        recv_at_host_names.insert(recv_node->name());
1923      }
1924    }
1925    // For each send_from_host node, do as much shape inference as possible
1926    // without knowing the shape of the recv_at_host nodes, and store the
1927    // result, along with enough information to complete the job at compile time
1928    // once the recv_at_host shapes are known.
1929    for (const auto& name : outside_compilation_names) {
1930      Node* send_node = subgraph.GetSendFromHostNode(name);
1931      std::vector<TensorShapeProto> static_shape;
1932      std::unique_ptr<GraphDef> graphdef;
1933      if (send_node != nullptr) {
1934        TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
1935            *pruned_graph, shape_refiner, recv_at_host_names,
1936            node_images[send_node], library, &static_shape, &graphdef));
1937        if (graphdef == nullptr) {
1938          VLOG(2) << "Send node  " << send_node->name() << " shapes";
1939          for (int i = 0; i < static_shape.size(); ++i) {
1940            VLOG(2) << static_shape[i].DebugString();
1941          }
1942        } else {
1943          VLOG(2) << "Send node " << send_node->name() << " graph\n"
1944                  << graphdef->DebugString();
1945        }
1946      }
1948          subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get()));
1949    }
1950    if (!outside_compilation_names.empty()) {
1951      TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library));
1952    }
1953  }
1955  return Status::OK();
1958Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out,
1959                                      FunctionLibraryDefinition* library) {
1960  // Map from nodes in the input graph to nodes in the output graph.
1961  std::unordered_map<const Node*, Node*> node_images;
1964      CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images));
1966      AddFunctionCallNodes(node_images, parallel_checking, graph_out));
1967  TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out));
1969      AddEdgesToOutputGraph(node_images, parallel_checking, graph_out));
1972      GetShapeInfoForOutsideCompilationSends(graph_out, library));
1974  return Status::OK();
1977}  // anonymous namespace
1979Status EncapsulateSubgraphsInFunctions(
1980    string group_attribute, string outside_compilation_attribute,
1981    const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
1982    bool parallel_checking, bool reuse_existing_functions,
1983    std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
1984  Status s;
1986  Encapsulator encapsulator(std::move(group_attribute),
1987                            std::move(outside_compilation_attribute),
1988                            &graph_in);
1989  TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs());
1991  TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
1992      rewrite_subgraph_fn, reuse_existing_functions, library));
1994  std::unique_ptr<Graph> out(new Graph(library));
1995  out->set_versions(graph_in.versions());
1997      encapsulator.BuildOutputGraph(parallel_checking, out.get(), library));
1999  *graph_out = std::move(out);
2000  return Status::OK();
2003// Finds the types of the _Arg nodes, indexed by position.
2004static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
2005  for (Node* n : graph.op_nodes()) {
2006    if (n->type_string() == kArgOp) {
2007      int index;
2008      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
2009      if (index < 0 || index >= types->size()) {
2010        return errors::InvalidArgument("Invalid argument number");
2011      }
2012      (*types)[index] = n->output_type(0);
2013    }
2014  }
2015  return Status::OK();
2018// Renumber the indices of _Arg nodes in a graph, according to
2019// 'permutation' that maps old indices to new indices.
2020static Status RenumberArguments(Graph* graph,
2021                                const std::vector<int>& permutation) {
2022  for (Node* n : graph->op_nodes()) {
2023    if (n->type_string() == kArgOp) {
2024      int index;
2025      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
2026      if (index < 0 || index >= permutation.size()) {
2027        return errors::InvalidArgument("Invalid argument number");
2028      }
2029      n->AddAttr("index", permutation[index]);
2030    }
2031  }
2032  return Status::OK();
2035Status EncapsulateSubgraphsPass::Run(
2036    const GraphOptimizationPassOptions& options) {
2037  VLOG(1) << "EncapsulateSubgraphsPass::Run";
2038  legacy_flags::EncapsulateSubgraphsPassFlags* flags =
2039      legacy_flags::GetEncapsulateSubgraphsPassFlags();
2040  if (VLOG_IS_ON(1)) {
2041    dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph,
2042                                options.flib_def);
2043  }
2045  std::unique_ptr<Graph> graph_out;
2046  FunctionLibraryDefinition* const library = options.flib_def;
2048  OptimizerOptions opts;
2049  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
2050      new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env,
2051                                        TF_GRAPH_DEF_VERSION, library, opts));
2052  FunctionLibraryRuntime* flr =
2053      pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
2055  auto rewrite_subgraph = [flr](std::unique_ptr<Graph>* subgraph,
2056                                std::vector<int>* input_permutation,
2057                                std::vector<int>* output_permutation,
2058                                NodeDef* node) {
2059    // Optimize the subgraph.
2060    OptimizeGraph(flr, subgraph);
2062    const int num_args = input_permutation->size();
2063    std::vector<bool> const_args(num_args);
2064    TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
2066    DataTypeVector arg_types(num_args);
2067    TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
2069    // Compute a permutation of the arguments such that the constant arguments
2070    // are first.
2071    const int num_consts =
2072        std::count(const_args.begin(), const_args.end(), true);
2074    const int num_resources =
2075        std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
2076    const int num_nonconsts = num_args - num_resources - num_consts;
2077    if (num_nonconsts < 0) {
2078      return errors::Internal("num_nonconsts should be >= 0, was ",
2079                              num_nonconsts);
2080    }
2082    int const_pos = 0;
2083    int arg_pos = num_consts;
2084    int resource_pos = num_consts + num_nonconsts;
2085    for (int i = 0; i < num_args; ++i) {
2086      if (const_args[i]) {
2087        if (arg_types[i] == DT_RESOURCE) {
2088          return errors::Internal(
2089              "Resource arguments cannot be constant (argument ", i, ")");
2090        }
2091        (*input_permutation)[i] = const_pos;
2092        ++const_pos;
2093      } else if (arg_types[i] == DT_RESOURCE) {
2094        (*input_permutation)[i] = resource_pos;
2095        ++resource_pos;
2096      } else {
2097        (*input_permutation)[i] = arg_pos;
2098        ++arg_pos;
2099      }
2100    }
2102    // Renumber argument nodes in the graph.
2103    TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation));
2105    // TODO(phawkins): add a forward is-constant analysis, similarly split
2106    // outputs into host-memory constants and device-memory non-constants.
2108    AddNodeAttr(kXlaCompiledKernelAttr, true, node);
2109    AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
2110    AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
2111    return Status::OK();
2112  };
2114  TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
2115      kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph,
2116      rewrite_subgraph, flags->tf_xla_parallel_checking,
2117      /*reuse_existing_functions=*/false, &graph_out, library));
2119  if (VLOG_IS_ON(1)) {
2120    dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
2121                                options.flib_def);
2122  }
2124  *options.graph = std::move(graph_out);
2125  return Status::OK();
2128bool IsXlaCompiledKernel(const Node& node) {
2129  bool is_compiled = false;
2130  bool has_compilation_attr =
2131      GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
2132      is_compiled;
2133  return has_compilation_attr ? is_compiled : false;
2136}  // namespace tensorflow