1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17
18#include <unistd.h>
19#include <algorithm>
20#include <atomic>
21#include <deque>
22#include <map>
23#include <memory>
24#include <string>
25#include <tuple>
26#include <unordered_map>
27#include <vector>
28
29#include "tensorflow/compiler/xla/layout_util.h"
30#include "tensorflow/compiler/xla/literal_util.h"
31#include "tensorflow/compiler/xla/service/hlo_module.h"
32#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
33#include "tensorflow/compiler/xla/shape_util.h"
34#include "tensorflow/compiler/xla/types.h"
35#include "tensorflow/compiler/xla/window_util.h"
36#include "tensorflow/core/lib/core/status.h"
37#include "tensorflow/core/lib/gtl/map_util.h"
38#include "tensorflow/core/lib/gtl/optional.h"
39#include "tensorflow/core/lib/io/path.h"
40#include "tensorflow/core/lib/strings/numbers.h"
41#include "tensorflow/core/lib/strings/str_util.h"
42#include "tensorflow/core/lib/strings/strcat.h"
43#include "tensorflow/core/lib/strings/stringprintf.h"
44#include "tensorflow/core/platform/env.h"
45#include "tensorflow/core/platform/protobuf.h"
46#include "tensorflow/core/platform/regexp.h"
47
48using ::tensorflow::Env;
49using ::tensorflow::WriteStringToFile;
50using ::tensorflow::gtl::nullopt;
51using ::tensorflow::gtl::optional;
52using ::tensorflow::io::JoinPath;
53using ::tensorflow::str_util::Join;
54using ::tensorflow::str_util::StringReplace;
55using ::tensorflow::strings::StrAppend;
56using ::tensorflow::strings::StrCat;
57
58namespace xla {
59namespace hlo_graph_dumper {
60namespace {
61
62// Helpers for Printf and Appendf.
63template <typename T>
64struct PrintfConvert {
65  const T& operator()(const T& t) const { return t; }
66};
67template <>
68struct PrintfConvert<string> {
69  const char* operator()(const string& s) const { return s.c_str(); }
70};
71
72// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str()
73// on strings.
74template <typename... Ts>
75string Printf(const char* fmt, const Ts&... ts) {
76  return tensorflow::strings::Printf(fmt, PrintfConvert<Ts>()(ts)...);
77}
78template <typename... Ts>
79void Appendf(string* s, const char* fmt, const Ts&... ts) {
80  tensorflow::strings::Appendf(s, fmt, PrintfConvert<Ts>()(ts)...);
81}
82
83// Used to indicate how we should treat a given HLOInstruction in the graph.
84// should we treat it like normal, hide it, and so on?
85enum NodeFilterResult {
86  kNormalNode,
87  kHideNode,
88  // Make the node easy to find in the final graph.
89  kHighlightNode,
90  // "Gray out" the node to indicate that some of its operands have been
91  // omitted.
92  kSomeOperandsOmitted,
93  // Style the node the same as kSomeOperandsOmitted, but also don't connect it
94  // to its operands, even if they're present in the graph.
95  kOmitNodeOperands,
96  // Same style as kSomeOperandsOmitted, but used to indicate that some of the
97  // node's *users* have been omitted.
98  kSomeUsersOmitted,
99};
100
101// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
102// It lets callers tell the graph-drawing routines which nodes they want to be
103// shown, hidden, or highlighted.
104class NodeFilter {
105 public:
106  NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
107
108  explicit NodeFilter(
109      std::function<NodeFilterResult(const HloInstruction* instr)> filter)
110      : filter_(std::move(filter)) {}
111
112  bool Show(const HloInstruction* instr) const {
113    return filter_(instr) != kHideNode;
114  }
115  bool Highlight(const HloInstruction* instr) const {
116    return filter_(instr) == kHighlightNode;
117  }
118  bool OmitOperands(const HloInstruction* instr) const {
119    return filter_(instr) == kOmitNodeOperands;
120  }
121  bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
122    auto result = filter_(instr);
123    return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
124  }
125  bool Deemphasized(const HloInstruction* instr) const {
126    auto result = filter_(instr);
127    return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
128           result == kSomeUsersOmitted;
129  }
130
131  bool ShowFusionSubcomputation(const HloInstruction* instr) const {
132    CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
133    return Show(instr) && !SomeOrAllOperandsOmitted(instr);
134  }
135
136 private:
137  std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
138};
139
140// Node color schemes, used by NodeColorAttributes.
141enum ColorScheme {
142  kBlue,
143  kBrown,
144  kDarkBlue,
145  kDarkGreen,
146  kDarkRed,
147  kGray,
148  kGreen,
149  kOrange,
150  kPurple,
151  kRed,
152  kWhite,
153  kYellow,
154
155  // Causes the node's border to be a dashed line, and its content to be gray
156  // text on a white background, suggesting that this is an "unimportant" node.
157  kDashedBorder,
158};
159
160// Given a ColorScheme, returns an attribute string for a node of that color.
161// Sets the node's style and fill/stroke/text colors.
162//
163// Colors are from https://material.io/color.
164string NodeColorAttributes(ColorScheme color) {
165  using std::make_tuple;
166
167  const char *style, *fill_color, *stroke_color, *font_color;
168  std::tie(style, fill_color, stroke_color, font_color) = [color] {
169    switch (color) {
170      case kBlue:
171        return make_tuple("filled", "#bbdefb", "#8aacc8", "black");
172      case kBrown:
173        return make_tuple("filled", "#bcaaa4", "#8c7b75", "black");
174      case kDarkBlue:
175        return make_tuple("filled", "#1565c0", "#003c8f", "white");
176      case kDarkGreen:
177        return make_tuple("filled", "#2e7d32", "#005005", "white");
178      case kDarkRed:
179        return make_tuple("filled", "#b71c1c", "#7f0000", "white");
180      case kGray:
181        return make_tuple("filled", "#cfd8dc", "#9ea7aa", "black");
182      case kGreen:
183        return make_tuple("filled", "#c8e6c9", "#97b498", "black");
184      case kOrange:
185        return make_tuple("filled", "#ffe0b2", "#cbae82", "black");
186      case kPurple:
187        return make_tuple("filled", "#e1bee7", "#af8eb5", "black");
188      case kRed:
189        return make_tuple("filled", "#ffcdd2", "#cb9ca1", "black");
190      case kWhite:
191        return make_tuple("filled", "white", "black", "black");
192      case kYellow:
193        return make_tuple("filled", "#fff9c4", "#cbc693", "black");
194      case kDashedBorder:
195        // "filled,dashed" looks the same as "dashed", since we have a white
196        // background.  But we use "filled,dashed" so that when you hover over
197        // any part of the node (not just the text inside the node), our css
198        // :hover rule is triggered.
199        return make_tuple("filled,dashed", "white", "#757575", "#757575");
200    }
201  }();
202
203  return Printf(
204      R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", style,
205      font_color, stroke_color, fill_color);
206}
207
208// Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
209// graphviz HTML-like string.
210string HtmlLikeStringSanitize(tensorflow::StringPiece s) {
211  return StringReplace(StringReplace(s, "<", "&lt;", /*replace_all=*/true), ">",
212                       "&gt;", /*replace_all=*/true);
213}
214
215// Tries to generates a human-readable one-word description of the given
216// computation.
217//
218// Currently we support:
219//
220//   "return param0 + param1;"      --> "add"
221//   "return param0 * param1;"      --> "multiply"
222//   "return min(param0, param1);"  --> "min"
223//   "return max(param0, param1);"  --> "max"
224//   "return param0 <= param1;"     --> "less-or-equal"
225//   "return param0 >= param1;"     --> "greater-or-equal"
226//   "return param0 >  param1;"     --> "greater-than"
227//   "return param0 <  param1;"     --> "less-than"
228//   "return param0 == param1;"     --> "equal-to"
229//   "return param0 != param1;"     --> "not-equal-to"
230//
231// where param0 and param1 are effective scalars.  For the ops that are
232// commutative, we also support them with param0 and param1 swapped.
233//
234// This is useful primarily for reduce and map nodes.  These take a
235// subcomputation which is almost always one of the above, and pattern matching
236// it to a short string lets us tell the user what the subcomputation is without
237// drawing it as a graph.
238optional<string> MatchTrivialComputation(const HloComputation* computation) {
239  if (computation->instruction_count() != 3) {
240    return nullopt;
241  }
242
243  HloInstruction* root = computation->root_instruction();
244  if (root->operand_count() != 2) {
245    return nullopt;
246  }
247
248  // Check that both of the operands to the root are parameters.
249  const HloInstruction* operand0 = root->operand(0);
250  const HloInstruction* operand1 = root->operand(1);
251  if (operand0->opcode() != HloOpcode::kParameter ||
252      operand1->opcode() != HloOpcode::kParameter) {
253    return nullopt;
254  }
255
256  // Check that the two operands of root are param0 and param1.  All of the
257  // opcodes we recognize are commutative, so we're OK with either order.
258  auto n0 = operand0->parameter_number();
259  auto n1 = operand1->parameter_number();
260  if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) {
261    return nullopt;
262  }
263
264  // If the params are reversed, check that the operation being performed is
265  // commutative.
266  if (n0 == 1) {
267    switch (root->opcode()) {
268      case HloOpcode::kLe:
269      case HloOpcode::kGe:
270      case HloOpcode::kGt:
271      case HloOpcode::kLt:
272        return nullopt;
273      default:
274        break;
275    }
276  }
277
278  // Check that the root and params are all effective scalars.
279  if (!ShapeUtil::IsEffectiveScalar(root->shape()) ||
280      !ShapeUtil::IsEffectiveScalar(operand0->shape()) ||
281      !ShapeUtil::IsEffectiveScalar(operand1->shape())) {
282    return nullopt;
283  }
284
285  // If we recognize the root's opcode, we've successfully pattern-matched!
286  switch (root->opcode()) {
287    case HloOpcode::kAdd:
288      return "add";
289    case HloOpcode::kMultiply:
290      return "multiply";
291    case HloOpcode::kMinimum:
292      return "min";
293    case HloOpcode::kMaximum:
294      return "max";
295    case HloOpcode::kLe:
296      return "less-or-equal";
297    case HloOpcode::kGe:
298      return "greater-or-equal";
299    case HloOpcode::kGt:
300      return "greater-than";
301    case HloOpcode::kLt:
302      return "less-than";
303    case HloOpcode::kEq:
304      return "equal-to";
305    case HloOpcode::kNe:
306      return "not-equal-to";
307    default:
308      return nullopt;
309  }
310}
311
312// Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
313class HloDotDumper {
314 public:
315  HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
316               const DebugOptions& debug_options, bool show_metadata,
317               const HloExecutionProfile* profile, NodeFilter filter)
318      : computation_(computation),
319        label_(label.ToString()),
320        debug_options_(debug_options),
321        show_metadata_(show_metadata),
322        profile_(profile),
323        filter_(std::move(filter)) {}
324
325  string Dump();
326
327 private:
328  // Returns the dot graph identifier for the given instruction.
329  string InstructionId(const HloInstruction* instruction) {
330    return StrCat(reinterpret_cast<uint64>(instruction));
331  }
332
333  // Returns the dot graph identifier for the given computation.
334  string SubcomputationId(const HloComputation* computation) {
335    return StrCat("cluster_", reinterpret_cast<uint64>(computation));
336  }
337
338  // Generates graph header/footer.  These should be called *after* dumping all
339  // of the instructions and subcomputations for the graph, as they both use
340  // data generated while dumping the graph.
341  string Header();
342  string Footer();
343
344  bool ShouldShowSubcomputation(const HloComputation* subcomp);
345  bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
346
347  // We omit some nodes from the graph, instead drawing them inlined into the
348  // nodes that use them.
349  bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
350
351  string DumpSubcomputation(const HloComputation* subcomp,
352                            const HloInstruction* parent_instr);
353  string DumpComputation(const HloComputation* comp);
354  string DumpRootTag();
355  string DumpInstruction(const HloInstruction* instr);
356  ColorScheme GetInstructionColor(const HloInstruction* instr);
357  string GetInstructionNodeShape(const HloInstruction* instr);
358  string GetInstructionNodeLabel(const HloInstruction* instr);
359  string GetInstructionNodeMetadata(const HloInstruction* instr);
360  string GetInstructionNodeExtraInfo(const HloInstruction* instr);
361  string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
362  void AddInstructionIncomingEdges(const HloInstruction* instr);
363
364  // For most instructions, GetNodeForEdge(instr) returns instr.
365  //
366  // The exception is fusion nodes.  For these, we walk up the chain of nested
367  // fusion nodes starting at instr until we reach a node that either (a) isn't
368  // a fusion node, or (b) is a fusion node for which
369  // ShouldShowFusionSubcomputation is false.
370  //
371  // We do this because fusion nodes are expanded inline -- if
372  // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
373  // the graph.
374  //
375  // In general when you want to draw an edge from A to B, you should actually
376  // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
377  const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
378
379  // If instr has just one computation and it's trivial (e.g. "return param0 +
380  // param1"), returns a string you can put into the node's body that names the
381  // subcomputation, e.g. "Subcomputation: <b>add</b>".
382  string GetInstructionTrivialComputationStr(const HloInstruction* instr);
383
384  const HloComputation* computation_;  // never null
385  const string label_;                 // overall name for the graph
386  const DebugOptions& debug_options_;
387  const bool show_metadata_;
388  const HloExecutionProfile* profile_;  // may be null
389  const NodeFilter filter_;
390
391  // Each HloInstruction dumped gets a monotically-increasing node ID.  This
392  // must start at 1, because that's where graphviz's accounting starts.
393  int64 next_node_id_ = 1;
394  std::unordered_map<const HloInstruction*, int64> node_ids_;
395
396  // The "root" tag doesn't have an associated HloInstruction pointer, so we
397  // need to store it outside the map.
398  int64 root_node_id_;
399
400  // Each (from, to) edge gets a monotonically-increasing ID.  This is a
401  // multimap because it's possible for the same edge to appear multiple times
402  // in the graph (e.g. x^2 may be represented as mul(x, x)).
403  int64 next_edge_id_ = 1;
404  std::unordered_multimap<
405      std::pair<const HloInstruction*, const HloInstruction*>, int64,
406      tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
407      edge_ids_;
408
409  // Each HloComputation that's emitted gets a monotonically-increasing ID.
410  int64 next_cluster_id_ = 1;
411  std::unordered_map<const HloComputation*, int64> cluster_ids_;
412
413  // Edges to print from Footer().  Edges come at the end because graphviz is
414  // unhappy if an edge from a subcomputation to a node in the outer computation
415  // appears before both the inner computation and the destination node are
416  // defined.
417  std::vector<string> edges_;
418
419  // When coloring by sharding information, we track the sharding string
420  // representation to color association, by round-robin the color schemes.
421  std::unordered_map<string, ColorScheme> sharding_colors_;
422  int64 next_shard_color_ = 0;
423};
424
425string HloDotDumper::Dump() {
426  string body;
427  StrAppend(&body, DumpComputation(computation_));
428  StrAppend(&body, DumpRootTag());
429
430  // By contract, Header() and Footer() have to be called after we've dumped all
431  // our instructions, because they use state generated during that process.
432  string g = Header();
433  StrAppend(&g, body);
434  StrAppend(&g, Footer());
435  return g;
436}
437
438string HloDotDumper::Header() {
439  const char* fmt = R"(digraph G {
440rankdir = TB;
441compound = true;
442label = <<b>%s</b>>;
443labelloc = t;
444// Disable the tooltip.  Interestingly, "" doesn't work!
445tooltip = " ";
446// DOT graphs accept a stylesheet as a URI.  So naturally, an inline
447// stylesheet is a data URI!
448stylesheet="
449  data:text/css,
450  @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
451  svg text {
452    font-family: 'Roboto';
453    font-size: 12px;
454  }
455
456  %s
457"
458
459)";
460
461  VLOG(3) << "Generating Header";
462
463  string graph_label =
464      StrCat(label_, "<br/>Computation ", computation_->name());
465  if (computation_->IsFusionComputation()) {
466    StrAppend(&graph_label,
467              StrCat(" (in fusion instruction ",
468                     computation_->FusionInstruction()->name(), ")"));
469  }
470  if (profile_ != nullptr) {
471    auto cycles = profile_->total_cycles_executed(*computation_);
472    Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles,
473            tensorflow::strings::HumanReadableNum(cycles));
474  }
475
476  // Create CSS rules that say, when you hover over the given node or cluster,
477  // turn the given edge the given color.
478  //
479  // We rely on a few properties of how graphviz generates SVGs:
480  //
481  //  - Nodes are named "nodeN", where N corresponds to the 1-based index of
482  //    the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
483  //    Edges are similarly named "edgeN", and clusters are named "clustN".
484  //  - Nodes come before their in- and out-edges in the SVG.  We need this
485  //    because the "X ~ Y" CSS selector finds a sibling of X that *comes
486  //    after X in the DOM* and matches Y.
487  std::vector<string> edge_css_rules;
488  const char* kBlue = "#1976d2";
489  const char* kRed = "#d32f2f";
490  for (const auto& kv : edge_ids_) {
491    const HloInstruction* from_node = kv.first.first;
492    const HloInstruction* to_node = kv.first.second;
493    int64 edge_id = kv.second;
494
495    auto add_hover_css_rule = [&](string elem_type, int64 elem_id,
496                                  const char* color) {
497      // One could imagine other ways of writing this CSS rule that involve
498      // less duplication, but this way seems to be relatively performant.
499      edge_css_rules.push_back(
500          Printf("  #%s%d:hover ~ #edge%lld text { fill: %s; }\n"
501                 "  #%s%d:hover ~ #edge%lld path { "
502                 "stroke: %s; stroke-width: .2em; }\n"
503                 "  #%s%d:hover ~ #edge%lld polygon { "
504                 "fill: %s; stroke: %s; stroke-width: .2em; }\n",
505                 elem_type, elem_id, edge_id, color,  //
506                 elem_type, elem_id, edge_id, color,  //
507                 elem_type, elem_id, edge_id, color, color));
508    };
509
510    // The "to_node" value may be a NULL, indicating that this points to the
511    // "root" tag rather than a normal node.
512    int64 from_node_id =
513        tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
514    if (from_node_id == -1) {
515      LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
516    }
517    int64 to_node_id =
518        to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
519                : root_node_id_;
520    if (to_node != nullptr && to_node_id == -1) {
521      LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
522    }
523
524    add_hover_css_rule("node", from_node_id, kBlue);
525    add_hover_css_rule("node", to_node_id, kRed);
526
527    if (to_node) {
528      VLOG(3) << "Adding css for edge " << edge_id << " from node "
529              << from_node->name() << " to node " << to_node->name();
530    } else {
531      VLOG(3) << "Adding css for edge " << edge_id << " from node "
532              << from_node->name() << " to root tag";
533    }
534
535    // If this edge crosses a fusion cluster boundary, highlight it when the
536    // cluster is hovered over.
537    if (to_node) {
538      if (from_node->IsFused() &&
539          from_node->parent()->root_instruction() == from_node) {
540        int64 cluster_id = cluster_ids_.at(from_node->parent());
541        add_hover_css_rule("clust", cluster_id, kBlue);
542      }
543      if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
544        int64 cluster_id = cluster_ids_.at(to_node->parent());
545        add_hover_css_rule("clust", cluster_id, kRed);
546      }
547    }
548  }
549
550  return Printf(fmt, graph_label, Join(edge_css_rules, "\n"));
551}
552
553string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); }
554
555bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
556  CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
557  return ShouldShowSubcomputation(instr->fused_instructions_computation());
558}
559
560bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
561  if (subcomp->IsFusionComputation()) {
562    const HloInstruction* fusion = subcomp->FusionInstruction();
563    if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
564      return false;
565    }
566  }
567
568  // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
569  // into the graph.
570  if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
571    return false;
572  }
573
574  // Show the subcomputation if we're showing any of its members.
575  return std::any_of(
576      computation_->instructions().begin(), computation_->instructions().end(),
577      [&](const HloInstruction* instr) { return filter_.Show(instr); });
578}
579
580string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
581                                        const HloInstruction* parent_instr) {
582  VLOG(2) << "Dumping subcomputation " << subcomp->name();
583  const char* computation_fmt = R"(subgraph %s {
584%s
585label = <%s>;
586labelloc = t;
587tooltip = " ";
588%s
589}  // %s
590
591)";
592
593  cluster_ids_[subcomp] = next_cluster_id_++;
594
595  string id = SubcomputationId(subcomp);
596
597  string subcomp_label, style;
598  if (parent_instr->opcode() == HloOpcode::kFusion) {
599    subcomp_label = Printf("Fused expression for <b>%s</b><br/>%s",
600                           HtmlLikeStringSanitize(parent_instr->name()),
601                           HtmlLikeStringSanitize(parent_instr->ToCategory()));
602    string extra_info = GetInstructionNodeExtraInfo(parent_instr);
603    if (!extra_info.empty()) {
604      StrAppend(&subcomp_label, "<br/>", extra_info);
605    }
606
607    // Subcomputation's fill/stroke color is light/dark red/gray, depending on
608    // whether or not the subcomputation's fusion node is highlighted.
609    bool highlight = filter_.Highlight(parent_instr);
610    const char* fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
611    const char* strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
612    style =
613        Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
614               fillcolor, strokecolor);
615  } else {
616    subcomp_label = Printf("Subcomputation for <b>%s</b><br/>%s",
617                           HtmlLikeStringSanitize(parent_instr->name()),
618                           HtmlLikeStringSanitize(subcomp->name()));
619    style = "style=rounded; color=black;";
620  }
621
622  string comp_body = DumpComputation(subcomp);
623
624  // Add an edge from the subcomputation to its parent node.  If subcomp
625  // belongs to a fusion node, it's drawn in place of the fusion instruction,
626  // so there's no need to link those.
627  if (parent_instr->opcode() != HloOpcode::kFusion) {
628    const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
629    VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
630            << " as " << next_edge_id_;
631    edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
632    const char* edge_fmt =
633        R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
634    edges_.push_back(Printf(
635        edge_fmt, InstructionId(from), InstructionId(parent_instr),
636        SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
637  }
638
639  string computation =
640      Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
641
642  return computation;
643}
644
645string HloDotDumper::DumpComputation(const HloComputation* comp) {
646  string g;
647  for (const auto* instr : comp->instructions()) {
648    if (!filter_.Show(instr)) {
649      continue;
650    }
651
652    // Dump subcomputations within instr.
653    for (const HloComputation* subcomp : instr->called_computations()) {
654      if (ShouldShowSubcomputation(subcomp)) {
655        StrAppend(&g, DumpSubcomputation(subcomp, instr));
656      }
657    }
658
659    StrAppend(&g, DumpInstruction(instr));
660  }
661  return g;
662}
663
664string HloDotDumper::DumpRootTag() {
665  const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
666
667  // We didn't display constants as separate nodes; so if the root is a
668  // constant, we don't add root tag or edge for it.
669  if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
670    return "";
671  }
672
673  auto from_id = InstructionId(from);
674
675  // The ID of the root computation is otherwise unused, so it makes a good ID
676  // to use for the root-tag node.  However, the edge_ids_ map requires a
677  // HloInstruction* pointer for the 'to' value, so we use a NULL value there
678  // (rather than a pointer type-cast) to make it obvious if it is erroneously
679  // dereferenced.
680  HloInstruction* to = nullptr;
681  auto to_id = SubcomputationId(computation_);
682
683  string node_body = "ROOT";
684  string node_shape = "circle";
685  ColorScheme color = kBrown;
686
687  VLOG(2) << "Adding root tag as node " << next_node_id_;
688  root_node_id_ = next_node_id_++;
689
690  VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
691          << next_edge_id_;
692  edge_ids_.insert({{from, to}, next_edge_id_++});
693  edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
694
695  return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
696                "\n",
697                to_id, node_body, node_shape, NodeColorAttributes(color));
698}
699
700bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
701  // If a node:
702  //
703  //  - is a tuple-shaped parameter,
704  //  - is not a parameter to a fusion node,
705  //  - has at least kMinUsersToOmit users shown, and
706  //  - all of the shown users are get-tuple-elements,
707  //
708  // then we omit it from the graph, merging it with its users.
709  //
710  // This helps us handle the common case where a while loop body has one big
711  // tuple-shaped parameter.
712  const int kMinUsersToOmit = 3;
713  return instr->opcode() == HloOpcode::kParameter &&
714         ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() &&
715         std::count_if(instr->users().begin(), instr->users().end(),
716                       [&](const HloInstruction* user) {
717                         return filter_.Show(user);
718                       }) > kMinUsersToOmit &&
719         std::all_of(instr->users().begin(), instr->users().end(),
720                     [&](const HloInstruction* user) {
721                       return !filter_.Show(user) ||
722                              user->opcode() == HloOpcode::kGetTupleElement;
723                     });
724}
725
726string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
727  // We don't display constants as separate nodes; they're merged into their
728  // users.
729  if (instr->opcode() == HloOpcode::kConstant) {
730    return "";
731  }
732  // Skip this node if it's merged into its users.
733  if (ShouldMergeIntoUsers(instr)) {
734    return "";
735  }
736  // Omit the fusion node if its subcomputation is drawn, since the
737  // subcomputation will be drawn inline.
738  if (instr->opcode() == HloOpcode::kFusion &&
739      ShouldShowFusionSubcomputation(instr)) {
740    return "";
741  }
742
743  VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
744  node_ids_[instr] = next_node_id_++;
745
746  ColorScheme color = GetInstructionColor(instr);
747  string node_shape = GetInstructionNodeShape(instr);
748  string node_label = GetInstructionNodeLabel(instr);
749  string node_metadata = GetInstructionNodeMetadata(instr);
750  string extra_info = GetInstructionNodeExtraInfo(instr);
751  string inlined_constants = GetInstructionNodeInlinedOperands(instr);
752  string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
753  AddInstructionIncomingEdges(instr);
754
755  if (!debug_options_.xla_hlo_graph_sharding_color()) {
756    // Override the node's styling if it should be (de-)emphasized.
757    if (filter_.Deemphasized(instr)) {
758      color = kDashedBorder;
759    }
760    if (filter_.Highlight(instr)) {
761      node_shape = "diamond";
762      color = kDarkRed;
763    }
764  }
765  // Build the text that will be displayed inside the node.
766  string node_body = node_label;
767  for (const string& s :
768       {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) {
769    if (!s.empty()) {
770      StrAppend(&node_body, "<br/>", s);
771    }
772  }
773
774  return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
775                "\n",
776                InstructionId(instr), node_body, node_shape,
777                NodeColorAttributes(color));
778}
779
780string HloDotDumper::GetInstructionNodeInlinedOperands(
781    const HloInstruction* instr) {
782  auto stringify_constant = [](const HloInstruction* constant) {
783    const auto& shape = constant->shape();
784
785    // Print the literal value of constants with <= K elements.
786    optional<int64> elem_count;
787    if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) {
788      elem_count = 1;
789      for (int64 dim : shape.dimensions()) {
790        *elem_count *= dim;
791      }
792    }
793    if (elem_count.has_value() && *elem_count <= 8) {
794      return Printf("%s (%s)", constant->literal().ToString(),
795                    ShapeUtil::HumanString(constant->shape()));
796    }
797
798    // Otherwise, print e.g. "%constant.42 (s32[100])".
799    string constant_name;
800    if (tensorflow::StringPiece(constant->name()).starts_with("constant")) {
801      constant_name = constant->name();
802    } else {
803      constant_name = StrCat("constant ", constant->name());
804    }
805    return Printf("%s %s", constant_name,
806                  ShapeUtil::HumanString(constant->shape()));
807  };
808
809  // Special case: If instr is a parameter to a fusion node, check whether the
810  // corresponding operand to the fusion node is a constant.
811  if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
812    const HloInstruction* fusion = instr->parent()->FusionInstruction();
813    const HloInstruction* operand = fusion->operand(instr->parameter_number());
814    if (operand->opcode() != HloOpcode::kConstant) {
815      return "";
816    }
817    return StrCat("<b>constant</b> ", stringify_constant(operand));
818  }
819
820  std::vector<string> lines;
821  for (int64 i = 0; i < instr->operand_count(); ++i) {
822    const HloInstruction* operand = instr->operand(i);
823    optional<string> operand_str;
824    if (operand->opcode() == HloOpcode::kConstant) {
825      operand_str = stringify_constant(operand);
826    } else if (ShouldMergeIntoUsers(operand)) {
827      // Special case: If the operand is a parameter, use its parameter number
828      // rather than its name, because that's generally how people think of the
829      // node.
830      if (operand->opcode() == HloOpcode::kParameter) {
831        operand_str = Printf("Parameter %lld", operand->parameter_number());
832      } else {
833        operand_str = operand->name();
834      }
835    }
836
837    if (operand_str) {
838      if (instr->operand_count() > 1) {
839        lines.push_back(Printf("<b>operand %lld</b> = %s", i, *operand_str));
840      } else {
841        lines.push_back(Printf("<b>operand</b> = %s", *operand_str));
842      }
843    }
844  }
845  return Join(lines, "<br/>");
846}
847
848ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
849  if (debug_options_.xla_hlo_graph_sharding_color()) {
850    if (!instr->has_sharding()) {
851      return kDashedBorder;
852    }
853    string shard_str = instr->sharding().ToString();
854    auto it = sharding_colors_.find(shard_str);
855    if (it != sharding_colors_.end()) {
856      return it->second;
857    }
858    ColorScheme color = static_cast<ColorScheme>(
859        kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
860    sharding_colors_.emplace(shard_str, color);
861    return color;
862  }
863  const auto kParameterColor = kOrange;
864
865  // Special case: If this instruction has a parameter merged into it, paint it
866  // the same color as a parameter.
867  if (std::any_of(instr->operands().begin(), instr->operands().end(),
868                  [&](const HloInstruction* operand) {
869                    return operand->opcode() == HloOpcode::kParameter &&
870                           ShouldMergeIntoUsers(operand);
871                  })) {
872    return kParameterColor;
873  }
874
875  // Pick different colors or shapes for instructions which are particularly
876  // expensive (eg, dot) and those which are unusual in some way or unique
877  // (eg, parameter).
878  switch (instr->opcode()) {
879    case HloOpcode::kAbs:
880    case HloOpcode::kAdd:
881    case HloOpcode::kAnd:
882    case HloOpcode::kAtan2:
883    case HloOpcode::kBitcastConvert:
884    case HloOpcode::kCeil:
885    case HloOpcode::kClamp:
886    case HloOpcode::kComplex:
887    case HloOpcode::kConvert:
888    case HloOpcode::kCos:
889    case HloOpcode::kDivide:
890    case HloOpcode::kEq:
891    case HloOpcode::kExp:
892    case HloOpcode::kFloor:
893    case HloOpcode::kGe:
894    case HloOpcode::kGt:
895    case HloOpcode::kImag:
896    case HloOpcode::kIsFinite:
897    case HloOpcode::kLe:
898    case HloOpcode::kLog:
899    case HloOpcode::kLt:
900    case HloOpcode::kMaximum:
901    case HloOpcode::kMinimum:
902    case HloOpcode::kMultiply:
903    case HloOpcode::kNe:
904    case HloOpcode::kNegate:
905    case HloOpcode::kNot:
906    case HloOpcode::kOr:
907    case HloOpcode::kPower:
908    case HloOpcode::kReal:
909    case HloOpcode::kRemainder:
910    case HloOpcode::kRng:
911    case HloOpcode::kRoundNearestAfz:
912    case HloOpcode::kShiftLeft:
913    case HloOpcode::kShiftRightArithmetic:
914    case HloOpcode::kShiftRightLogical:
915    case HloOpcode::kSign:
916    case HloOpcode::kSin:
917    case HloOpcode::kSlice:
918    case HloOpcode::kSort:
919    case HloOpcode::kSubtract:
920    case HloOpcode::kTanh:
921      // De-emphasize scalar-shaped elementwise ops -- they're generally
922      // uninteresting.
923      if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
924        return kWhite;
925      }
926      return kYellow;
927    case HloOpcode::kBitcast:
928    case HloOpcode::kGetTupleElement:
929    case HloOpcode::kTrace:
930    case HloOpcode::kTuple:
931      return kWhite;
932    case HloOpcode::kBroadcast:
933      // De-emphasize nodes which broadcast a scalar within a fusion node --
934      // these are essentially free.
935      if (instr->IsFused() &&
936          ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
937        return kWhite;
938      }
939      return kGreen;
940    case HloOpcode::kConcatenate:
941    case HloOpcode::kCopy:
942    case HloOpcode::kDynamicSlice:
943    case HloOpcode::kGather:
944    case HloOpcode::kPad:
945    case HloOpcode::kReshape:
946    case HloOpcode::kReverse:
947    case HloOpcode::kSelect:
948    case HloOpcode::kTranspose:
949      // De-emphasize scalar-shaped data movement ops and all data movement ops
950      // inside fusion nodes, both of which are essentially free.
951      if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
952        return kWhite;
953      }
954      return kGreen;
955    case HloOpcode::kDynamicUpdateSlice:
956      // Unlike the data-movement ops above, dynamic-update-slice is not ~free
957      // inside of fusion nodes, so we de-emphasize it only if it's
958      // scalar-shaped.
959      if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
960        return kWhite;
961      }
962      return kGreen;
963    case HloOpcode::kConvolution:
964    case HloOpcode::kDot:
965    case HloOpcode::kFft:
966      return kDarkBlue;
967    case HloOpcode::kReducePrecision:
968      return kRed;
969    case HloOpcode::kParameter:
970      return kParameterColor;
971    case HloOpcode::kBatchNormGrad:
972    case HloOpcode::kBatchNormInference:
973    case HloOpcode::kBatchNormTraining:
974    case HloOpcode::kReduce:
975    case HloOpcode::kReduceWindow:
976    case HloOpcode::kSelectAndScatter:
977      return kPurple;
978    case HloOpcode::kFusion:
979    case HloOpcode::kMap:
980      return kGray;
981    case HloOpcode::kCrossReplicaSum:
982    case HloOpcode::kInfeed:
983    case HloOpcode::kOutfeed:
984    case HloOpcode::kRecv:
985    case HloOpcode::kRecvDone:
986    case HloOpcode::kSend:
987    case HloOpcode::kSendDone:
988      return kBrown;
989    case HloOpcode::kCall:
990    case HloOpcode::kConditional:
991    case HloOpcode::kCustomCall:
992    case HloOpcode::kHostCompute:
993    case HloOpcode::kWhile:
994      return kDarkGreen;
995    case HloOpcode::kConstant:
996      LOG(FATAL) << "Constants don't get their own nodes in the graph.";
997  }
998}
999
1000string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1001  // Give while loops a different shape so they're easier to pick out.
1002  switch (instr->opcode()) {
1003    case HloOpcode::kWhile:
1004      return "ellipse";
1005    default:
1006      return "rect";
1007  }
1008}
1009
1010string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1011  // If we have a parameter, put the param number in the name.
1012  if (instr->opcode() == HloOpcode::kParameter) {
1013    return Printf("<b>Parameter %lld</b>", instr->parameter_number());
1014  }
1015
1016  // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1017  // an add instruction.  In this case we render just the name.
1018  if (tensorflow::StringPiece(instr->name())
1019          .starts_with(HloOpcodeString(instr->opcode()))) {
1020    return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1021  }
1022  string extended_opcode =
1023      StrCat(HloOpcodeString(instr->opcode()),
1024             instr->opcode() != HloOpcode::kFusion
1025                 ? ""
1026                 : StrCat(":", xla::ToString(instr->fusion_kind())));
1027  // If the name does not contain the opcode, render both.
1028  return Printf("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1029                HtmlLikeStringSanitize(instr->name()));
1030}
1031
1032string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
1033  if (!show_metadata_) {
1034    return "";
1035  }
1036
1037  std::vector<string> lines;
1038  if (!instr->metadata().op_name().empty()) {
1039    lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1040  }
1041  if (!instr->metadata().op_type().empty()) {
1042    lines.push_back(Printf(
1043        "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1044  }
1045  if (!instr->metadata().source_file().empty() &&
1046      instr->metadata().source_line() != 0) {
1047    lines.push_back(Printf("op_type: %s", instr->metadata().source_file(),
1048                           instr->metadata().source_line()));
1049  }
1050
1051  return Join(lines, "<br/>");
1052}
1053
1054string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
1055  std::vector<string> lines;
1056
1057  // Get the instruction's extra attributes excluding the names of its
1058  // subcomputations, since those are drawn explicitly in the graph.
1059  for (const auto& line : instr->ExtraAttributesToString(
1060           HloPrintOptions().set_print_subcomputation_references(false))) {
1061    lines.push_back(HtmlLikeStringSanitize(line));
1062  }
1063
1064  // Show the shape and layout of the instruction, unless it's an inlined fusion
1065  // node -- there the shape and layout is present in the output node.
1066  if (instr->opcode() != HloOpcode::kFusion ||
1067      !ShouldShowFusionSubcomputation(instr)) {
1068    // Show layout of instructions with more than one dimension.  Don't show
1069    // layout on tuples or tensors with just one dimension (which only have one
1070    // possible layout) to avoid visual noise.
1071    bool shape_is_multidim = false;
1072    ShapeUtil::ForEachSubshape(instr->shape(),
1073                               [&](const Shape& s, const ShapeIndex&) {
1074                                 shape_is_multidim |= s.dimensions_size() > 1;
1075                               });
1076    string instr_shape;
1077    if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1078      instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1079    } else {
1080      instr_shape = ShapeUtil::HumanString(instr->shape());
1081    }
1082
1083    // Some instructions have giant tuples as their shapes, so truncate the
1084    // HLO's shape to kMaxShapeLen characters.
1085    constexpr int kMaxShapeLen = 64;
1086    if (instr_shape.length() > kMaxShapeLen) {
1087      instr_shape = StrCat(
1088          tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3),
1089          "...");
1090    }
1091    lines.push_back(instr_shape);
1092  }
1093  if (debug_options_.xla_hlo_graph_addresses()) {
1094    lines.push_back(Printf("[%p]", instr));
1095  }
1096  if (profile_ != nullptr) {
1097    double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1098    double total_cycles_executed =
1099        profile_->total_cycles_executed(*instr->parent());
1100    if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1101      lines.push_back(
1102          Printf("%% of cycles executed=%.2f",
1103                 100 * hlo_cycles_executed / total_cycles_executed));
1104    }
1105  }
1106  return Join(lines, "<br/>");
1107}
1108
1109void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1110  auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1111                      int64 operand_num, bool control_edge = false) {
1112    from = GetNodeForEdge(from);
1113
1114    if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1115        ShouldMergeIntoUsers(from)) {
1116      return;
1117    }
1118    VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1119            << " as " << next_edge_id_;
1120    edge_ids_.insert({{from, to}, next_edge_id_++});
1121
1122    string edge_label;
1123    if (instr->operand_count() > 1 && !control_edge) {
1124      edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num);
1125    } else if (control_edge) {
1126      edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1127    }
1128    const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)";
1129    edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to),
1130                            from->name(), to->name(), edge_label));
1131  };
1132
1133  // Add edges from instr's operands to instr.  Parameters within fusion
1134  // expressions are handled specially -- we draw an edge from the corresponding
1135  // operand on the fusion node itself to the parameter.
1136  if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1137    // Only add the edge if this is not the outermost computation; otherwise it
1138    // will lead from a node we're not drawing.
1139    if (instr->parent() != computation_) {
1140      const HloInstruction* fusion = instr->parent()->FusionInstruction();
1141      add_edge(fusion->operand(instr->parameter_number()), instr,
1142               /*operand_num=*/0);
1143    }
1144  } else {
1145    for (int64 i = 0; i < instr->operand_count(); ++i) {
1146      add_edge(instr->operand(i), instr, i);
1147    }
1148    for (const HloInstruction* pred : instr->control_predecessors()) {
1149      add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1150    }
1151  }
1152}
1153
1154string HloDotDumper::GetInstructionTrivialComputationStr(
1155    const HloInstruction* instr) {
1156  // called_computations() on a fusion node "inherits" any called computations
1157  // of the fused root, which isn't what we want.  Just ignore fusion nodes
1158  // here; they're handled separately.
1159  if (instr->opcode() == HloOpcode::kFusion) {
1160    return "";
1161  }
1162
1163  std::vector<string> lines;
1164  for (int64 i = 0; i < instr->called_computations().size(); ++i) {
1165    optional<string> computation_type =
1166        MatchTrivialComputation(instr->called_computations()[i]);
1167    if (!computation_type) {
1168      continue;
1169    }
1170    if (instr->called_computations().size() == 1) {
1171      lines.push_back(Printf("Subcomputation: <b>%s</b>",
1172                             HtmlLikeStringSanitize(*computation_type)));
1173    } else {
1174      lines.push_back(Printf("Subcomputation %lld: <b>%s</b>", i,
1175                             HtmlLikeStringSanitize(*computation_type)));
1176    }
1177  }
1178  return Join(lines, "<br/>");
1179}
1180
1181const HloInstruction* HloDotDumper::GetNodeForEdge(
1182    const HloInstruction* instr) {
1183  while (instr->opcode() == HloOpcode::kFusion &&
1184         ShouldShowFusionSubcomputation(instr)) {
1185    instr = instr->fused_expression_root();
1186  }
1187  return instr;
1188}
1189
1190class GraphRendererRegistry {
1191 public:
1192  void AddRenderer(GraphRendererInterface* graph_renderer) {
1193    tensorflow::mutex_lock lock(mu_);
1194    graph_renderer_ = graph_renderer;
1195  }
1196
1197  GraphRendererInterface* GetDefaultRenderer() {
1198    tensorflow::mutex_lock lock(mu_);
1199    return graph_renderer_;
1200  }
1201
1202  static GraphRendererRegistry* Default() {
1203    static GraphRendererRegistry* registry = new GraphRendererRegistry();
1204    return registry;
1205  }
1206
1207 private:
1208  tensorflow::mutex mu_;
1209  GraphRendererInterface* graph_renderer_ = nullptr;
1210};
1211
1212}  // namespace
1213
1214Registrar::Registrar(GraphRendererInterface* dumper) {
1215  GraphRendererRegistry::Default()->AddRenderer(dumper);
1216}
1217
1218namespace {
1219
1220// Gets a NodeFilter that includes roughly all instructions whose distance from
1221// root is <= radius.
1222NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
1223  // First, find the neighborhood of nodes with distance from root <= radius.
1224  // These nodes are our initial set of "normal" nodes.
1225  std::unordered_map<const HloInstruction*, NodeFilterResult> nodes;
1226  std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
1227  worklist.push_back({root, 0});
1228  while (!worklist.empty()) {
1229    const HloInstruction* instr;
1230    int64 depth;
1231    std::tie(instr, depth) = worklist.front();
1232    worklist.pop_front();
1233
1234    nodes[instr] = kNormalNode;
1235    if (depth == radius) {
1236      continue;
1237    }
1238
1239    // Traverse into instr's operands.
1240    //
1241    // Don't traverse into tuples' operands unless the tuple is the root.
1242    // Usually a tuple is the bottommost node in the graph, and so its operands
1243    // are not interesting to the graph at hand.
1244    if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1245      for (const HloInstruction* operand : instr->operands()) {
1246        if (!nodes.count(operand)) {
1247          worklist.push_back({operand, depth + 1});
1248        }
1249      }
1250    }
1251
1252    // Traverse into instr's nested computations.
1253    for (const HloComputation* computation : instr->called_computations()) {
1254      worklist.push_back({computation->root_instruction(), depth + 1});
1255    }
1256
1257    // Traverse into instr's users, unless:
1258    //
1259    //  - there are a ton of them, in which case they're probably not
1260    //    interesting (and anyway, rendering them all would make the graph
1261    //    unreadable), or
1262    //  - instr is a constant, in which case its users are probably not
1263    //    interesting.
1264    if (instr->opcode() == HloOpcode::kConstant) {
1265      continue;
1266    }
1267    constexpr int kMaxUsersToRender = 16;
1268    if (instr->user_count() > kMaxUsersToRender) {
1269      // If we're going to skip this node's users, style it as such.
1270      nodes[instr] = kSomeUsersOmitted;
1271      continue;
1272    }
1273    for (const HloInstruction* user : instr->users()) {
1274      if (!nodes.count(user)) {
1275        worklist.push_back({user, depth + 1});
1276      }
1277    }
1278  }
1279
1280  auto is_displayed = [&](const HloInstruction* instr) {
1281    // Constants are displayed inline with their users; they're never omitted.
1282    // Nodes in subcomputations are always shown.
1283    return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant ||
1284           instr->parent() != root->parent();
1285  };
1286
1287  // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1288  // know which nodes will be included in the graph.
1289  for (auto& kv : nodes) {
1290    const HloInstruction* instr = kv.first;
1291    NodeFilterResult& filter_result = kv.second;
1292    const auto& operands = instr->operands();
1293
1294    if (std::any_of(operands.begin(), operands.end(), is_displayed) &&
1295        !std::all_of(operands.begin(), operands.end(), is_displayed)) {
1296      // Mark nodes with some operands omitted appropriately.
1297      filter_result = kSomeOperandsOmitted;
1298    } else if (!operands.empty() &&
1299               std::none_of(operands.begin(), operands.end(), is_displayed)) {
1300      // Mark nodes with *all* operands omitted appropriately.
1301      filter_result = kOmitNodeOperands;
1302    }
1303
1304    // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1305    // users made it into the graph.
1306    if (filter_result == kSomeUsersOmitted &&
1307        std::all_of(instr->users().begin(), instr->users().end(),
1308                    is_displayed)) {
1309      filter_result = kNormalNode;
1310    }
1311  }
1312
1313  // Highlight the root node.
1314  nodes[root] = kHighlightNode;
1315
1316  return NodeFilter([=](const HloInstruction* instr) {
1317    auto it = nodes.find(instr);
1318    if (it != nodes.end()) {
1319      return it->second;
1320    }
1321    // Show all nodes in subcomputations.
1322    if (instr->parent() != root->parent()) {
1323      return kNormalNode;
1324    }
1325    return kHideNode;
1326  });
1327}
1328
1329string SaveGraph(const string& graph,
1330                 GraphRendererInterface::GraphKind graph_kind,
1331                 const string& dest_path) {
1332  static std::atomic<int> output_num(0);
1333  string file_extension;
1334  switch (graph_kind) {
1335    case GraphRendererInterface::DOT_GRAPH:
1336      file_extension = ".dot";
1337      break;
1338    case GraphRendererInterface::TF_GRAPHDEF:
1339      file_extension = ".pbtxt";
1340      break;
1341  }
1342  string path = JoinPath(dest_path, StrCat("hlo_graph_", output_num++, "."));
1343  auto status = Status::OK();
1344  auto env = tensorflow::Env::Default();
1345  if (!env->CreateUniqueFileName(&path, file_extension)) {
1346    status =
1347        Status(tensorflow::error::Code::UNKNOWN,
1348               StrCat("Failed to create temporary file to dump HLO graph: ",
1349                      strerror(errno)));
1350  } else {
1351    status = tensorflow::WriteStringToFile(env, path, graph);
1352  }
1353  if (!status.ok()) {
1354    LOG(WARNING) << "Saving HLO graph failed: " << status;
1355  }
1356  return path;
1357}
1358
1359string ExportGraph(const string& graph,
1360                   GraphRendererInterface::GraphKind graph_kind,
1361                   const DebugOptions& debug_options) {
1362  string path = debug_options.xla_hlo_graph_path();
1363  if (!path.empty()) {
1364    return SaveGraph(graph, graph_kind, path);
1365  } else {
1366    auto graph_renderer =
1367        GraphRendererRegistry::Default()->GetDefaultRenderer();
1368    CHECK(graph_renderer != nullptr)
1369        << "No registered renderer for the HLO graph. "
1370           "Use --xla_hlo_graph_path=PATH to export to local file system";
1371    return graph_renderer->RenderGraph(graph, graph_kind, debug_options);
1372  }
1373}
1374
1375}  // namespace
1376
1377string DumpGraph(const HloComputation& computation, const string& label,
1378                 const DebugOptions& debug_options,
1379                 const HloExecutionProfile* hlo_execution_profile,
1380                 bool show_metadata) {
1381  GraphRendererInterface::GraphKind graph_kind;
1382  string graph;
1383  if (debug_options.xla_hlo_dump_as_graphdef()) {
1384    HloTfGraphBuilder builder(debug_options);
1385    TF_CHECK_OK(builder.AddComputation(computation));
1386    CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(),
1387                                                          &graph));
1388    graph_kind = GraphRendererInterface::TF_GRAPHDEF;
1389  } else {
1390    graph = HloDotDumper(&computation, label, debug_options, show_metadata,
1391                         hlo_execution_profile, NodeFilter())
1392                .Dump();
1393    graph_kind = GraphRendererInterface::DOT_GRAPH;
1394  }
1395
1396  string graph_url = ExportGraph(graph, graph_kind, debug_options);
1397  LOG(INFO) << "computation " << computation.name() << " [" << label
1398            << "]: " << graph_url;
1399  return graph_url;
1400}
1401
1402string DumpNeighborhoodAround(const HloInstruction& node, int radius,
1403                              bool show_metadata) {
1404  auto debug_options = node.GetModule()->config().debug_options();
1405  string label =
1406      StrCat("Neighborhood of ", radius, " nodes around ", node.name());
1407  NodeFilter filter = MakeNodeFilter(&node, radius);
1408  string graph =
1409      HloDotDumper(node.parent(), label, debug_options, show_metadata,
1410                   /*profile=*/nullptr, filter)
1411          .Dump();
1412  return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options);
1413}
1414
1415void DumpText(const HloModule& module, const string& label,
1416              const string& directory_path, bool do_prefix) {
1417  Env* env = Env::Default();
1418  TF_CHECK_OK(env->RecursivelyCreateDir(directory_path));
1419  string prefix = StrCat(env->NowMicros());
1420  string filename =
1421      do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt");
1422  string path = JoinPath(directory_path, filename);
1423  TF_CHECK_OK(WriteStringToFile(
1424      env, path,
1425      module.ToString(HloPrintOptions().set_print_large_constants(true))));
1426  LOG(INFO) << "dumping module '" << module.name() << "' to " << path;
1427}
1428
1429string MaybeDumpHloModule(const HloModule& module, const string& label,
1430                          const HloExecutionProfile* profile) {
1431  const DebugOptions& debug_options = module.config().debug_options();
1432  VLOG(2) << "MaybeDumpHloModule called on module " << module.name()
1433          << " with generate_hlo_graph regex \""
1434          << debug_options.xla_generate_hlo_graph() << "\"";
1435  string graph_url;
1436  if (!debug_options.xla_generate_hlo_graph().empty() &&
1437      RE2::PartialMatch(module.name(),
1438                        debug_options.xla_generate_hlo_graph())) {
1439    graph_url =
1440        DumpGraph(*module.entry_computation(), label, debug_options, profile);
1441  }
1442  if (!debug_options.xla_log_hlo_text().empty() &&
1443      RE2::PartialMatch(module.name(), debug_options.xla_log_hlo_text())) {
1444    LOG(INFO) << "HLO for module " << module.name();
1445    LOG(INFO) << "Label: " << label;
1446    XLA_LOG_LINES(2, module.ToString());
1447  }
1448  if (!debug_options.xla_generate_hlo_text_to().empty()) {
1449    DumpText(module, label, debug_options.xla_generate_hlo_text_to());
1450  }
1451  return graph_url;
1452}
1453
1454}  // namespace hlo_graph_dumper
1455}  // namespace xla
1456