1/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/profiler/internal/tfprof_op.h"
17
18#include <stdio.h>
19#include <utility>
20
21#include "tensorflow/core/lib/strings/strcat.h"
22#include "tensorflow/core/lib/strings/stringprintf.h"
23#include "tensorflow/core/platform/regexp.h"
24#include "tensorflow/core/profiler/internal/tfprof_constants.h"
25#include "tensorflow/core/profiler/internal/tfprof_tensor.h"
26
27namespace tensorflow {
28namespace tfprof {
29namespace {
30string FormatToalExecTime(const ShowMultiNode* node,
31                          const ShowMultiNode* root) {
32  double accu_pct = 0.0;
33  double pct = 0.0;
34  if (node->proto().total_exec_micros() > 0) {
35    accu_pct = 100.0 * node->proto().total_exec_micros() /
36               root->proto().total_exec_micros();
37    pct =
38        100.0 * node->proto().exec_micros() / root->proto().total_exec_micros();
39  }
40
41  return strings::Printf(
42      "%30s", strings::Printf("%s (%.2f%%, %.2f%%)",
43                              FormatTime(node->proto().exec_micros()).c_str(),
44                              accu_pct, pct)
45                  .c_str());
46}
47string FormatCPUExecTime(const ShowMultiNode* node, const ShowMultiNode* root) {
48  double accu_pct = 0.0;
49  double pct = 0.0;
50  if (node->proto().total_cpu_exec_micros() > 0) {
51    accu_pct = 100.0 * node->proto().total_cpu_exec_micros() /
52               root->proto().total_cpu_exec_micros();
53    pct = 100.0 * node->proto().cpu_exec_micros() /
54          root->proto().total_cpu_exec_micros();
55  }
56
57  return strings::Printf(
58      "%30s",
59      strings::Printf("%s (%.2f%%, %.2f%%)",
60                      FormatTime(node->proto().cpu_exec_micros()).c_str(),
61                      accu_pct, pct)
62          .c_str());
63}
64string FormatAcceleratorExecTime(const ShowMultiNode* node,
65                                 const ShowMultiNode* root) {
66  double accu_pct = 0.0;
67  double pct = 0.0;
68  if (node->proto().total_accelerator_exec_micros() > 0) {
69    accu_pct = 100.0 * node->proto().total_accelerator_exec_micros() /
70               root->proto().total_accelerator_exec_micros();
71    pct = 100.0 * node->proto().accelerator_exec_micros() /
72          root->proto().total_accelerator_exec_micros();
73  }
74
75  return strings::Printf(
76      "%30s", strings::Printf(
77                  "%s (%.2f%%, %.2f%%)",
78                  FormatTime(node->proto().accelerator_exec_micros()).c_str(),
79                  accu_pct, pct)
80                  .c_str());
81}
82}  // namespace
83
84void TFOp::AddNode(TFGraphNode* node) {
85  const string& op = node->op();
86  if (tfcnodes_map_.find(op) == tfcnodes_map_.end()) {
87    tfcnodes_map_[op] =
88        std::unique_ptr<TFMultiGraphNode>(new TFMultiGraphNode(op));
89  }
90  TFMultiGraphNode* tfcnode = tfcnodes_map_[op].get();
91  tfcnode->AddGraphNode(node);
92}
93
94void TFOp::Build() {
95  for (auto& tn : tfcnodes_map_) {
96    cnodes_map_[tn.first] =
97        std::unique_ptr<OpNode>(new OpNode(tn.second.get()));
98  }
99
100  tfcnodes_map_[kTFProfRoot] =
101      std::unique_ptr<TFMultiGraphNode>(new TFMultiGraphNode(kTFProfRoot));
102  root_.reset(new OpNode(tfcnodes_map_[kTFProfRoot].get()));
103}
104
105const ShowMultiNode* TFOp::ShowInternal(const Options& opts,
106                                        Timeline* timeline) {
107  root_->ResetTotalStats();
108  if (opts.output_type == kOutput[3]) {
109    fprintf(stderr, "Only 'code' view supports pprof output now.\n");
110    return root_.get();
111  }
112  if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) {
113    root_->formatted_str = FormatNode(root_.get(), root_.get(), opts);
114  }
115  if (timeline) {
116    fprintf(stderr,
117            "op view doesn't support timeline yet. "
118            "Consider graph/scope/code view.\n");
119    return root_.get();
120  }
121  if (cnodes_map_.empty()) {
122    return root_.get();
123  }
124
125  std::vector<OpNode*> nodes;
126  for (auto& n : cnodes_map_) {
127    n.second->account = ReAccount(n.second.get(), opts);
128    n.second->ResetTotalStats();
129    n.second->AddSelfToTotalStats();
130    nodes.push_back(n.second.get());
131  }
132  nodes = SortNodes(nodes, opts);
133  // pre keeps track of previous visited node.
134  OpNode* pre = nullptr;
135  std::vector<OpNode*> account_nodes;
136  for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
137    if ((*it)->account) {
138      if (pre) (*it)->AggregateTotalStats(pre);
139      account_nodes.push_back(*it);
140      pre = *it;
141    }
142  }
143  std::reverse(std::begin(account_nodes), std::end(account_nodes));
144  if (pre) {
145    root_->AggregateTotalStats(pre);
146  }
147
148  // Perform the display and optionally redo accounting.
149  int64 depth = 0;
150  std::vector<OpNode*> show_nodes;
151  int64 start = SearchRoot(account_nodes, opts.start_name_regexes);
152  for (int64 i = start; i < account_nodes.size(); ++i, ++depth) {
153    OpNode* n = account_nodes[i];
154    if (ShouldTrim(n, opts.trim_name_regexes) || depth > opts.max_depth) {
155      break;
156    }
157    n->show = ShouldShow(n, opts, depth);
158    if (n->show) show_nodes.push_back(n);
159  }
160
161  pre = nullptr;
162  for (auto it = show_nodes.rbegin(); it != show_nodes.rend(); ++it) {
163    if (opts.account_displayed_op_only) {
164      (*it)->ResetTotalStats();
165      (*it)->AddSelfToTotalStats();
166      if (pre) (*it)->AggregateTotalStats(pre);
167    }
168    pre = *it;
169  }
170  if (opts.account_displayed_op_only) {
171    root_->ResetTotalStats();
172    if (pre) {
173      root_->AggregateTotalStats(pre);
174    }
175  }
176  if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) {
177    string display_str = FormatLegend(opts);
178    for (OpNode* node : show_nodes) {
179      display_str += FormatNode(node, root_.get(), opts);
180    }
181    // In op view, we don't show root (total). But it will still in proto.
182    // TODO(xpan): Is it the right choice?
183    root_->formatted_str = display_str;
184  }
185  // Populate the chidren field.
186  auto* pre_pb = root_->mutable_proto();
187  for (auto& show_node : show_nodes) {
188    pre_pb->clear_children();
189    pre_pb->add_children()->Swap(show_node->mutable_proto());
190    pre_pb = pre_pb->mutable_children(0);
191  }
192  return root_.get();
193}
194
195int64 TFOp::SearchRoot(const std::vector<OpNode*> nodes,
196                       const std::vector<string>& regexes) {
197  if (regexes.empty() || (regexes.size() == 1 && regexes[0] == ".*")) {
198    return 0;
199  }
200  int64 i = 0;
201  for (; i < nodes.size(); ++i) {
202    for (const string& regex : regexes) {
203      if (RE2::FullMatch(nodes[i]->name(), regex)) {
204        return i;
205      }
206    }
207  }
208  return i;
209}
210
211string TFOp::FormatMemoryNode(int64 node_total_bytes, int64 root_total_bytes,
212                              int64 node_bytes) const {
213  double accu_pct = 0.0;
214  double pct = 0.0;
215  if (node_bytes > 0) {
216    accu_pct = 100.0 * node_total_bytes / root_total_bytes;
217    pct = 100.0 * node_bytes / root_total_bytes;
218  }
219  return strings::Printf(
220      "%30s", strings::Printf("%s (%.2f%%, %.2f%%)",
221                              FormatMemory(node_bytes).c_str(), accu_pct, pct)
222                  .c_str());
223}
224
225string TFOp::FormatNode(OpNode* node, OpNode* root, const Options& opts) const {
226  std::vector<string> attrs;
227
228  if (opts.select.find(kShown[0]) != opts.select.end()) {
229    attrs.push_back(FormatMemoryNode(node->proto().total_requested_bytes(),
230                                     root->proto().total_requested_bytes(),
231                                     node->proto().requested_bytes()));
232  }
233
234  if (opts.select.find(kShown[11]) != opts.select.end()) {
235    attrs.push_back(FormatMemoryNode(node->proto().total_peak_bytes(),
236                                     root->proto().total_peak_bytes(),
237                                     node->proto().peak_bytes()));
238  }
239
240  if (opts.select.find(kShown[12]) != opts.select.end()) {
241    attrs.push_back(FormatMemoryNode(node->proto().total_residual_bytes(),
242                                     root->proto().total_residual_bytes(),
243                                     node->proto().residual_bytes()));
244  }
245  if (opts.select.find(kShown[13]) != opts.select.end()) {
246    attrs.push_back(FormatMemoryNode(node->proto().total_output_bytes(),
247                                     root->proto().total_output_bytes(),
248                                     node->proto().output_bytes()));
249  }
250
251  if (opts.select.find(kShown[1]) != opts.select.end()) {
252    attrs.push_back(FormatToalExecTime(node, root));
253    attrs.push_back(FormatAcceleratorExecTime(node, root));
254    attrs.push_back(FormatCPUExecTime(node, root));
255  }
256  if (opts.select.find(kShown[9]) != opts.select.end() &&
257      opts.select.find(kShown[1]) == opts.select.end()) {
258    attrs.push_back(FormatAcceleratorExecTime(node, root));
259  }
260  if (opts.select.find(kShown[10]) != opts.select.end() &&
261      opts.select.find(kShown[1]) == opts.select.end()) {
262    attrs.push_back(FormatCPUExecTime(node, root));
263  }
264  if (opts.select.find(kShown[2]) != opts.select.end()) {
265    double accu_pct = 0.0;
266    double pct = 0.0;
267    if (node->proto().total_parameters() > 0) {
268      accu_pct = 100.0 * node->proto().total_parameters() /
269                 root->proto().total_parameters();
270      pct =
271          100.0 * node->proto().parameters() / root->proto().total_parameters();
272    }
273    attrs.push_back(strings::Printf(
274        "%30s",
275        strings::Printf("%s params (%.2f%%, %.2f%%)",
276                        FormatNumber(node->proto().parameters()).c_str(),
277                        accu_pct, pct)
278            .c_str()));
279  }
280
281  if (opts.select.find(kShown[3]) != opts.select.end()) {
282    double accu_pct = 0.0;
283    double pct = 0.0;
284    if (node->proto().total_float_ops() > 0) {
285      accu_pct = 100.0 * node->proto().total_float_ops() /
286                 root->proto().total_float_ops();
287      pct = 100.0 * node->proto().float_ops() / root->proto().total_float_ops();
288    }
289
290    attrs.push_back(strings::Printf(
291        "%30s", strings::Printf("%s float_ops (%.2f%%, %.2f%%)",
292                                FormatNumber(node->proto().float_ops()).c_str(),
293                                accu_pct, pct)
294                    .c_str()));
295  }
296
297  if (opts.select.find(kShown[5]) != opts.select.end()) {
298    attrs.push_back(str_util::Join(node->node->devices(), "|"));
299  }
300
301  if (opts.select.find(kShown[6]) != opts.select.end()) {
302    std::set<string> op_types = node->node->op_types();
303    attrs.push_back(str_util::Join(op_types, "|"));
304  }
305
306  if (opts.select.find(kShown[7]) != opts.select.end()) {
307    int64 total_runs = 0;
308    for (const auto& gnode : node->proto().graph_nodes()) {
309      total_runs += gnode.run_count();
310    }
311    attrs.push_back(strings::Printf(
312        "%10s",
313        strings::Printf("%lld|%d", total_runs, node->proto().graph_nodes_size())
314            .c_str()));
315  }
316
317  string node_str = strings::Printf("%-25s%s\n", node->name().c_str(),
318                                    str_util::Join(attrs, ", ").c_str());
319
320  if (opts.select.find(kShown[8]) != opts.select.end()) {
321    string input_shape_str = FormatInputShapes(node->proto());
322    if (!input_shape_str.empty()) {
323      node_str = strings::Printf("%s\n%s\n\n", node_str.c_str(),
324                                 input_shape_str.c_str());
325    }
326  }
327  return node_str;
328}
329}  // namespace tfprof
330}  // namespace tensorflow
331