1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" 17 18#include <unistd.h> 19#include <algorithm> 20#include <atomic> 21#include <deque> 22#include <map> 23#include <memory> 24#include <string> 25#include <tuple> 26#include <unordered_map> 27#include <vector> 28 29#include "tensorflow/compiler/xla/layout_util.h" 30#include "tensorflow/compiler/xla/literal_util.h" 31#include "tensorflow/compiler/xla/service/hlo_module.h" 32#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" 33#include "tensorflow/compiler/xla/shape_util.h" 34#include "tensorflow/compiler/xla/types.h" 35#include "tensorflow/compiler/xla/window_util.h" 36#include "tensorflow/core/lib/core/status.h" 37#include "tensorflow/core/lib/gtl/map_util.h" 38#include "tensorflow/core/lib/gtl/optional.h" 39#include "tensorflow/core/lib/io/path.h" 40#include "tensorflow/core/lib/strings/numbers.h" 41#include "tensorflow/core/lib/strings/str_util.h" 42#include "tensorflow/core/lib/strings/strcat.h" 43#include "tensorflow/core/lib/strings/stringprintf.h" 44#include "tensorflow/core/platform/env.h" 45#include "tensorflow/core/platform/protobuf.h" 46#include "tensorflow/core/platform/regexp.h" 47 48using ::tensorflow::Env; 49using ::tensorflow::WriteStringToFile; 50using ::tensorflow::gtl::nullopt; 51using ::tensorflow::gtl::optional; 52using ::tensorflow::io::JoinPath; 53using ::tensorflow::str_util::Join; 54using ::tensorflow::str_util::StringReplace; 55using ::tensorflow::strings::StrAppend; 56using ::tensorflow::strings::StrCat; 57 58namespace xla { 59namespace hlo_graph_dumper { 60namespace { 61 62// Helpers for Printf and Appendf. 63template <typename T> 64struct PrintfConvert { 65 const T& operator()(const T& t) const { return t; } 66}; 67template <> 68struct PrintfConvert<string> { 69 const char* operator()(const string& s) const { return s.c_str(); } 70}; 71 72// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str() 73// on strings. 74template <typename... Ts> 75string Printf(const char* fmt, const Ts&... ts) { 76 return tensorflow::strings::Printf(fmt, PrintfConvert<Ts>()(ts)...); 77} 78template <typename... Ts> 79void Appendf(string* s, const char* fmt, const Ts&... ts) { 80 tensorflow::strings::Appendf(s, fmt, PrintfConvert<Ts>()(ts)...); 81} 82 83// Used to indicate how we should treat a given HLOInstruction in the graph. 84// should we treat it like normal, hide it, and so on? 85enum NodeFilterResult { 86 kNormalNode, 87 kHideNode, 88 // Make the node easy to find in the final graph. 89 kHighlightNode, 90 // "Gray out" the node to indicate that some of its operands have been 91 // omitted. 92 kSomeOperandsOmitted, 93 // Style the node the same as kSomeOperandsOmitted, but also don't connect it 94 // to its operands, even if they're present in the graph. 95 kOmitNodeOperands, 96 // Same style as kSomeOperandsOmitted, but used to indicate that some of the 97 // node's *users* have been omitted. 98 kSomeUsersOmitted, 99}; 100 101// NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult. 102// It lets callers tell the graph-drawing routines which nodes they want to be 103// shown, hidden, or highlighted. 104class NodeFilter { 105 public: 106 NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {} 107 108 explicit NodeFilter( 109 std::function<NodeFilterResult(const HloInstruction* instr)> filter) 110 : filter_(std::move(filter)) {} 111 112 bool Show(const HloInstruction* instr) const { 113 return filter_(instr) != kHideNode; 114 } 115 bool Highlight(const HloInstruction* instr) const { 116 return filter_(instr) == kHighlightNode; 117 } 118 bool OmitOperands(const HloInstruction* instr) const { 119 return filter_(instr) == kOmitNodeOperands; 120 } 121 bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const { 122 auto result = filter_(instr); 123 return result == kOmitNodeOperands || result == kSomeOperandsOmitted; 124 } 125 bool Deemphasized(const HloInstruction* instr) const { 126 auto result = filter_(instr); 127 return result == kOmitNodeOperands || result == kSomeOperandsOmitted || 128 result == kSomeUsersOmitted; 129 } 130 131 bool ShowFusionSubcomputation(const HloInstruction* instr) const { 132 CHECK_EQ(instr->opcode(), HloOpcode::kFusion); 133 return Show(instr) && !SomeOrAllOperandsOmitted(instr); 134 } 135 136 private: 137 std::function<NodeFilterResult(const HloInstruction* instr)> filter_; 138}; 139 140// Node color schemes, used by NodeColorAttributes. 141enum ColorScheme { 142 kBlue, 143 kBrown, 144 kDarkBlue, 145 kDarkGreen, 146 kDarkRed, 147 kGray, 148 kGreen, 149 kOrange, 150 kPurple, 151 kRed, 152 kWhite, 153 kYellow, 154 155 // Causes the node's border to be a dashed line, and its content to be gray 156 // text on a white background, suggesting that this is an "unimportant" node. 157 kDashedBorder, 158}; 159 160// Given a ColorScheme, returns an attribute string for a node of that color. 161// Sets the node's style and fill/stroke/text colors. 162// 163// Colors are from https://material.io/color. 164string NodeColorAttributes(ColorScheme color) { 165 using std::make_tuple; 166 167 const char *style, *fill_color, *stroke_color, *font_color; 168 std::tie(style, fill_color, stroke_color, font_color) = [color] { 169 switch (color) { 170 case kBlue: 171 return make_tuple("filled", "#bbdefb", "#8aacc8", "black"); 172 case kBrown: 173 return make_tuple("filled", "#bcaaa4", "#8c7b75", "black"); 174 case kDarkBlue: 175 return make_tuple("filled", "#1565c0", "#003c8f", "white"); 176 case kDarkGreen: 177 return make_tuple("filled", "#2e7d32", "#005005", "white"); 178 case kDarkRed: 179 return make_tuple("filled", "#b71c1c", "#7f0000", "white"); 180 case kGray: 181 return make_tuple("filled", "#cfd8dc", "#9ea7aa", "black"); 182 case kGreen: 183 return make_tuple("filled", "#c8e6c9", "#97b498", "black"); 184 case kOrange: 185 return make_tuple("filled", "#ffe0b2", "#cbae82", "black"); 186 case kPurple: 187 return make_tuple("filled", "#e1bee7", "#af8eb5", "black"); 188 case kRed: 189 return make_tuple("filled", "#ffcdd2", "#cb9ca1", "black"); 190 case kWhite: 191 return make_tuple("filled", "white", "black", "black"); 192 case kYellow: 193 return make_tuple("filled", "#fff9c4", "#cbc693", "black"); 194 case kDashedBorder: 195 // "filled,dashed" looks the same as "dashed", since we have a white 196 // background. But we use "filled,dashed" so that when you hover over 197 // any part of the node (not just the text inside the node), our css 198 // :hover rule is triggered. 199 return make_tuple("filled,dashed", "white", "#757575", "#757575"); 200 } 201 }(); 202 203 return Printf( 204 R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", style, 205 font_color, stroke_color, fill_color); 206} 207 208// Replaces <> with <>, so that this string is safe(er) for use in a 209// graphviz HTML-like string. 210string HtmlLikeStringSanitize(tensorflow::StringPiece s) { 211 return StringReplace(StringReplace(s, "<", "<", /*replace_all=*/true), ">", 212 ">", /*replace_all=*/true); 213} 214 215// Tries to generates a human-readable one-word description of the given 216// computation. 217// 218// Currently we support: 219// 220// "return param0 + param1;" --> "add" 221// "return param0 * param1;" --> "multiply" 222// "return min(param0, param1);" --> "min" 223// "return max(param0, param1);" --> "max" 224// "return param0 <= param1;" --> "less-or-equal" 225// "return param0 >= param1;" --> "greater-or-equal" 226// "return param0 > param1;" --> "greater-than" 227// "return param0 < param1;" --> "less-than" 228// "return param0 == param1;" --> "equal-to" 229// "return param0 != param1;" --> "not-equal-to" 230// 231// where param0 and param1 are effective scalars. For the ops that are 232// commutative, we also support them with param0 and param1 swapped. 233// 234// This is useful primarily for reduce and map nodes. These take a 235// subcomputation which is almost always one of the above, and pattern matching 236// it to a short string lets us tell the user what the subcomputation is without 237// drawing it as a graph. 238optional<string> MatchTrivialComputation(const HloComputation* computation) { 239 if (computation->instruction_count() != 3) { 240 return nullopt; 241 } 242 243 HloInstruction* root = computation->root_instruction(); 244 if (root->operand_count() != 2) { 245 return nullopt; 246 } 247 248 // Check that both of the operands to the root are parameters. 249 const HloInstruction* operand0 = root->operand(0); 250 const HloInstruction* operand1 = root->operand(1); 251 if (operand0->opcode() != HloOpcode::kParameter || 252 operand1->opcode() != HloOpcode::kParameter) { 253 return nullopt; 254 } 255 256 // Check that the two operands of root are param0 and param1. All of the 257 // opcodes we recognize are commutative, so we're OK with either order. 258 auto n0 = operand0->parameter_number(); 259 auto n1 = operand1->parameter_number(); 260 if (!(n0 == 0 && n1 == 1) && !(n1 == 0 && n0 == 1)) { 261 return nullopt; 262 } 263 264 // If the params are reversed, check that the operation being performed is 265 // commutative. 266 if (n0 == 1) { 267 switch (root->opcode()) { 268 case HloOpcode::kLe: 269 case HloOpcode::kGe: 270 case HloOpcode::kGt: 271 case HloOpcode::kLt: 272 return nullopt; 273 default: 274 break; 275 } 276 } 277 278 // Check that the root and params are all effective scalars. 279 if (!ShapeUtil::IsEffectiveScalar(root->shape()) || 280 !ShapeUtil::IsEffectiveScalar(operand0->shape()) || 281 !ShapeUtil::IsEffectiveScalar(operand1->shape())) { 282 return nullopt; 283 } 284 285 // If we recognize the root's opcode, we've successfully pattern-matched! 286 switch (root->opcode()) { 287 case HloOpcode::kAdd: 288 return "add"; 289 case HloOpcode::kMultiply: 290 return "multiply"; 291 case HloOpcode::kMinimum: 292 return "min"; 293 case HloOpcode::kMaximum: 294 return "max"; 295 case HloOpcode::kLe: 296 return "less-or-equal"; 297 case HloOpcode::kGe: 298 return "greater-or-equal"; 299 case HloOpcode::kGt: 300 return "greater-than"; 301 case HloOpcode::kLt: 302 return "less-than"; 303 case HloOpcode::kEq: 304 return "equal-to"; 305 case HloOpcode::kNe: 306 return "not-equal-to"; 307 default: 308 return nullopt; 309 } 310} 311 312// Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax). 313class HloDotDumper { 314 public: 315 HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, 316 const DebugOptions& debug_options, bool show_metadata, 317 const HloExecutionProfile* profile, NodeFilter filter) 318 : computation_(computation), 319 label_(label.ToString()), 320 debug_options_(debug_options), 321 show_metadata_(show_metadata), 322 profile_(profile), 323 filter_(std::move(filter)) {} 324 325 string Dump(); 326 327 private: 328 // Returns the dot graph identifier for the given instruction. 329 string InstructionId(const HloInstruction* instruction) { 330 return StrCat(reinterpret_cast<uint64>(instruction)); 331 } 332 333 // Returns the dot graph identifier for the given computation. 334 string SubcomputationId(const HloComputation* computation) { 335 return StrCat("cluster_", reinterpret_cast<uint64>(computation)); 336 } 337 338 // Generates graph header/footer. These should be called *after* dumping all 339 // of the instructions and subcomputations for the graph, as they both use 340 // data generated while dumping the graph. 341 string Header(); 342 string Footer(); 343 344 bool ShouldShowSubcomputation(const HloComputation* subcomp); 345 bool ShouldShowFusionSubcomputation(const HloInstruction* instr); 346 347 // We omit some nodes from the graph, instead drawing them inlined into the 348 // nodes that use them. 349 bool ShouldMergeIntoUsers(const HloInstruction* instr) const; 350 351 string DumpSubcomputation(const HloComputation* subcomp, 352 const HloInstruction* parent_instr); 353 string DumpComputation(const HloComputation* comp); 354 string DumpRootTag(); 355 string DumpInstruction(const HloInstruction* instr); 356 ColorScheme GetInstructionColor(const HloInstruction* instr); 357 string GetInstructionNodeShape(const HloInstruction* instr); 358 string GetInstructionNodeLabel(const HloInstruction* instr); 359 string GetInstructionNodeMetadata(const HloInstruction* instr); 360 string GetInstructionNodeExtraInfo(const HloInstruction* instr); 361 string GetInstructionNodeInlinedOperands(const HloInstruction* instr); 362 void AddInstructionIncomingEdges(const HloInstruction* instr); 363 364 // For most instructions, GetNodeForEdge(instr) returns instr. 365 // 366 // The exception is fusion nodes. For these, we walk up the chain of nested 367 // fusion nodes starting at instr until we reach a node that either (a) isn't 368 // a fusion node, or (b) is a fusion node for which 369 // ShouldShowFusionSubcomputation is false. 370 // 371 // We do this because fusion nodes are expanded inline -- if 372 // ShouldShowFusionSubcomputation is true, the fusion node won't be present in 373 // the graph. 374 // 375 // In general when you want to draw an edge from A to B, you should actually 376 // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B). 377 const HloInstruction* GetNodeForEdge(const HloInstruction* instr); 378 379 // If instr has just one computation and it's trivial (e.g. "return param0 + 380 // param1"), returns a string you can put into the node's body that names the 381 // subcomputation, e.g. "Subcomputation: <b>add</b>". 382 string GetInstructionTrivialComputationStr(const HloInstruction* instr); 383 384 const HloComputation* computation_; // never null 385 const string label_; // overall name for the graph 386 const DebugOptions& debug_options_; 387 const bool show_metadata_; 388 const HloExecutionProfile* profile_; // may be null 389 const NodeFilter filter_; 390 391 // Each HloInstruction dumped gets a monotically-increasing node ID. This 392 // must start at 1, because that's where graphviz's accounting starts. 393 int64 next_node_id_ = 1; 394 std::unordered_map<const HloInstruction*, int64> node_ids_; 395 396 // The "root" tag doesn't have an associated HloInstruction pointer, so we 397 // need to store it outside the map. 398 int64 root_node_id_; 399 400 // Each (from, to) edge gets a monotonically-increasing ID. This is a 401 // multimap because it's possible for the same edge to appear multiple times 402 // in the graph (e.g. x^2 may be represented as mul(x, x)). 403 int64 next_edge_id_ = 1; 404 std::unordered_multimap< 405 std::pair<const HloInstruction*, const HloInstruction*>, int64, 406 tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>> 407 edge_ids_; 408 409 // Each HloComputation that's emitted gets a monotonically-increasing ID. 410 int64 next_cluster_id_ = 1; 411 std::unordered_map<const HloComputation*, int64> cluster_ids_; 412 413 // Edges to print from Footer(). Edges come at the end because graphviz is 414 // unhappy if an edge from a subcomputation to a node in the outer computation 415 // appears before both the inner computation and the destination node are 416 // defined. 417 std::vector<string> edges_; 418 419 // When coloring by sharding information, we track the sharding string 420 // representation to color association, by round-robin the color schemes. 421 std::unordered_map<string, ColorScheme> sharding_colors_; 422 int64 next_shard_color_ = 0; 423}; 424 425string HloDotDumper::Dump() { 426 string body; 427 StrAppend(&body, DumpComputation(computation_)); 428 StrAppend(&body, DumpRootTag()); 429 430 // By contract, Header() and Footer() have to be called after we've dumped all 431 // our instructions, because they use state generated during that process. 432 string g = Header(); 433 StrAppend(&g, body); 434 StrAppend(&g, Footer()); 435 return g; 436} 437 438string HloDotDumper::Header() { 439 const char* fmt = R"(digraph G { 440rankdir = TB; 441compound = true; 442label = <<b>%s</b>>; 443labelloc = t; 444// Disable the tooltip. Interestingly, "" doesn't work! 445tooltip = " "; 446// DOT graphs accept a stylesheet as a URI. So naturally, an inline 447// stylesheet is a data URI! 448stylesheet=" 449 data:text/css, 450 @import url(https://fonts.googleapis.com/css?family=Roboto:400,700); 451 svg text { 452 font-family: 'Roboto'; 453 font-size: 12px; 454 } 455 456 %s 457" 458 459)"; 460 461 VLOG(3) << "Generating Header"; 462 463 string graph_label = 464 StrCat(label_, "<br/>Computation ", computation_->name()); 465 if (computation_->IsFusionComputation()) { 466 StrAppend(&graph_label, 467 StrCat(" (in fusion instruction ", 468 computation_->FusionInstruction()->name(), ")")); 469 } 470 if (profile_ != nullptr) { 471 auto cycles = profile_->total_cycles_executed(*computation_); 472 Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles, 473 tensorflow::strings::HumanReadableNum(cycles)); 474 } 475 476 // Create CSS rules that say, when you hover over the given node or cluster, 477 // turn the given edge the given color. 478 // 479 // We rely on a few properties of how graphviz generates SVGs: 480 // 481 // - Nodes are named "nodeN", where N corresponds to the 1-based index of 482 // the node in our DOT (i.e. the first node in the DOT is "node1", etc.). 483 // Edges are similarly named "edgeN", and clusters are named "clustN". 484 // - Nodes come before their in- and out-edges in the SVG. We need this 485 // because the "X ~ Y" CSS selector finds a sibling of X that *comes 486 // after X in the DOM* and matches Y. 487 std::vector<string> edge_css_rules; 488 const char* kBlue = "#1976d2"; 489 const char* kRed = "#d32f2f"; 490 for (const auto& kv : edge_ids_) { 491 const HloInstruction* from_node = kv.first.first; 492 const HloInstruction* to_node = kv.first.second; 493 int64 edge_id = kv.second; 494 495 auto add_hover_css_rule = [&](string elem_type, int64 elem_id, 496 const char* color) { 497 // One could imagine other ways of writing this CSS rule that involve 498 // less duplication, but this way seems to be relatively performant. 499 edge_css_rules.push_back( 500 Printf(" #%s%d:hover ~ #edge%lld text { fill: %s; }\n" 501 " #%s%d:hover ~ #edge%lld path { " 502 "stroke: %s; stroke-width: .2em; }\n" 503 " #%s%d:hover ~ #edge%lld polygon { " 504 "fill: %s; stroke: %s; stroke-width: .2em; }\n", 505 elem_type, elem_id, edge_id, color, // 506 elem_type, elem_id, edge_id, color, // 507 elem_type, elem_id, edge_id, color, color)); 508 }; 509 510 // The "to_node" value may be a NULL, indicating that this points to the 511 // "root" tag rather than a normal node. 512 int64 from_node_id = 513 tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1); 514 if (from_node_id == -1) { 515 LOG(FATAL) << from_node->name() << " was added to edges but not to nodes"; 516 } 517 int64 to_node_id = 518 to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1) 519 : root_node_id_; 520 if (to_node != nullptr && to_node_id == -1) { 521 LOG(FATAL) << to_node->name() << " was added to edges but not to nodes"; 522 } 523 524 add_hover_css_rule("node", from_node_id, kBlue); 525 add_hover_css_rule("node", to_node_id, kRed); 526 527 if (to_node) { 528 VLOG(3) << "Adding css for edge " << edge_id << " from node " 529 << from_node->name() << " to node " << to_node->name(); 530 } else { 531 VLOG(3) << "Adding css for edge " << edge_id << " from node " 532 << from_node->name() << " to root tag"; 533 } 534 535 // If this edge crosses a fusion cluster boundary, highlight it when the 536 // cluster is hovered over. 537 if (to_node) { 538 if (from_node->IsFused() && 539 from_node->parent()->root_instruction() == from_node) { 540 int64 cluster_id = cluster_ids_.at(from_node->parent()); 541 add_hover_css_rule("clust", cluster_id, kBlue); 542 } 543 if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) { 544 int64 cluster_id = cluster_ids_.at(to_node->parent()); 545 add_hover_css_rule("clust", cluster_id, kRed); 546 } 547 } 548 } 549 550 return Printf(fmt, graph_label, Join(edge_css_rules, "\n")); 551} 552 553string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); } 554 555bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) { 556 CHECK_EQ(instr->opcode(), HloOpcode::kFusion); 557 return ShouldShowSubcomputation(instr->fused_instructions_computation()); 558} 559 560bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { 561 if (subcomp->IsFusionComputation()) { 562 const HloInstruction* fusion = subcomp->FusionInstruction(); 563 if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) { 564 return false; 565 } 566 } 567 568 // Don't show trivial subcomputations on non-fusion nodes -- these are inlined 569 // into the graph. 570 if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) { 571 return false; 572 } 573 574 // Show the subcomputation if we're showing any of its members. 575 return std::any_of( 576 computation_->instructions().begin(), computation_->instructions().end(), 577 [&](const HloInstruction* instr) { return filter_.Show(instr); }); 578} 579 580string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, 581 const HloInstruction* parent_instr) { 582 VLOG(2) << "Dumping subcomputation " << subcomp->name(); 583 const char* computation_fmt = R"(subgraph %s { 584%s 585label = <%s>; 586labelloc = t; 587tooltip = " "; 588%s 589} // %s 590 591)"; 592 593 cluster_ids_[subcomp] = next_cluster_id_++; 594 595 string id = SubcomputationId(subcomp); 596 597 string subcomp_label, style; 598 if (parent_instr->opcode() == HloOpcode::kFusion) { 599 subcomp_label = Printf("Fused expression for <b>%s</b><br/>%s", 600 HtmlLikeStringSanitize(parent_instr->name()), 601 HtmlLikeStringSanitize(parent_instr->ToCategory())); 602 string extra_info = GetInstructionNodeExtraInfo(parent_instr); 603 if (!extra_info.empty()) { 604 StrAppend(&subcomp_label, "<br/>", extra_info); 605 } 606 607 // Subcomputation's fill/stroke color is light/dark red/gray, depending on 608 // whether or not the subcomputation's fusion node is highlighted. 609 bool highlight = filter_.Highlight(parent_instr); 610 const char* fillcolor = highlight ? "#ffcdd2" : "#f5f5f5"; 611 const char* strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; 612 style = 613 Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", 614 fillcolor, strokecolor); 615 } else { 616 subcomp_label = Printf("Subcomputation for <b>%s</b><br/>%s", 617 HtmlLikeStringSanitize(parent_instr->name()), 618 HtmlLikeStringSanitize(subcomp->name())); 619 style = "style=rounded; color=black;"; 620 } 621 622 string comp_body = DumpComputation(subcomp); 623 624 // Add an edge from the subcomputation to its parent node. If subcomp 625 // belongs to a fusion node, it's drawn in place of the fusion instruction, 626 // so there's no need to link those. 627 if (parent_instr->opcode() != HloOpcode::kFusion) { 628 const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); 629 VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() 630 << " as " << next_edge_id_; 631 edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); 632 const char* edge_fmt = 633 R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; 634 edges_.push_back(Printf( 635 edge_fmt, InstructionId(from), InstructionId(parent_instr), 636 SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); 637 } 638 639 string computation = 640 Printf(computation_fmt, id, style, subcomp_label, comp_body, id); 641 642 return computation; 643} 644 645string HloDotDumper::DumpComputation(const HloComputation* comp) { 646 string g; 647 for (const auto* instr : comp->instructions()) { 648 if (!filter_.Show(instr)) { 649 continue; 650 } 651 652 // Dump subcomputations within instr. 653 for (const HloComputation* subcomp : instr->called_computations()) { 654 if (ShouldShowSubcomputation(subcomp)) { 655 StrAppend(&g, DumpSubcomputation(subcomp, instr)); 656 } 657 } 658 659 StrAppend(&g, DumpInstruction(instr)); 660 } 661 return g; 662} 663 664string HloDotDumper::DumpRootTag() { 665 const HloInstruction* from = GetNodeForEdge(computation_->root_instruction()); 666 667 // We didn't display constants as separate nodes; so if the root is a 668 // constant, we don't add root tag or edge for it. 669 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) { 670 return ""; 671 } 672 673 auto from_id = InstructionId(from); 674 675 // The ID of the root computation is otherwise unused, so it makes a good ID 676 // to use for the root-tag node. However, the edge_ids_ map requires a 677 // HloInstruction* pointer for the 'to' value, so we use a NULL value there 678 // (rather than a pointer type-cast) to make it obvious if it is erroneously 679 // dereferenced. 680 HloInstruction* to = nullptr; 681 auto to_id = SubcomputationId(computation_); 682 683 string node_body = "ROOT"; 684 string node_shape = "circle"; 685 ColorScheme color = kBrown; 686 687 VLOG(2) << "Adding root tag as node " << next_node_id_; 688 root_node_id_ = next_node_id_++; 689 690 VLOG(2) << "Adding edge from " << from->name() << " to root tag as " 691 << next_edge_id_; 692 edge_ids_.insert({{from, to}, next_edge_id_++}); 693 edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); 694 695 return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" 696 "\n", 697 to_id, node_body, node_shape, NodeColorAttributes(color)); 698} 699 700bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { 701 // If a node: 702 // 703 // - is a tuple-shaped parameter, 704 // - is not a parameter to a fusion node, 705 // - has at least kMinUsersToOmit users shown, and 706 // - all of the shown users are get-tuple-elements, 707 // 708 // then we omit it from the graph, merging it with its users. 709 // 710 // This helps us handle the common case where a while loop body has one big 711 // tuple-shaped parameter. 712 const int kMinUsersToOmit = 3; 713 return instr->opcode() == HloOpcode::kParameter && 714 ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && 715 std::count_if(instr->users().begin(), instr->users().end(), 716 [&](const HloInstruction* user) { 717 return filter_.Show(user); 718 }) > kMinUsersToOmit && 719 std::all_of(instr->users().begin(), instr->users().end(), 720 [&](const HloInstruction* user) { 721 return !filter_.Show(user) || 722 user->opcode() == HloOpcode::kGetTupleElement; 723 }); 724} 725 726string HloDotDumper::DumpInstruction(const HloInstruction* instr) { 727 // We don't display constants as separate nodes; they're merged into their 728 // users. 729 if (instr->opcode() == HloOpcode::kConstant) { 730 return ""; 731 } 732 // Skip this node if it's merged into its users. 733 if (ShouldMergeIntoUsers(instr)) { 734 return ""; 735 } 736 // Omit the fusion node if its subcomputation is drawn, since the 737 // subcomputation will be drawn inline. 738 if (instr->opcode() == HloOpcode::kFusion && 739 ShouldShowFusionSubcomputation(instr)) { 740 return ""; 741 } 742 743 VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_; 744 node_ids_[instr] = next_node_id_++; 745 746 ColorScheme color = GetInstructionColor(instr); 747 string node_shape = GetInstructionNodeShape(instr); 748 string node_label = GetInstructionNodeLabel(instr); 749 string node_metadata = GetInstructionNodeMetadata(instr); 750 string extra_info = GetInstructionNodeExtraInfo(instr); 751 string inlined_constants = GetInstructionNodeInlinedOperands(instr); 752 string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); 753 AddInstructionIncomingEdges(instr); 754 755 if (!debug_options_.xla_hlo_graph_sharding_color()) { 756 // Override the node's styling if it should be (de-)emphasized. 757 if (filter_.Deemphasized(instr)) { 758 color = kDashedBorder; 759 } 760 if (filter_.Highlight(instr)) { 761 node_shape = "diamond"; 762 color = kDarkRed; 763 } 764 } 765 // Build the text that will be displayed inside the node. 766 string node_body = node_label; 767 for (const string& s : 768 {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { 769 if (!s.empty()) { 770 StrAppend(&node_body, "<br/>", s); 771 } 772 } 773 774 return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" 775 "\n", 776 InstructionId(instr), node_body, node_shape, 777 NodeColorAttributes(color)); 778} 779 780string HloDotDumper::GetInstructionNodeInlinedOperands( 781 const HloInstruction* instr) { 782 auto stringify_constant = [](const HloInstruction* constant) { 783 const auto& shape = constant->shape(); 784 785 // Print the literal value of constants with <= K elements. 786 optional<int64> elem_count; 787 if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { 788 elem_count = 1; 789 for (int64 dim : shape.dimensions()) { 790 *elem_count *= dim; 791 } 792 } 793 if (elem_count.has_value() && *elem_count <= 8) { 794 return Printf("%s (%s)", constant->literal().ToString(), 795 ShapeUtil::HumanString(constant->shape())); 796 } 797 798 // Otherwise, print e.g. "%constant.42 (s32[100])". 799 string constant_name; 800 if (tensorflow::StringPiece(constant->name()).starts_with("constant")) { 801 constant_name = constant->name(); 802 } else { 803 constant_name = StrCat("constant ", constant->name()); 804 } 805 return Printf("%s %s", constant_name, 806 ShapeUtil::HumanString(constant->shape())); 807 }; 808 809 // Special case: If instr is a parameter to a fusion node, check whether the 810 // corresponding operand to the fusion node is a constant. 811 if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { 812 const HloInstruction* fusion = instr->parent()->FusionInstruction(); 813 const HloInstruction* operand = fusion->operand(instr->parameter_number()); 814 if (operand->opcode() != HloOpcode::kConstant) { 815 return ""; 816 } 817 return StrCat("<b>constant</b> ", stringify_constant(operand)); 818 } 819 820 std::vector<string> lines; 821 for (int64 i = 0; i < instr->operand_count(); ++i) { 822 const HloInstruction* operand = instr->operand(i); 823 optional<string> operand_str; 824 if (operand->opcode() == HloOpcode::kConstant) { 825 operand_str = stringify_constant(operand); 826 } else if (ShouldMergeIntoUsers(operand)) { 827 // Special case: If the operand is a parameter, use its parameter number 828 // rather than its name, because that's generally how people think of the 829 // node. 830 if (operand->opcode() == HloOpcode::kParameter) { 831 operand_str = Printf("Parameter %lld", operand->parameter_number()); 832 } else { 833 operand_str = operand->name(); 834 } 835 } 836 837 if (operand_str) { 838 if (instr->operand_count() > 1) { 839 lines.push_back(Printf("<b>operand %lld</b> = %s", i, *operand_str)); 840 } else { 841 lines.push_back(Printf("<b>operand</b> = %s", *operand_str)); 842 } 843 } 844 } 845 return Join(lines, "<br/>"); 846} 847 848ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { 849 if (debug_options_.xla_hlo_graph_sharding_color()) { 850 if (!instr->has_sharding()) { 851 return kDashedBorder; 852 } 853 string shard_str = instr->sharding().ToString(); 854 auto it = sharding_colors_.find(shard_str); 855 if (it != sharding_colors_.end()) { 856 return it->second; 857 } 858 ColorScheme color = static_cast<ColorScheme>( 859 kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); 860 sharding_colors_.emplace(shard_str, color); 861 return color; 862 } 863 const auto kParameterColor = kOrange; 864 865 // Special case: If this instruction has a parameter merged into it, paint it 866 // the same color as a parameter. 867 if (std::any_of(instr->operands().begin(), instr->operands().end(), 868 [&](const HloInstruction* operand) { 869 return operand->opcode() == HloOpcode::kParameter && 870 ShouldMergeIntoUsers(operand); 871 })) { 872 return kParameterColor; 873 } 874 875 // Pick different colors or shapes for instructions which are particularly 876 // expensive (eg, dot) and those which are unusual in some way or unique 877 // (eg, parameter). 878 switch (instr->opcode()) { 879 case HloOpcode::kAbs: 880 case HloOpcode::kAdd: 881 case HloOpcode::kAnd: 882 case HloOpcode::kAtan2: 883 case HloOpcode::kBitcastConvert: 884 case HloOpcode::kCeil: 885 case HloOpcode::kClamp: 886 case HloOpcode::kComplex: 887 case HloOpcode::kConvert: 888 case HloOpcode::kCos: 889 case HloOpcode::kDivide: 890 case HloOpcode::kEq: 891 case HloOpcode::kExp: 892 case HloOpcode::kFloor: 893 case HloOpcode::kGe: 894 case HloOpcode::kGt: 895 case HloOpcode::kImag: 896 case HloOpcode::kIsFinite: 897 case HloOpcode::kLe: 898 case HloOpcode::kLog: 899 case HloOpcode::kLt: 900 case HloOpcode::kMaximum: 901 case HloOpcode::kMinimum: 902 case HloOpcode::kMultiply: 903 case HloOpcode::kNe: 904 case HloOpcode::kNegate: 905 case HloOpcode::kNot: 906 case HloOpcode::kOr: 907 case HloOpcode::kPower: 908 case HloOpcode::kReal: 909 case HloOpcode::kRemainder: 910 case HloOpcode::kRng: 911 case HloOpcode::kRoundNearestAfz: 912 case HloOpcode::kShiftLeft: 913 case HloOpcode::kShiftRightArithmetic: 914 case HloOpcode::kShiftRightLogical: 915 case HloOpcode::kSign: 916 case HloOpcode::kSin: 917 case HloOpcode::kSlice: 918 case HloOpcode::kSort: 919 case HloOpcode::kSubtract: 920 case HloOpcode::kTanh: 921 // De-emphasize scalar-shaped elementwise ops -- they're generally 922 // uninteresting. 923 if (ShapeUtil::IsEffectiveScalar(instr->shape())) { 924 return kWhite; 925 } 926 return kYellow; 927 case HloOpcode::kBitcast: 928 case HloOpcode::kGetTupleElement: 929 case HloOpcode::kTrace: 930 case HloOpcode::kTuple: 931 return kWhite; 932 case HloOpcode::kBroadcast: 933 // De-emphasize nodes which broadcast a scalar within a fusion node -- 934 // these are essentially free. 935 if (instr->IsFused() && 936 ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) { 937 return kWhite; 938 } 939 return kGreen; 940 case HloOpcode::kConcatenate: 941 case HloOpcode::kCopy: 942 case HloOpcode::kDynamicSlice: 943 case HloOpcode::kGather: 944 case HloOpcode::kPad: 945 case HloOpcode::kReshape: 946 case HloOpcode::kReverse: 947 case HloOpcode::kSelect: 948 case HloOpcode::kTranspose: 949 // De-emphasize scalar-shaped data movement ops and all data movement ops 950 // inside fusion nodes, both of which are essentially free. 951 if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) { 952 return kWhite; 953 } 954 return kGreen; 955 case HloOpcode::kDynamicUpdateSlice: 956 // Unlike the data-movement ops above, dynamic-update-slice is not ~free 957 // inside of fusion nodes, so we de-emphasize it only if it's 958 // scalar-shaped. 959 if (ShapeUtil::IsEffectiveScalar(instr->shape())) { 960 return kWhite; 961 } 962 return kGreen; 963 case HloOpcode::kConvolution: 964 case HloOpcode::kDot: 965 case HloOpcode::kFft: 966 return kDarkBlue; 967 case HloOpcode::kReducePrecision: 968 return kRed; 969 case HloOpcode::kParameter: 970 return kParameterColor; 971 case HloOpcode::kBatchNormGrad: 972 case HloOpcode::kBatchNormInference: 973 case HloOpcode::kBatchNormTraining: 974 case HloOpcode::kReduce: 975 case HloOpcode::kReduceWindow: 976 case HloOpcode::kSelectAndScatter: 977 return kPurple; 978 case HloOpcode::kFusion: 979 case HloOpcode::kMap: 980 return kGray; 981 case HloOpcode::kCrossReplicaSum: 982 case HloOpcode::kInfeed: 983 case HloOpcode::kOutfeed: 984 case HloOpcode::kRecv: 985 case HloOpcode::kRecvDone: 986 case HloOpcode::kSend: 987 case HloOpcode::kSendDone: 988 return kBrown; 989 case HloOpcode::kCall: 990 case HloOpcode::kConditional: 991 case HloOpcode::kCustomCall: 992 case HloOpcode::kHostCompute: 993 case HloOpcode::kWhile: 994 return kDarkGreen; 995 case HloOpcode::kConstant: 996 LOG(FATAL) << "Constants don't get their own nodes in the graph."; 997 } 998} 999 1000string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) { 1001 // Give while loops a different shape so they're easier to pick out. 1002 switch (instr->opcode()) { 1003 case HloOpcode::kWhile: 1004 return "ellipse"; 1005 default: 1006 return "rect"; 1007 } 1008} 1009 1010string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { 1011 // If we have a parameter, put the param number in the name. 1012 if (instr->opcode() == HloOpcode::kParameter) { 1013 return Printf("<b>Parameter %lld</b>", instr->parameter_number()); 1014 } 1015 1016 // The HLO instruction name contains usually the opcode, e.g. "%add.42" is 1017 // an add instruction. In this case we render just the name. 1018 if (tensorflow::StringPiece(instr->name()) 1019 .starts_with(HloOpcodeString(instr->opcode()))) { 1020 return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name())); 1021 } 1022 string extended_opcode = 1023 StrCat(HloOpcodeString(instr->opcode()), 1024 instr->opcode() != HloOpcode::kFusion 1025 ? "" 1026 : StrCat(":", xla::ToString(instr->fusion_kind()))); 1027 // If the name does not contain the opcode, render both. 1028 return Printf("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode), 1029 HtmlLikeStringSanitize(instr->name())); 1030} 1031 1032string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { 1033 if (!show_metadata_) { 1034 return ""; 1035 } 1036 1037 std::vector<string> lines; 1038 if (!instr->metadata().op_name().empty()) { 1039 lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); 1040 } 1041 if (!instr->metadata().op_type().empty()) { 1042 lines.push_back(Printf( 1043 "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type()))); 1044 } 1045 if (!instr->metadata().source_file().empty() && 1046 instr->metadata().source_line() != 0) { 1047 lines.push_back(Printf("op_type: %s", instr->metadata().source_file(), 1048 instr->metadata().source_line())); 1049 } 1050 1051 return Join(lines, "<br/>"); 1052} 1053 1054string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { 1055 std::vector<string> lines; 1056 1057 // Get the instruction's extra attributes excluding the names of its 1058 // subcomputations, since those are drawn explicitly in the graph. 1059 for (const auto& line : instr->ExtraAttributesToString( 1060 HloPrintOptions().set_print_subcomputation_references(false))) { 1061 lines.push_back(HtmlLikeStringSanitize(line)); 1062 } 1063 1064 // Show the shape and layout of the instruction, unless it's an inlined fusion 1065 // node -- there the shape and layout is present in the output node. 1066 if (instr->opcode() != HloOpcode::kFusion || 1067 !ShouldShowFusionSubcomputation(instr)) { 1068 // Show layout of instructions with more than one dimension. Don't show 1069 // layout on tuples or tensors with just one dimension (which only have one 1070 // possible layout) to avoid visual noise. 1071 bool shape_is_multidim = false; 1072 ShapeUtil::ForEachSubshape(instr->shape(), 1073 [&](const Shape& s, const ShapeIndex&) { 1074 shape_is_multidim |= s.dimensions_size() > 1; 1075 }); 1076 string instr_shape; 1077 if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) { 1078 instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape()); 1079 } else { 1080 instr_shape = ShapeUtil::HumanString(instr->shape()); 1081 } 1082 1083 // Some instructions have giant tuples as their shapes, so truncate the 1084 // HLO's shape to kMaxShapeLen characters. 1085 constexpr int kMaxShapeLen = 64; 1086 if (instr_shape.length() > kMaxShapeLen) { 1087 instr_shape = StrCat( 1088 tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3), 1089 "..."); 1090 } 1091 lines.push_back(instr_shape); 1092 } 1093 if (debug_options_.xla_hlo_graph_addresses()) { 1094 lines.push_back(Printf("[%p]", instr)); 1095 } 1096 if (profile_ != nullptr) { 1097 double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr); 1098 double total_cycles_executed = 1099 profile_->total_cycles_executed(*instr->parent()); 1100 if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { 1101 lines.push_back( 1102 Printf("%% of cycles executed=%.2f", 1103 100 * hlo_cycles_executed / total_cycles_executed)); 1104 } 1105 } 1106 return Join(lines, "<br/>"); 1107} 1108 1109void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { 1110 auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, 1111 int64 operand_num, bool control_edge = false) { 1112 from = GetNodeForEdge(from); 1113 1114 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant || 1115 ShouldMergeIntoUsers(from)) { 1116 return; 1117 } 1118 VLOG(2) << "Adding edge from " << from->name() << " to " << to->name() 1119 << " as " << next_edge_id_; 1120 edge_ids_.insert({{from, to}, next_edge_id_++}); 1121 1122 string edge_label; 1123 if (instr->operand_count() > 1 && !control_edge) { 1124 edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num); 1125 } else if (control_edge) { 1126 edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; 1127 } 1128 const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)"; 1129 edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), 1130 from->name(), to->name(), edge_label)); 1131 }; 1132 1133 // Add edges from instr's operands to instr. Parameters within fusion 1134 // expressions are handled specially -- we draw an edge from the corresponding 1135 // operand on the fusion node itself to the parameter. 1136 if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { 1137 // Only add the edge if this is not the outermost computation; otherwise it 1138 // will lead from a node we're not drawing. 1139 if (instr->parent() != computation_) { 1140 const HloInstruction* fusion = instr->parent()->FusionInstruction(); 1141 add_edge(fusion->operand(instr->parameter_number()), instr, 1142 /*operand_num=*/0); 1143 } 1144 } else { 1145 for (int64 i = 0; i < instr->operand_count(); ++i) { 1146 add_edge(instr->operand(i), instr, i); 1147 } 1148 for (const HloInstruction* pred : instr->control_predecessors()) { 1149 add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true); 1150 } 1151 } 1152} 1153 1154string HloDotDumper::GetInstructionTrivialComputationStr( 1155 const HloInstruction* instr) { 1156 // called_computations() on a fusion node "inherits" any called computations 1157 // of the fused root, which isn't what we want. Just ignore fusion nodes 1158 // here; they're handled separately. 1159 if (instr->opcode() == HloOpcode::kFusion) { 1160 return ""; 1161 } 1162 1163 std::vector<string> lines; 1164 for (int64 i = 0; i < instr->called_computations().size(); ++i) { 1165 optional<string> computation_type = 1166 MatchTrivialComputation(instr->called_computations()[i]); 1167 if (!computation_type) { 1168 continue; 1169 } 1170 if (instr->called_computations().size() == 1) { 1171 lines.push_back(Printf("Subcomputation: <b>%s</b>", 1172 HtmlLikeStringSanitize(*computation_type))); 1173 } else { 1174 lines.push_back(Printf("Subcomputation %lld: <b>%s</b>", i, 1175 HtmlLikeStringSanitize(*computation_type))); 1176 } 1177 } 1178 return Join(lines, "<br/>"); 1179} 1180 1181const HloInstruction* HloDotDumper::GetNodeForEdge( 1182 const HloInstruction* instr) { 1183 while (instr->opcode() == HloOpcode::kFusion && 1184 ShouldShowFusionSubcomputation(instr)) { 1185 instr = instr->fused_expression_root(); 1186 } 1187 return instr; 1188} 1189 1190class GraphRendererRegistry { 1191 public: 1192 void AddRenderer(GraphRendererInterface* graph_renderer) { 1193 tensorflow::mutex_lock lock(mu_); 1194 graph_renderer_ = graph_renderer; 1195 } 1196 1197 GraphRendererInterface* GetDefaultRenderer() { 1198 tensorflow::mutex_lock lock(mu_); 1199 return graph_renderer_; 1200 } 1201 1202 static GraphRendererRegistry* Default() { 1203 static GraphRendererRegistry* registry = new GraphRendererRegistry(); 1204 return registry; 1205 } 1206 1207 private: 1208 tensorflow::mutex mu_; 1209 GraphRendererInterface* graph_renderer_ = nullptr; 1210}; 1211 1212} // namespace 1213 1214Registrar::Registrar(GraphRendererInterface* dumper) { 1215 GraphRendererRegistry::Default()->AddRenderer(dumper); 1216} 1217 1218namespace { 1219 1220// Gets a NodeFilter that includes roughly all instructions whose distance from 1221// root is <= radius. 1222NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { 1223 // First, find the neighborhood of nodes with distance from root <= radius. 1224 // These nodes are our initial set of "normal" nodes. 1225 std::unordered_map<const HloInstruction*, NodeFilterResult> nodes; 1226 std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist; 1227 worklist.push_back({root, 0}); 1228 while (!worklist.empty()) { 1229 const HloInstruction* instr; 1230 int64 depth; 1231 std::tie(instr, depth) = worklist.front(); 1232 worklist.pop_front(); 1233 1234 nodes[instr] = kNormalNode; 1235 if (depth == radius) { 1236 continue; 1237 } 1238 1239 // Traverse into instr's operands. 1240 // 1241 // Don't traverse into tuples' operands unless the tuple is the root. 1242 // Usually a tuple is the bottommost node in the graph, and so its operands 1243 // are not interesting to the graph at hand. 1244 if (instr == root || instr->opcode() != HloOpcode::kTuple) { 1245 for (const HloInstruction* operand : instr->operands()) { 1246 if (!nodes.count(operand)) { 1247 worklist.push_back({operand, depth + 1}); 1248 } 1249 } 1250 } 1251 1252 // Traverse into instr's nested computations. 1253 for (const HloComputation* computation : instr->called_computations()) { 1254 worklist.push_back({computation->root_instruction(), depth + 1}); 1255 } 1256 1257 // Traverse into instr's users, unless: 1258 // 1259 // - there are a ton of them, in which case they're probably not 1260 // interesting (and anyway, rendering them all would make the graph 1261 // unreadable), or 1262 // - instr is a constant, in which case its users are probably not 1263 // interesting. 1264 if (instr->opcode() == HloOpcode::kConstant) { 1265 continue; 1266 } 1267 constexpr int kMaxUsersToRender = 16; 1268 if (instr->user_count() > kMaxUsersToRender) { 1269 // If we're going to skip this node's users, style it as such. 1270 nodes[instr] = kSomeUsersOmitted; 1271 continue; 1272 } 1273 for (const HloInstruction* user : instr->users()) { 1274 if (!nodes.count(user)) { 1275 worklist.push_back({user, depth + 1}); 1276 } 1277 } 1278 } 1279 1280 auto is_displayed = [&](const HloInstruction* instr) { 1281 // Constants are displayed inline with their users; they're never omitted. 1282 // Nodes in subcomputations are always shown. 1283 return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant || 1284 instr->parent() != root->parent(); 1285 }; 1286 1287 // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we 1288 // know which nodes will be included in the graph. 1289 for (auto& kv : nodes) { 1290 const HloInstruction* instr = kv.first; 1291 NodeFilterResult& filter_result = kv.second; 1292 const auto& operands = instr->operands(); 1293 1294 if (std::any_of(operands.begin(), operands.end(), is_displayed) && 1295 !std::all_of(operands.begin(), operands.end(), is_displayed)) { 1296 // Mark nodes with some operands omitted appropriately. 1297 filter_result = kSomeOperandsOmitted; 1298 } else if (!operands.empty() && 1299 std::none_of(operands.begin(), operands.end(), is_displayed)) { 1300 // Mark nodes with *all* operands omitted appropriately. 1301 filter_result = kOmitNodeOperands; 1302 } 1303 1304 // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their 1305 // users made it into the graph. 1306 if (filter_result == kSomeUsersOmitted && 1307 std::all_of(instr->users().begin(), instr->users().end(), 1308 is_displayed)) { 1309 filter_result = kNormalNode; 1310 } 1311 } 1312 1313 // Highlight the root node. 1314 nodes[root] = kHighlightNode; 1315 1316 return NodeFilter([=](const HloInstruction* instr) { 1317 auto it = nodes.find(instr); 1318 if (it != nodes.end()) { 1319 return it->second; 1320 } 1321 // Show all nodes in subcomputations. 1322 if (instr->parent() != root->parent()) { 1323 return kNormalNode; 1324 } 1325 return kHideNode; 1326 }); 1327} 1328 1329string SaveGraph(const string& graph, 1330 GraphRendererInterface::GraphKind graph_kind, 1331 const string& dest_path) { 1332 static std::atomic<int> output_num(0); 1333 string file_extension; 1334 switch (graph_kind) { 1335 case GraphRendererInterface::DOT_GRAPH: 1336 file_extension = ".dot"; 1337 break; 1338 case GraphRendererInterface::TF_GRAPHDEF: 1339 file_extension = ".pbtxt"; 1340 break; 1341 } 1342 string path = JoinPath(dest_path, StrCat("hlo_graph_", output_num++, ".")); 1343 auto status = Status::OK(); 1344 auto env = tensorflow::Env::Default(); 1345 if (!env->CreateUniqueFileName(&path, file_extension)) { 1346 status = 1347 Status(tensorflow::error::Code::UNKNOWN, 1348 StrCat("Failed to create temporary file to dump HLO graph: ", 1349 strerror(errno))); 1350 } else { 1351 status = tensorflow::WriteStringToFile(env, path, graph); 1352 } 1353 if (!status.ok()) { 1354 LOG(WARNING) << "Saving HLO graph failed: " << status; 1355 } 1356 return path; 1357} 1358 1359string ExportGraph(const string& graph, 1360 GraphRendererInterface::GraphKind graph_kind, 1361 const DebugOptions& debug_options) { 1362 string path = debug_options.xla_hlo_graph_path(); 1363 if (!path.empty()) { 1364 return SaveGraph(graph, graph_kind, path); 1365 } else { 1366 auto graph_renderer = 1367 GraphRendererRegistry::Default()->GetDefaultRenderer(); 1368 CHECK(graph_renderer != nullptr) 1369 << "No registered renderer for the HLO graph. " 1370 "Use --xla_hlo_graph_path=PATH to export to local file system"; 1371 return graph_renderer->RenderGraph(graph, graph_kind, debug_options); 1372 } 1373} 1374 1375} // namespace 1376 1377string DumpGraph(const HloComputation& computation, const string& label, 1378 const DebugOptions& debug_options, 1379 const HloExecutionProfile* hlo_execution_profile, 1380 bool show_metadata) { 1381 GraphRendererInterface::GraphKind graph_kind; 1382 string graph; 1383 if (debug_options.xla_hlo_dump_as_graphdef()) { 1384 HloTfGraphBuilder builder(debug_options); 1385 TF_CHECK_OK(builder.AddComputation(computation)); 1386 CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), 1387 &graph)); 1388 graph_kind = GraphRendererInterface::TF_GRAPHDEF; 1389 } else { 1390 graph = HloDotDumper(&computation, label, debug_options, show_metadata, 1391 hlo_execution_profile, NodeFilter()) 1392 .Dump(); 1393 graph_kind = GraphRendererInterface::DOT_GRAPH; 1394 } 1395 1396 string graph_url = ExportGraph(graph, graph_kind, debug_options); 1397 LOG(INFO) << "computation " << computation.name() << " [" << label 1398 << "]: " << graph_url; 1399 return graph_url; 1400} 1401 1402string DumpNeighborhoodAround(const HloInstruction& node, int radius, 1403 bool show_metadata) { 1404 auto debug_options = node.GetModule()->config().debug_options(); 1405 string label = 1406 StrCat("Neighborhood of ", radius, " nodes around ", node.name()); 1407 NodeFilter filter = MakeNodeFilter(&node, radius); 1408 string graph = 1409 HloDotDumper(node.parent(), label, debug_options, show_metadata, 1410 /*profile=*/nullptr, filter) 1411 .Dump(); 1412 return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); 1413} 1414 1415void DumpText(const HloModule& module, const string& label, 1416 const string& directory_path, bool do_prefix) { 1417 Env* env = Env::Default(); 1418 TF_CHECK_OK(env->RecursivelyCreateDir(directory_path)); 1419 string prefix = StrCat(env->NowMicros()); 1420 string filename = 1421 do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); 1422 string path = JoinPath(directory_path, filename); 1423 TF_CHECK_OK(WriteStringToFile( 1424 env, path, 1425 module.ToString(HloPrintOptions().set_print_large_constants(true)))); 1426 LOG(INFO) << "dumping module '" << module.name() << "' to " << path; 1427} 1428 1429string MaybeDumpHloModule(const HloModule& module, const string& label, 1430 const HloExecutionProfile* profile) { 1431 const DebugOptions& debug_options = module.config().debug_options(); 1432 VLOG(2) << "MaybeDumpHloModule called on module " << module.name() 1433 << " with generate_hlo_graph regex \"" 1434 << debug_options.xla_generate_hlo_graph() << "\""; 1435 string graph_url; 1436 if (!debug_options.xla_generate_hlo_graph().empty() && 1437 RE2::PartialMatch(module.name(), 1438 debug_options.xla_generate_hlo_graph())) { 1439 graph_url = 1440 DumpGraph(*module.entry_computation(), label, debug_options, profile); 1441 } 1442 if (!debug_options.xla_log_hlo_text().empty() && 1443 RE2::PartialMatch(module.name(), debug_options.xla_log_hlo_text())) { 1444 LOG(INFO) << "HLO for module " << module.name(); 1445 LOG(INFO) << "Label: " << label; 1446 XLA_LOG_LINES(2, module.ToString()); 1447 } 1448 if (!debug_options.xla_generate_hlo_text_to().empty()) { 1449 DumpText(module, label, debug_options.xla_generate_hlo_text_to()); 1450 } 1451 return graph_url; 1452} 1453 1454} // namespace hlo_graph_dumper 1455} // namespace xla 1456