1/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/profiler/internal/tfprof_graph.h"
17
18#include <stdio.h>
19#include <utility>
20
21#include "tensorflow/core/lib/strings/strcat.h"
22#include "tensorflow/core/lib/strings/stringprintf.h"
23#include "tensorflow/core/platform/regexp.h"
24#include "tensorflow/core/profiler/internal/tfprof_constants.h"
25#include "tensorflow/core/profiler/internal/tfprof_tensor.h"
26
27namespace tensorflow {
28namespace tfprof {
29GraphNode* TFGraph::CreateParentNode(const string& name) {
30  node_defs_.push_back(std::unique_ptr<NodeDef>(new NodeDef()));
31  node_defs_.back()->set_name(name);
32  node_defs_.back()->set_op(kTFGraphParent);
33  parent_nodes_[name] = std::unique_ptr<TFGraphNode>(
34      new TFGraphNode(node_defs_.back().get(), -1, nullptr));
35  nodes_map_[name] =
36      std::unique_ptr<GraphNode>(new GraphNode(parent_nodes_[name].get()));
37  return nodes_map_[name].get();
38}
39
40void TFGraph::AddNode(TFGraphNode* node) {
41  string name = node->name();
42  nodes_map_[name] = std::unique_ptr<GraphNode>(new GraphNode(node));
43}
44
45void TFGraph::Build() {
46  if (root_) return;
47
48  std::set<string> nonroots;
49  // Filter out the root nodes (node not input of any other node).
50  for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
51    GraphNode* node = it->second.get();
52    const std::map<int, string>& inputs = node->node->inputs();
53    for (auto inputs_it = inputs.cbegin(); inputs_it != inputs.cend();
54         inputs_it++) {
55      nonroots.insert(inputs_it->second);
56      auto child_it = nodes_map_.find(inputs_it->second);
57      if (child_it != nodes_map_.end()) {
58        node->children.push_back(child_it->second.get());
59      }
60    }
61  }
62  std::vector<GraphNode*> roots;
63  for (auto it = nodes_map_.begin(); it != nodes_map_.end(); it++) {
64    if (nonroots.find(it->first) == nonroots.end()) {
65      roots.push_back(it->second.get());
66    }
67  }
68  root_ = CreateParentNode(kTFProfRoot);
69  root_->children.insert(root_->children.end(), roots.begin(), roots.end());
70}
71
72const ShowNode* TFGraph::ShowInternal(const Options& opts, Timeline* timeline) {
73  root_->ResetTotalStats();
74  root_->show_children.clear();
75
76  if (opts.output_type == kOutput[3]) {
77    fprintf(stderr, "Only 'code' view supports pprof output now.\n");
78    return root_;
79  }
80  if (timeline && timeline->step() < 0) {
81    // TODO(xpan): Maybe pick a default step for users.
82    fprintf(stderr,
83            "Must specify -step option to generate timeline in graph view.\n");
84    return root_;
85  }
86  // 1. Account and aggregate the stats based on the graph structure.
87  // Returns a graph consists of accounted nodes.
88  std::set<string> visits;
89  std::vector<GraphNode*> roots = Account(root_->children, opts, &visits);
90  for (GraphNode* n : roots) {
91    root_->AggregateTotalStats(n);
92  }
93
94  // 2. Trim the nodes before start_name_regexes.
95  if (opts.start_name_regexes.size() != 1 ||
96      opts.start_name_regexes[0] != ".*") {
97    visits.clear();
98    roots = SearchRoot(roots, opts.start_name_regexes, &visits);
99  }
100
101  // 3. Trim the nodes not matching show/hide/trim_name_regexes.
102  // If account_displayed_op_only=true, redo the accounting.
103  visits.clear();
104  root_->show_children.assign(roots.begin(), roots.end());
105  GraphNode* root = PrintGraph({root_}, opts, 1, 0, &visits)[0];
106
107  // 4. Prepare output based on the final graphs.
108  root->formatted_str = FormatLegend(opts) + root->formatted_str;
109  Format(root->show_children, &root->formatted_str, root->mutable_proto());
110
111  if (timeline) {
112    timeline->GenerateGraphTimeline(root->show_children);
113  }
114  return root;
115}
116
117std::vector<GraphNode*> TFGraph::SearchRoot(
118    const std::vector<GraphNode*>& roots, const std::vector<string>& regexes,
119    std::set<string>* visited) {
120  std::vector<GraphNode*> res;
121  if (roots.empty()) {
122    return res;
123  }
124  for (GraphNode* root : roots) {
125    if (visited->find(root->name()) != visited->end()) continue;
126    visited->insert(root->name());
127    // If the parent is a start point, don't search its children.
128    // Note that its children can still be added as start node through
129    // another route.
130    bool match_start_node = false;
131    for (const string& regex : regexes) {
132      if (RE2::FullMatch(root->name(), regex)) {
133        res.push_back(root);
134        match_start_node = true;
135        break;
136      }
137    }
138    if (match_start_node) {
139      continue;
140    }
141    std::vector<GraphNode*> nroot =
142        SearchRoot(root->show_children, regexes, visited);
143    res.insert(res.end(), nroot.begin(), nroot.end());
144  }
145  return res;
146}
147
148void TFGraph::Format(const std::vector<GraphNode*> roots, string* display_str,
149                     GraphNodeProto* proto) {
150  for (GraphNode* node : roots) {
151    display_str->append(node->formatted_str);
152    GraphNodeProto* child = proto->add_children();
153    child->MergeFrom(node->proto());
154    Format(node->show_children, display_str, child);
155  }
156}
157
158std::vector<GraphNode*> TFGraph::PrintGraph(const std::vector<GraphNode*> roots,
159                                            const Options& opts, int depth,
160                                            int last_ident,
161                                            std::set<string>* visits) {
162  std::vector<GraphNode*> show_nodes;
163
164  for (GraphNode* node : roots) {
165    if (visits->find(node->name()) != visits->end()) continue;
166    visits->insert(node->name());
167
168    bool show = ShouldShow(node, opts, depth);
169    int indent = last_ident;
170    if (show) indent += 2;
171
172    std::vector<GraphNode*> show_cnodes;
173    if (!ShouldTrim(node, opts.trim_name_regexes) && depth <= opts.max_depth) {
174      show_cnodes =
175          PrintGraph(node->show_children, opts, depth + 1, indent, visits);
176    }
177    if (show) {
178      node->show_children.clear();
179      if (opts.account_displayed_op_only) {
180        node->ResetTotalStats();
181        node->AddSelfToTotalStats();
182      }
183
184      show_cnodes = SortNodes(show_cnodes, opts);
185      for (GraphNode* sc : show_cnodes) {
186        node->show_children.push_back(sc);
187        if (opts.account_displayed_op_only) {
188          node->AggregateTotalStats(sc);
189        }
190      }
191      node->formatted_str =
192          strings::Printf("%s%s\n", string(last_ident, ' ').c_str(),
193                          FormatNode(node, opts).c_str());
194
195      if (opts.select.find(kShown[4]) != opts.select.end()) {
196        std::unique_ptr<TFProfTensor> tfprof_tensor;
197        if (LookUpCheckPoint(node->name(), &tfprof_tensor)) {
198          string value_str;
199          tfprof_tensor->Display(&value_str,
200                                 node->mutable_proto()->mutable_tensor_value());
201          node->formatted_str += value_str;
202        }
203      }
204      show_nodes.push_back(node);
205    } else {
206      show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
207                        show_cnodes.end());
208    }
209  }
210  return show_nodes;
211}
212
213std::vector<GraphNode*> TFGraph::Account(const std::vector<GraphNode*>& roots,
214                                         const Options& opts,
215                                         std::set<string>* visits) {
216  std::vector<GraphNode*> act_nodes;
217  for (GraphNode* node : roots) {
218    if (visits->find(node->name()) != visits->end()) continue;
219    visits->insert(node->name());
220    // Depth-first.
221    std::vector<GraphNode*> act_cnodes = Account(node->children, opts, visits);
222
223    node->account = ReAccount(node, opts);
224    if (node->account) {
225      node->show_children.clear();
226      node->ResetTotalStats();
227      node->AddSelfToTotalStats();
228      // Aggregate its accounted children stats.
229      for (GraphNode* c : act_cnodes) {
230        node->AggregateTotalStats(c);
231        node->show_children.push_back(c);
232      }
233      act_nodes.push_back(node);
234    } else {
235      // If the current node is not accounted, pass the children to the
236      // ancestor.
237      act_nodes.insert(act_nodes.end(), act_cnodes.begin(), act_cnodes.end());
238    }
239  }
240  return act_nodes;
241}
242}  // namespace tfprof
243}  // namespace tensorflow
244