1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" 17 18#include <atomic> 19#include <deque> 20#include <limits> 21#include <unordered_map> 22#include <unordered_set> 23 24#include "tensorflow/compiler/jit/defs.h" 25#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" 26#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" 27#include "tensorflow/compiler/jit/union_find.h" 28#include "tensorflow/compiler/tf2xla/dump_graph.h" 29#include "tensorflow/compiler/tf2xla/xla_op_registry.h" 30#include "tensorflow/core/common_runtime/function.h" 31#include "tensorflow/core/framework/graph_def_util.h" 32#include "tensorflow/core/framework/memory_types.h" 33#include "tensorflow/core/framework/node_def.pb.h" 34#include "tensorflow/core/framework/op_kernel.h" 35#include "tensorflow/core/framework/types.h" 36#include "tensorflow/core/graph/algorithm.h" 37#include "tensorflow/core/graph/control_flow.h" 38#include "tensorflow/core/lib/strings/strcat.h" 39#include "tensorflow/core/public/version.h" 40 41namespace tensorflow { 42 43const char* const kXlaClusterAttr = "_XlaCluster"; 44const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; 45 46namespace { 47 48bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { 49 // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient 50 // is really a kind of function call and will be handled by 51 // IsCompilableCall(). 52 if (node.type_string() == "SymbolicGradient") return false; 53 return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); 54} 55 56// Make sure we don't recurse infinitely on recursive functions. 57const int kMaxRecursionDepth = 10; 58 59bool IsCompilableCall(const NodeDef& call_def, 60 const DeviceType& jit_device_type, int depth, 61 FunctionLibraryRuntime* lib_runtime); 62 63// Tests whether 'while_node' is a completely compilable loop. 64// Every operator in the condition and body functions must be compilable for a 65// while loop to be compilable. 66bool IsCompilableWhile(const Node& while_node, 67 const DeviceType& jit_device_type, int depth, 68 FunctionLibraryRuntime* lib_runtime) { 69 VLOG(2) << "Loop marking: " << while_node.type_string(); 70 71 const NameAttrList* name_attr; 72 NodeDef call; 73 Status status; 74 status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); 75 if (!status.ok()) { 76 VLOG(2) << "Missing 'cond' attribute on While node."; 77 return false; 78 } 79 const string cond_func = name_attr->name(); 80 call.set_name("while_cond"); 81 call.set_op(cond_func); 82 *call.mutable_attr() = name_attr->attr(); 83 if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { 84 VLOG(2) << "Can't compile loop condition: " << cond_func; 85 return false; 86 } 87 status = GetNodeAttr(while_node.attrs(), "body", &name_attr); 88 if (!status.ok()) { 89 VLOG(2) << "Missing 'body' attribute on While node."; 90 return false; 91 } 92 const string body_func = name_attr->name(); 93 call.set_name("while_body"); 94 call.set_op(body_func); 95 *call.mutable_attr() = name_attr->attr(); 96 if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { 97 VLOG(2) << "Can't compile loop body: " << body_func; 98 return false; 99 } 100 VLOG(2) << "Loop is compilable."; 101 return true; 102} 103 104// Tests whether 'call_def' is a call to a completely compilable function. 105// Every operator in the function must be compilable for a function to be 106// compilable. 107bool IsCompilableCall(const NodeDef& call_def, 108 const DeviceType& jit_device_type, int depth, 109 FunctionLibraryRuntime* lib_runtime) { 110 VLOG(2) << "Function marking: " << call_def.op(); 111 112 if (depth > kMaxRecursionDepth) { 113 VLOG(2) << "Function depth limit exceeded"; 114 return false; 115 } 116 117 FunctionLibraryRuntime::Handle handle; 118 Status status = 119 lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); 120 if (!status.ok()) { 121 VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; 122 return false; 123 } 124 const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); 125 CHECK(fbody); 126 const FunctionDef& fdef = fbody->fdef; 127 bool noinline = false; 128 if (GetNodeAttr(AttrSlice(&fdef.attr()), "_noinline", &noinline).ok() && 129 noinline) { 130 // The underlying mechanism that calls non-inlined functions uses 131 // LocalExecutor, which interacts poorly with the LocalExecutor used by 132 // tf2xla to translate the TF graph into XLA. So we avoid this for now. 133 // 134 // TODO(b/36139787): Create a mechanism to set inlining hints. 135 VLOG(2) << "Can't compile noinline function: " << fdef.DebugString(); 136 return false; 137 } 138 139 for (Node* node : fbody->graph->op_nodes()) { 140 if (node->type_string() == "_Arg" || node->type_string() == "_Retval") 141 continue; 142 if (node->type_string() == "While") { 143 // Handle functional While loop (not in open source build). 144 return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime); 145 } 146 if (!HasXLAKernel(*node, jit_device_type) && 147 !IsCompilableCall(node->def(), jit_device_type, depth + 1, 148 lib_runtime)) { 149 VLOG(2) << "Function marking failed: unsupported op " << node->name() 150 << ": " << node->def().ShortDebugString(); 151 return false; 152 } 153 } 154 VLOG(2) << "Function is compilable: " << call_def.op(); 155 return true; 156} 157 158// Returns the DeviceType corresponding to 'device'. 159Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { 160 DeviceNameUtils::ParsedName parsed; 161 if (!DeviceNameUtils::ParseFullName(device, &parsed)) { 162 return errors::Internal("Malformed assigned device '", device, "'"); 163 } 164 *device_type = DeviceType(parsed.type); 165 return Status::OK(); 166} 167 168// Tests whether `node` has a DT_RESOURCE typed input or output. 169bool HasResourceInputOrOutput(const Node& node) { 170 return std::find(node.input_types().begin(), node.input_types().end(), 171 DT_RESOURCE) != node.input_types().end() || 172 std::find(node.output_types().begin(), node.output_types().end(), 173 DT_RESOURCE) != node.output_types().end(); 174} 175 176struct NodeCompare { 177 bool operator()(const Node* a, const Node* b) { return a->id() < b->id(); } 178}; 179using OrderedNodeSet = std::set<Node*, NodeCompare>; 180 181Status FindCompilationCandidates( 182 const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, 183 const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn, 184 OrderedNodeSet* candidates) { 185 OptimizerOptions opts; 186 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( 187 new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, 188 flib_def, opts)); 189 FunctionLibraryRuntime* lib_runtime = 190 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); 191 192 for (Node* node : graph.op_nodes()) { 193 VLOG(2) << "FindCompilationCandidates(): Processing " 194 << node->DebugString(); 195 196 DeviceType device_type(""); 197 TF_RETURN_IF_ERROR( 198 DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); 199 200 if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue; 201 202 const XlaOpRegistry::DeviceRegistration* registration; 203 CHECK( 204 XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); 205 DeviceType jit_device_type(registration->compilation_device_name); 206 if (!HasXLAKernel(*node, jit_device_type) && 207 !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) { 208 VLOG(2) << "Compilation rejected node: unsupported op " << node->name() 209 << ": " << node->type_string(); 210 continue; 211 } 212 if (!registration->compile_resource_ops && 213 HasResourceInputOrOutput(*node)) { 214 VLOG(2) << "Compilation rejected node: resource input/output " 215 << node->name() << ": " << node->type_string(); 216 continue; 217 } 218 if (node->type_string() == "While" && 219 !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { 220 continue; 221 } 222 // _Arg nodes in a top-level function represent feeds. 223 // Do not compile them. 224 if (node->type_string() == "_Arg") { 225 VLOG(2) << "Skipping jit compilation for '_Arg'-typed node " 226 << node->DebugString(); 227 continue; 228 } 229 // _Retval nodes in a top-level function represent fetches. 230 // Do not compile them. 231 if (node->type_string() == "_Retval") { 232 VLOG(2) << "Compilation rejected node: return value " << node->name() 233 << ": " << node->type_string(); 234 continue; 235 } 236 candidates->insert(node); 237 } 238 return Status::OK(); 239} 240 241struct Cluster { 242 // Identifies the node that represents this cluster in the cycle detection 243 // graph. 244 int representative = -1; 245}; 246 247// Returns a string describing how an edge from src to dst would 248// create a cycle. 249string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, 250 int dst) { 251 int32 max_path_size = graph.num_node_ids() + 1; 252 std::vector<int32> path(max_path_size); 253 int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data()); 254 if (path_size == 0) { 255 return ""; 256 } 257 258 auto node_name = [&cycles, &graph](int node_id) { 259 auto* node = graph.FindNodeId(node_id); 260 if (node == nullptr) { 261 return string("(null)"); 262 } 263 return node->name(); 264 }; 265 266 string description; 267 strings::StrAppend(&description, "Edge from ", node_name(src), " to ", 268 node_name(dst), " would create a cycle.\n"); 269 path.resize(path_size); 270 for (int32 node_id : path) { 271 string ascii_art; 272 if (node_id == dst) { 273 ascii_art = "+-> "; 274 } else if (node_id != src) { 275 ascii_art = "| "; 276 } else { 277 ascii_art = "+-- "; 278 } 279 strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); 280 } 281 return description; 282} 283 284} // anonymous namespace 285 286bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { 287 Device* device = flr->device(); 288 const XlaOpRegistry::DeviceRegistration* registration; 289 CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), 290 ®istration)); 291 DeviceType jit_device_type(registration->compilation_device_name); 292 return IsCompilableCall(ndef, jit_device_type, 0, flr); 293} 294 295Status MarkForCompilationPass::Run( 296 const GraphOptimizationPassOptions& options) { 297 // TODO(phawkins): precompute the "GetCompilationDevice" properties of each 298 // device ahead of time. 299 OptimizerOptions::GlobalJitLevel global_jit_level = 300 options.session_options->config.graph_options() 301 .optimizer_options() 302 .global_jit_level(); 303 if (global_jit_level == OptimizerOptions::DEFAULT) { 304 // To set compilation to be on by default, change the following line. 305 global_jit_level = OptimizerOptions::OFF; 306 } 307 legacy_flags::MarkForCompilationPassFlags* flags = 308 legacy_flags::GetMarkForCompilationPassFlags(); 309 if (flags->tf_xla_auto_jit == -1 || 310 (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { 311 // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides 312 // the setting in ConfigProto. 313 global_jit_level = 314 static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit); 315 } 316 bool cpu_global_jit = flags->tf_xla_cpu_global_jit; 317 VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; 318 const FunctionLibraryDefinition* fld = options.flib_def; 319 320 auto is_compilable = [global_jit_level, cpu_global_jit, fld]( 321 const Node* node, const DeviceType& device_type) { 322 const XlaOpRegistry::DeviceRegistration* registration; 323 if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), 324 ®istration)) { 325 return false; 326 } 327 328 // Don't compile control trigger nodes. We won't preserve their deadness 329 // semantics correctly, so it's safest not to compile them. 330 if (node->IsControlTrigger()) return false; 331 332 // If this device requires a JIT, we must say yes. 333 if (registration->requires_compilation) return true; 334 335 // If there is a _XlaCompile annotation, use its value. 336 bool compile = false; 337 Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); 338 if (status.ok()) return compile; 339 340 status = fld->GetAttr(*node, kXlaCompileAttr, &compile); 341 if (status.ok()) return compile; 342 343 // Otherwise use the value of global_jit_level. 344 // Ignore enable_jit_by_default if global jit compilation for CPU 345 // is explicitly requested via tf_xla_cpu_global_jit flag 346 bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; 347 return (ignore_registration || registration->enable_jit_by_default) && 348 global_jit_level > 0; 349 }; 350 return RunImpl(options, is_compilable); 351} 352 353// Is 'node' an operator that consumes only the shape of its input, not the 354// data itself? 355static bool IsShapeConsumerOp(const Node& node) { 356 return node.type_string() == "Shape" || node.type_string() == "Rank" || 357 node.type_string() == "Size"; 358} 359 360// Sequence number generator to ensure clusters have unique names. 361static std::atomic<int64> cluster_sequence_num; 362 363Status MarkForCompilationPass::RunImpl( 364 const GraphOptimizationPassOptions& options, 365 const std::function<bool(const Node*, const DeviceType&)>& 366 is_compilable_fn) { 367 VLOG(1) << "MarkForCompilationPass::Run"; 368 369 // Make sure that kernels have been registered on the JIT device. 370 XlaOpRegistry::RegisterCompilationKernels(); 371 372 Graph* graph = options.graph->get(); 373 374 OrderedNodeSet compilation_candidates; 375 TF_RETURN_IF_ERROR(FindCompilationCandidates( 376 *graph, options.flib_def, 377 (options.session_options != nullptr) ? options.session_options->env 378 : Env::Default(), 379 is_compilable_fn, &compilation_candidates)); 380 381 GraphCycles cycles; 382 for (int i = 0; i < graph->num_node_ids(); ++i) { 383 // We rely on the node IDs in the cycle detection graph being consecutive 384 // integers starting from 0. 385 CHECK_EQ(i, cycles.NewNode()); 386 } 387 388 // Compute the loop structure of the graph. 389 std::vector<ControlFlowInfo> control_flow_info; 390 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); 391 392 // The clustering code must avoid adding cycles to the graph to prevent 393 // deadlock. However, the graph may contain loops, which would trigger the 394 // cycle detection code. To handle loops, we alter the structure of the cycle 395 // detection graph, disconnecting each loop from the enclosing graph. 396 // Specifically, we: 397 // * add a new "frame" node for each loop. 398 // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges 399 // to/from the corresponding frame node. In essence, we collapse the loop 400 // into a single node for the purpose of cycle detection in the enclosing 401 // graph. 402 // * the body of the loop should now be disconnected from the rest of the 403 // graph; we make it acyclic by breaking loop backedges (edges outgoing from 404 // "NextIteration" nodes. 405 406 // Map from frame name strings to node IDs in the cycle detection graph. 407 std::unordered_map<string, int> frame_nodes; 408 409 // Get the cycle graph node ID for frame 'frame_name', or add one if none 410 // exists. 411 auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) { 412 int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; 413 if (frame_id < 0) { 414 // The emplace succeeded; we have not allocated a frame node yet. 415 frame_id = cycles.NewNode(); 416 } 417 return frame_id; 418 }; 419 420 for (Edge const* edge : graph->edges()) { 421 if (edge->dst()->IsEnter()) { 422 // Lift edges to an "Enter" node to the corresponding frame node. 423 const string& frame_name = 424 control_flow_info[edge->dst()->id()].frame_name; 425 int dst = GetOrAddFrameNodeId(frame_name); 426 if (!cycles.InsertEdge(edge->src()->id(), dst)) { 427 return errors::Internal( 428 "Cycle detected when adding enter->frame edge: ", 429 DescribeCycle(cycles, *graph, edge->src()->id(), dst)); 430 } 431 continue; 432 } 433 if (edge->src()->IsExit()) { 434 // Lift edges from an "Exit" node to the corresponding frame node. 435 const string& frame_name = 436 control_flow_info[edge->src()->id()].frame_name; 437 int src = GetOrAddFrameNodeId(frame_name); 438 if (!cycles.InsertEdge(src, edge->dst()->id())) { 439 return errors::Internal( 440 "Cycle detected when adding frame->exit edge: ", 441 DescribeCycle(cycles, *graph, src, edge->dst()->id())); 442 } 443 // Drop the original edge. 444 continue; 445 } 446 if (edge->src()->IsNextIteration()) { 447 // Break loop back-edges. 448 continue; 449 } 450 if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) { 451 // This should never happen. All cycles in the graph should contain 452 // a control flow operator. 453 return errors::Internal( 454 "Found cycle in graph without control flow operator during XLA " 455 "compilation: ", 456 DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); 457 } 458 } 459 460 // Each compilation candidate belongs to a cluster. The cluster's 461 // representative 462 // names the node in the 'cycles' graph that represents the cluster. 463 std::vector<UnionFind<Cluster>> clusters(graph->num_node_ids()); 464 std::deque<UnionFind<Cluster>*> worklist; 465 for (Node* node : compilation_candidates) { 466 Cluster& cluster = clusters[node->id()].Get(); 467 cluster.representative = node->id(); 468 worklist.push_back(&clusters[node->id()]); 469 } 470 471 legacy_flags::MarkForCompilationPassFlags* flags = 472 legacy_flags::GetMarkForCompilationPassFlags(); 473 474 // Repeatedly contract edges between clusters that are on the same device, 475 // provided the contraction would not create a cycle. 476 while (!worklist.empty()) { 477 int from = worklist.front()->Get().representative; 478 worklist.pop_front(); 479 480 Node* node_from = graph->FindNodeId(from); 481 if (node_from->IsControlFlow()) { 482 // Control flow nodes aren't compilation candidates and should never 483 // appear. 484 return errors::Internal( 485 "Found control flow node in clustering worklist: ", 486 node_from->type_string()); 487 } 488 string from_scope; 489 string to_scope; 490 for (int to : cycles.Successors(from)) { 491 if (to >= graph->num_node_ids()) { 492 // Node is a "frame" node that is present only in the cycle detection 493 // graph. No clustering is possible. 494 continue; 495 } 496 Node* node_to = graph->FindNodeId(to); 497 if (compilation_candidates.find(node_to) == 498 compilation_candidates.cend()) { 499 continue; 500 } 501 if (node_from->assigned_device_name() != 502 node_to->assigned_device_name()) { 503 continue; 504 } 505 // Look for an _XlaScope on both nodes. If both nodes have a 506 // scope and the scopes do not match, do not cluster along this 507 // edge. If even one of the nodes lacks an _XlaScope attribute, 508 // then it is treated as a "bridge" and a cluster may be created 509 // along it. We may want to restrict this behavior to require 510 // all nodes marked with _XlaCompile=true to also have a 511 // _XlaScope property set (and raise an error otherwise); but 512 // for now we don't do this. 513 if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && 514 GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() && 515 from_scope != to_scope) { 516 continue; 517 } 518 519 // Ops that consume shapes cannot be the root of a cluster. This is an 520 // optimization. 521 if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { 522 continue; 523 } 524 525 // Don't exceed the maximum cluster size. 526 if (clusters[from].Size() + clusters[to].Size() > 527 flags->tf_xla_max_cluster_size) { 528 continue; 529 } 530 531 // If contracting the edge would create a cycle, bail out. 532 // However, just because we can't merge the clusters now does not mean 533 // we won't be able to merge them in the future. 534 // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge 535 // 1->3. But if we first contract 1->2 then we can later contract 1->3. 536 if (!cycles.ContractEdge(from, to)) continue; 537 538 // Merge the clusters. ContractEdge uses 'from' as the number of the 539 // merged node, so make sure 'from' is the chosen representative. 540 clusters[from].Merge(&clusters[to]); 541 542 worklist.push_back(&clusters[from]); 543 break; 544 } 545 } 546 547 // Count the number of elements in each cluster. 548 std::vector<int> cluster_sizes(graph->num_node_ids()); 549 for (const Node* n : compilation_candidates) { 550 int cluster = clusters[n->id()].Get().representative; 551 cluster_sizes[cluster]++; 552 } 553 554 // Names for each cluster. 555 std::unordered_map<int, string> cluster_names; 556 557 // Mark clusters for compilation that: 558 // * are placed on a device that requires compilation (an XlaDevice), 559 // * are explicitly marked for compilation (_XlaCompile=true), or 560 // * have more than flags->tf_xla_min_cluster_size elements (applicable only 561 // if compilation is enabled, otherwise there will be no such candidates). 562 const int min_cluster_size = flags->tf_xla_min_cluster_size; 563 for (Node* n : compilation_candidates) { 564 int cluster = clusters[n->id()].Get().representative; 565 566 // Compile if the user marked this node _XlaCompile=true 567 bool compile_attr = false; 568 bool marked_for_compilation = false; 569 if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) { 570 marked_for_compilation = compile_attr; 571 } else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr) 572 .ok()) { 573 marked_for_compilation = compile_attr; 574 } 575 576 // Compile if this operator is placed on a device that requires 577 // compilation. 578 DeviceType device_type(""); 579 TF_RETURN_IF_ERROR( 580 DeviceTypeOfDevice(n->assigned_device_name(), &device_type)); 581 const XlaOpRegistry::DeviceRegistration* registration; 582 XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); 583 584 // Or compile if this is a cluster of >= min_cluster_size compilable 585 // operators. 586 if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation || 587 registration->requires_compilation) { 588 string& name = cluster_names[cluster]; 589 590 if (name.empty()) { 591 name = strings::StrCat("cluster_", cluster_sequence_num++); 592 } 593 n->AddAttr(kXlaClusterAttr, name); 594 VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; 595 } 596 } 597 598 if (flags->tf_xla_clustering_debug) { 599 dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph, 600 options.flib_def); 601 } 602 return Status::OK(); 603} 604 605} // namespace tensorflow 606