1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
17
18#include <atomic>
19#include <deque>
20#include <limits>
21#include <unordered_map>
22#include <unordered_set>
23
24#include "tensorflow/compiler/jit/defs.h"
25#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
26#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
27#include "tensorflow/compiler/jit/union_find.h"
28#include "tensorflow/compiler/tf2xla/dump_graph.h"
29#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
30#include "tensorflow/core/common_runtime/function.h"
31#include "tensorflow/core/framework/graph_def_util.h"
32#include "tensorflow/core/framework/memory_types.h"
33#include "tensorflow/core/framework/node_def.pb.h"
34#include "tensorflow/core/framework/op_kernel.h"
35#include "tensorflow/core/framework/types.h"
36#include "tensorflow/core/graph/algorithm.h"
37#include "tensorflow/core/graph/control_flow.h"
38#include "tensorflow/core/lib/strings/strcat.h"
39#include "tensorflow/core/public/version.h"
40
41namespace tensorflow {
42
43const char* const kXlaClusterAttr = "_XlaCluster";
44const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
45
46namespace {
47
48bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
49  // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
50  // is really a kind of function call and will be handled by
51  // IsCompilableCall().
52  if (node.type_string() == "SymbolicGradient") return false;
53  return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
54}
55
56// Make sure we don't recurse infinitely on recursive functions.
57const int kMaxRecursionDepth = 10;
58
59bool IsCompilableCall(const NodeDef& call_def,
60                      const DeviceType& jit_device_type, int depth,
61                      FunctionLibraryRuntime* lib_runtime);
62
63// Tests whether 'while_node' is a completely compilable loop.
64// Every operator in the condition and body functions must be compilable for a
65// while loop to be compilable.
66bool IsCompilableWhile(const Node& while_node,
67                       const DeviceType& jit_device_type, int depth,
68                       FunctionLibraryRuntime* lib_runtime) {
69  VLOG(2) << "Loop marking: " << while_node.type_string();
70
71  const NameAttrList* name_attr;
72  NodeDef call;
73  Status status;
74  status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
75  if (!status.ok()) {
76    VLOG(2) << "Missing 'cond' attribute on While node.";
77    return false;
78  }
79  const string cond_func = name_attr->name();
80  call.set_name("while_cond");
81  call.set_op(cond_func);
82  *call.mutable_attr() = name_attr->attr();
83  if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
84    VLOG(2) << "Can't compile loop condition: " << cond_func;
85    return false;
86  }
87  status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
88  if (!status.ok()) {
89    VLOG(2) << "Missing 'body' attribute on While node.";
90    return false;
91  }
92  const string body_func = name_attr->name();
93  call.set_name("while_body");
94  call.set_op(body_func);
95  *call.mutable_attr() = name_attr->attr();
96  if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
97    VLOG(2) << "Can't compile loop body: " << body_func;
98    return false;
99  }
100  VLOG(2) << "Loop is compilable.";
101  return true;
102}
103
104// Tests whether 'call_def' is a call to a completely compilable function.
105// Every operator in the function must be compilable for a function to be
106// compilable.
107bool IsCompilableCall(const NodeDef& call_def,
108                      const DeviceType& jit_device_type, int depth,
109                      FunctionLibraryRuntime* lib_runtime) {
110  VLOG(2) << "Function marking: " << call_def.op();
111
112  if (depth > kMaxRecursionDepth) {
113    VLOG(2) << "Function depth limit exceeded";
114    return false;
115  }
116
117  FunctionLibraryRuntime::Handle handle;
118  Status status =
119      lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle);
120  if (!status.ok()) {
121    VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
122    return false;
123  }
124  const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
125  CHECK(fbody);
126  const FunctionDef& fdef = fbody->fdef;
127  bool noinline = false;
128  if (GetNodeAttr(AttrSlice(&fdef.attr()), "_noinline", &noinline).ok() &&
129      noinline) {
130    // The underlying mechanism that calls non-inlined functions uses
131    // LocalExecutor, which interacts poorly with the LocalExecutor used by
132    // tf2xla to translate the TF graph into XLA.  So we avoid this for now.
133    //
134    // TODO(b/36139787): Create a mechanism to set inlining hints.
135    VLOG(2) << "Can't compile noinline function: " << fdef.DebugString();
136    return false;
137  }
138
139  for (Node* node : fbody->graph->op_nodes()) {
140    if (node->type_string() == "_Arg" || node->type_string() == "_Retval")
141      continue;
142    if (node->type_string() == "While") {
143      // Handle functional While loop (not in open source build).
144      return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime);
145    }
146    if (!HasXLAKernel(*node, jit_device_type) &&
147        !IsCompilableCall(node->def(), jit_device_type, depth + 1,
148                          lib_runtime)) {
149      VLOG(2) << "Function marking failed: unsupported op " << node->name()
150              << ": " << node->def().ShortDebugString();
151      return false;
152    }
153  }
154  VLOG(2) << "Function is compilable: " << call_def.op();
155  return true;
156}
157
158// Returns the DeviceType corresponding to 'device'.
159Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
160  DeviceNameUtils::ParsedName parsed;
161  if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
162    return errors::Internal("Malformed assigned device '", device, "'");
163  }
164  *device_type = DeviceType(parsed.type);
165  return Status::OK();
166}
167
168// Tests whether `node` has a DT_RESOURCE typed input or output.
169bool HasResourceInputOrOutput(const Node& node) {
170  return std::find(node.input_types().begin(), node.input_types().end(),
171                   DT_RESOURCE) != node.input_types().end() ||
172         std::find(node.output_types().begin(), node.output_types().end(),
173                   DT_RESOURCE) != node.output_types().end();
174}
175
176struct NodeCompare {
177  bool operator()(const Node* a, const Node* b) { return a->id() < b->id(); }
178};
179using OrderedNodeSet = std::set<Node*, NodeCompare>;
180
181Status FindCompilationCandidates(
182    const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
183    const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
184    OrderedNodeSet* candidates) {
185  OptimizerOptions opts;
186  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
187      new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
188                                        flib_def, opts));
189  FunctionLibraryRuntime* lib_runtime =
190      pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
191
192  for (Node* node : graph.op_nodes()) {
193    VLOG(2) << "FindCompilationCandidates(): Processing "
194            << node->DebugString();
195
196    DeviceType device_type("");
197    TF_RETURN_IF_ERROR(
198        DeviceTypeOfDevice(node->assigned_device_name(), &device_type));
199
200    if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue;
201
202    const XlaOpRegistry::DeviceRegistration* registration;
203    CHECK(
204        XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration));
205    DeviceType jit_device_type(registration->compilation_device_name);
206    if (!HasXLAKernel(*node, jit_device_type) &&
207        !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) {
208      VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
209              << ": " << node->type_string();
210      continue;
211    }
212    if (!registration->compile_resource_ops &&
213        HasResourceInputOrOutput(*node)) {
214      VLOG(2) << "Compilation rejected node: resource input/output "
215              << node->name() << ": " << node->type_string();
216      continue;
217    }
218    if (node->type_string() == "While" &&
219        !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) {
220      continue;
221    }
222    // _Arg nodes in a top-level function represent feeds.
223    // Do not compile them.
224    if (node->type_string() == "_Arg") {
225      VLOG(2) << "Skipping jit compilation for '_Arg'-typed node "
226              << node->DebugString();
227      continue;
228    }
229    // _Retval nodes in a top-level function represent fetches.
230    // Do not compile them.
231    if (node->type_string() == "_Retval") {
232      VLOG(2) << "Compilation rejected node: return value " << node->name()
233              << ": " << node->type_string();
234      continue;
235    }
236    candidates->insert(node);
237  }
238  return Status::OK();
239}
240
241struct Cluster {
242  // Identifies the node that represents this cluster in the cycle detection
243  // graph.
244  int representative = -1;
245};
246
247// Returns a string describing how an edge from src to dst would
248// create a cycle.
249string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src,
250                     int dst) {
251  int32 max_path_size = graph.num_node_ids() + 1;
252  std::vector<int32> path(max_path_size);
253  int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data());
254  if (path_size == 0) {
255    return "";
256  }
257
258  auto node_name = [&cycles, &graph](int node_id) {
259    auto* node = graph.FindNodeId(node_id);
260    if (node == nullptr) {
261      return string("(null)");
262    }
263    return node->name();
264  };
265
266  string description;
267  strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
268                     node_name(dst), " would create a cycle.\n");
269  path.resize(path_size);
270  for (int32 node_id : path) {
271    string ascii_art;
272    if (node_id == dst) {
273      ascii_art = "+-> ";
274    } else if (node_id != src) {
275      ascii_art = "|   ";
276    } else {
277      ascii_art = "+-- ";
278    }
279    strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
280  }
281  return description;
282}
283
284}  // anonymous namespace
285
286bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
287  Device* device = flr->device();
288  const XlaOpRegistry::DeviceRegistration* registration;
289  CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(),
290                                            &registration));
291  DeviceType jit_device_type(registration->compilation_device_name);
292  return IsCompilableCall(ndef, jit_device_type, 0, flr);
293}
294
295Status MarkForCompilationPass::Run(
296    const GraphOptimizationPassOptions& options) {
297  // TODO(phawkins): precompute the "GetCompilationDevice" properties of each
298  // device ahead of time.
299  OptimizerOptions::GlobalJitLevel global_jit_level =
300      options.session_options->config.graph_options()
301          .optimizer_options()
302          .global_jit_level();
303  if (global_jit_level == OptimizerOptions::DEFAULT) {
304    // To set compilation to be on by default, change the following line.
305    global_jit_level = OptimizerOptions::OFF;
306  }
307  legacy_flags::MarkForCompilationPassFlags* flags =
308      legacy_flags::GetMarkForCompilationPassFlags();
309  if (flags->tf_xla_auto_jit == -1 ||
310      (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
311    // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
312    // the setting in ConfigProto.
313    global_jit_level =
314        static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
315  }
316  bool cpu_global_jit = flags->tf_xla_cpu_global_jit;
317  VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
318  const FunctionLibraryDefinition* fld = options.flib_def;
319
320  auto is_compilable = [global_jit_level, cpu_global_jit, fld](
321                           const Node* node, const DeviceType& device_type) {
322    const XlaOpRegistry::DeviceRegistration* registration;
323    if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
324                                             &registration)) {
325      return false;
326    }
327
328    // Don't compile control trigger nodes. We won't preserve their deadness
329    // semantics correctly, so it's safest not to compile them.
330    if (node->IsControlTrigger()) return false;
331
332    // If this device requires a JIT, we must say yes.
333    if (registration->requires_compilation) return true;
334
335    // If there is a _XlaCompile annotation, use its value.
336    bool compile = false;
337    Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
338    if (status.ok()) return compile;
339
340    status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
341    if (status.ok()) return compile;
342
343    // Otherwise use the value of global_jit_level.
344    // Ignore enable_jit_by_default if global jit compilation for CPU
345    // is explicitly requested via tf_xla_cpu_global_jit flag
346    bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
347    return (ignore_registration || registration->enable_jit_by_default) &&
348           global_jit_level > 0;
349  };
350  return RunImpl(options, is_compilable);
351}
352
353// Is 'node' an operator that consumes only the shape of its input, not the
354// data itself?
355static bool IsShapeConsumerOp(const Node& node) {
356  return node.type_string() == "Shape" || node.type_string() == "Rank" ||
357         node.type_string() == "Size";
358}
359
360// Sequence number generator to ensure clusters have unique names.
361static std::atomic<int64> cluster_sequence_num;
362
363Status MarkForCompilationPass::RunImpl(
364    const GraphOptimizationPassOptions& options,
365    const std::function<bool(const Node*, const DeviceType&)>&
366        is_compilable_fn) {
367  VLOG(1) << "MarkForCompilationPass::Run";
368
369  // Make sure that kernels have been registered on the JIT device.
370  XlaOpRegistry::RegisterCompilationKernels();
371
372  Graph* graph = options.graph->get();
373
374  OrderedNodeSet compilation_candidates;
375  TF_RETURN_IF_ERROR(FindCompilationCandidates(
376      *graph, options.flib_def,
377      (options.session_options != nullptr) ? options.session_options->env
378                                           : Env::Default(),
379      is_compilable_fn, &compilation_candidates));
380
381  GraphCycles cycles;
382  for (int i = 0; i < graph->num_node_ids(); ++i) {
383    // We rely on the node IDs in the cycle detection graph being consecutive
384    // integers starting from 0.
385    CHECK_EQ(i, cycles.NewNode());
386  }
387
388  // Compute the loop structure of the graph.
389  std::vector<ControlFlowInfo> control_flow_info;
390  TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
391
392  // The clustering code must avoid adding cycles to the graph to prevent
393  // deadlock. However, the graph may contain loops, which would trigger the
394  // cycle detection code. To handle loops, we alter the structure of the cycle
395  // detection graph, disconnecting each loop from the enclosing graph.
396  // Specifically, we:
397  // * add a new "frame" node for each loop.
398  // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
399  //   to/from the corresponding frame node. In essence, we collapse the loop
400  //   into a single node for the purpose of cycle detection in the enclosing
401  //   graph.
402  // * the body of the loop should now be disconnected from the rest of the
403  //   graph; we make it acyclic by breaking loop backedges (edges outgoing from
404  //   "NextIteration" nodes.
405
406  // Map from frame name strings to node IDs in the cycle detection graph.
407  std::unordered_map<string, int> frame_nodes;
408
409  // Get the cycle graph node ID for frame 'frame_name', or add one if none
410  // exists.
411  auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) {
412    int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
413    if (frame_id < 0) {
414      // The emplace succeeded; we have not allocated a frame node yet.
415      frame_id = cycles.NewNode();
416    }
417    return frame_id;
418  };
419
420  for (Edge const* edge : graph->edges()) {
421    if (edge->dst()->IsEnter()) {
422      // Lift edges to an "Enter" node to the corresponding frame node.
423      const string& frame_name =
424          control_flow_info[edge->dst()->id()].frame_name;
425      int dst = GetOrAddFrameNodeId(frame_name);
426      if (!cycles.InsertEdge(edge->src()->id(), dst)) {
427        return errors::Internal(
428            "Cycle detected when adding enter->frame edge: ",
429            DescribeCycle(cycles, *graph, edge->src()->id(), dst));
430      }
431      continue;
432    }
433    if (edge->src()->IsExit()) {
434      // Lift edges from an "Exit" node to the corresponding frame node.
435      const string& frame_name =
436          control_flow_info[edge->src()->id()].frame_name;
437      int src = GetOrAddFrameNodeId(frame_name);
438      if (!cycles.InsertEdge(src, edge->dst()->id())) {
439        return errors::Internal(
440            "Cycle detected when adding frame->exit edge: ",
441            DescribeCycle(cycles, *graph, src, edge->dst()->id()));
442      }
443      // Drop the original edge.
444      continue;
445    }
446    if (edge->src()->IsNextIteration()) {
447      // Break loop back-edges.
448      continue;
449    }
450    if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) {
451      // This should never happen. All cycles in the graph should contain
452      // a control flow operator.
453      return errors::Internal(
454          "Found cycle in graph without control flow operator during XLA "
455          "compilation: ",
456          DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
457    }
458  }
459
460  // Each compilation candidate belongs to a cluster. The cluster's
461  // representative
462  // names the node in the 'cycles' graph that represents the cluster.
463  std::vector<UnionFind<Cluster>> clusters(graph->num_node_ids());
464  std::deque<UnionFind<Cluster>*> worklist;
465  for (Node* node : compilation_candidates) {
466    Cluster& cluster = clusters[node->id()].Get();
467    cluster.representative = node->id();
468    worklist.push_back(&clusters[node->id()]);
469  }
470
471  legacy_flags::MarkForCompilationPassFlags* flags =
472      legacy_flags::GetMarkForCompilationPassFlags();
473
474  // Repeatedly contract edges between clusters that are on the same device,
475  // provided the contraction would not create a cycle.
476  while (!worklist.empty()) {
477    int from = worklist.front()->Get().representative;
478    worklist.pop_front();
479
480    Node* node_from = graph->FindNodeId(from);
481    if (node_from->IsControlFlow()) {
482      // Control flow nodes aren't compilation candidates and should never
483      // appear.
484      return errors::Internal(
485          "Found control flow node in clustering worklist: ",
486          node_from->type_string());
487    }
488    string from_scope;
489    string to_scope;
490    for (int to : cycles.Successors(from)) {
491      if (to >= graph->num_node_ids()) {
492        // Node is a "frame" node that is present only in the cycle detection
493        // graph. No clustering is possible.
494        continue;
495      }
496      Node* node_to = graph->FindNodeId(to);
497      if (compilation_candidates.find(node_to) ==
498          compilation_candidates.cend()) {
499        continue;
500      }
501      if (node_from->assigned_device_name() !=
502          node_to->assigned_device_name()) {
503        continue;
504      }
505      // Look for an _XlaScope on both nodes.  If both nodes have a
506      // scope and the scopes do not match, do not cluster along this
507      // edge.  If even one of the nodes lacks an _XlaScope attribute,
508      // then it is treated as a "bridge" and a cluster may be created
509      // along it.  We may want to restrict this behavior to require
510      // all nodes marked with _XlaCompile=true to also have a
511      // _XlaScope property set (and raise an error otherwise); but
512      // for now we don't do this.
513      if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() &&
514          GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() &&
515          from_scope != to_scope) {
516        continue;
517      }
518
519      // Ops that consume shapes cannot be the root of a cluster. This is an
520      // optimization.
521      if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
522        continue;
523      }
524
525      // Don't exceed the maximum cluster size.
526      if (clusters[from].Size() + clusters[to].Size() >
527          flags->tf_xla_max_cluster_size) {
528        continue;
529      }
530
531      // If contracting the edge would create a cycle, bail out.
532      // However, just because we can't merge the clusters now does not mean
533      // we won't be able to merge them in the future.
534      // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
535      // 1->3. But if we first contract 1->2 then we can later contract 1->3.
536      if (!cycles.ContractEdge(from, to)) continue;
537
538      // Merge the clusters. ContractEdge uses 'from' as the number of the
539      // merged node, so make sure 'from' is the chosen representative.
540      clusters[from].Merge(&clusters[to]);
541
542      worklist.push_back(&clusters[from]);
543      break;
544    }
545  }
546
547  // Count the number of elements in each cluster.
548  std::vector<int> cluster_sizes(graph->num_node_ids());
549  for (const Node* n : compilation_candidates) {
550    int cluster = clusters[n->id()].Get().representative;
551    cluster_sizes[cluster]++;
552  }
553
554  // Names for each cluster.
555  std::unordered_map<int, string> cluster_names;
556
557  // Mark clusters for compilation that:
558  // * are placed on a device that requires compilation (an XlaDevice),
559  // * are explicitly marked for compilation (_XlaCompile=true), or
560  // * have more than flags->tf_xla_min_cluster_size elements (applicable only
561  //   if compilation is enabled, otherwise there will be no such candidates).
562  const int min_cluster_size = flags->tf_xla_min_cluster_size;
563  for (Node* n : compilation_candidates) {
564    int cluster = clusters[n->id()].Get().representative;
565
566    // Compile if the user marked this node _XlaCompile=true
567    bool compile_attr = false;
568    bool marked_for_compilation = false;
569    if (GetNodeAttr(n->attrs(), kXlaCompileAttr, &compile_attr).ok()) {
570      marked_for_compilation = compile_attr;
571    } else if (options.flib_def->GetAttr(*n, kXlaCompileAttr, &compile_attr)
572                   .ok()) {
573      marked_for_compilation = compile_attr;
574    }
575
576    // Compile if this operator is placed on a device that requires
577    // compilation.
578    DeviceType device_type("");
579    TF_RETURN_IF_ERROR(
580        DeviceTypeOfDevice(n->assigned_device_name(), &device_type));
581    const XlaOpRegistry::DeviceRegistration* registration;
582    XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration);
583
584    // Or compile if this is a cluster of >= min_cluster_size compilable
585    // operators.
586    if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation ||
587        registration->requires_compilation) {
588      string& name = cluster_names[cluster];
589
590      if (name.empty()) {
591        name = strings::StrCat("cluster_", cluster_sequence_num++);
592      }
593      n->AddAttr(kXlaClusterAttr, name);
594      VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
595    }
596  }
597
598  if (flags->tf_xla_clustering_debug) {
599    dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph,
600                                options.flib_def);
601  }
602  return Status::OK();
603}
604
605}  // namespace tensorflow
606