1/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/profiler/internal/tfprof_show.h"
17
18#include <memory>
19#include <set>
20
21#include "tensorflow/core/lib/strings/stringprintf.h"
22#include "tensorflow/core/platform/env.h"
23#include "tensorflow/core/platform/regexp.h"
24
25namespace tensorflow {
26namespace tfprof {
27
28const GraphNodeProto& TFShow::Show(const string& prefix, const Options& opts) {
29  if (opts.output_type == kOutput[0]) {
30    Timeline timeline(opts.step, opts.output_options.at(kTimelineOpts[0]));
31    return ShowInternal(opts, &timeline)->proto();
32  } else {
33    const ShowNode* ret = ShowInternal(opts, nullptr);
34    if (opts.output_type == kOutput[1]) {
35      printf("%s", (prefix + ret->formatted_str).c_str());
36      fflush(stdout);
37    } else if (opts.output_type == kOutput[2]) {
38      Status s = WriteStringToFile(Env::Default(),
39                                   opts.output_options.at(kFileOpts[0]),
40                                   prefix + ret->formatted_str);
41      if (!s.ok()) {
42        fprintf(stderr, "%s\n", s.ToString().c_str());
43      }
44    } else if (opts.output_type == kOutput[3] ||
45               opts.output_type == kOutput[4]) {
46    } else {
47      fprintf(stderr, "Unknown output type: %s\n", opts.output_type.c_str());
48    }
49    return ret->proto();
50  }
51}
52
53bool TFShow::LookUpCheckPoint(const string& name,
54                              std::unique_ptr<TFProfTensor>* tensor) {
55  if (name == kTFProfRoot || !ckpt_reader_ || !tensor) {
56    return false;
57  }
58  std::unique_ptr<Tensor> out_tensor;
59  TF_Status* status = TF_NewStatus();
60  ckpt_reader_->GetTensor(name, &out_tensor, status);
61  if (TF_GetCode(status) != TF_OK) {
62    fprintf(stderr, "%s\n", TF_Message(status));
63    TF_DeleteStatus(status);
64    return false;
65  }
66  tensor->reset(new TFProfTensor(std::move(out_tensor)));
67  TF_DeleteStatus(status);
68  return true;
69}
70
71bool TFShow::ShouldShow(const ShowNode* node, const Options& opts,
72                        int depth) const {
73  // Always show kTFProfRoot.
74  if (node->name() == kTFProfRoot) return true;
75
76  if (node->proto().total_requested_bytes() < opts.min_bytes ||
77      node->proto().total_peak_bytes() < opts.min_peak_bytes ||
78      node->proto().total_residual_bytes() < opts.min_residual_bytes ||
79      node->proto().total_output_bytes() < opts.min_output_bytes ||
80      node->proto().total_exec_micros() < opts.min_micros ||
81      node->proto().total_accelerator_exec_micros() <
82          opts.min_accelerator_micros ||
83      node->proto().total_cpu_exec_micros() < opts.min_cpu_micros ||
84      node->proto().parameters() < opts.min_params ||
85      node->proto().float_ops() < opts.min_float_ops ||
86      node->proto().run_count() < opts.min_occurrence ||
87      depth > opts.max_depth || !ShouldShowIfExtra(node, opts, depth)) {
88    return false;
89  }
90
91  bool show = false;
92  if (opts.show_name_regexes.size() == 1 && opts.show_name_regexes[0] == ".*") {
93    show = true;
94  } else {
95    for (const string& regex : opts.show_name_regexes) {
96      if (RE2::FullMatch(node->name(), regex)) {
97        show = true;
98        break;
99      }
100    }
101  }
102  // Don't show if show_name_regexes don't cover it.
103  if (!show) return false;
104  // Don't show if hide_name_regexes cover it.
105  for (const string& regex : opts.hide_name_regexes) {
106    if (RE2::FullMatch(node->name(), regex)) return false;
107  }
108  return true;
109}
110
111bool TFShow::ShouldTrim(const ShowNode* node,
112                        const std::vector<string>& regexes) const {
113  for (const string& regex : regexes) {
114    if (RE2::FullMatch(node->name(), regex)) {
115      return true;
116    }
117  }
118  return false;
119}
120
121bool TFShow::ReAccount(ShowNode* node, const Options& opts) {
122  node->ReInit(opts.step);
123  if (opts.account_type_regexes.size() == 1 &&
124      opts.account_type_regexes[0] == ".*") {
125    return true;
126  }
127  for (const string& regex : opts.account_type_regexes) {
128    for (const string& type : node->node->op_types()) {
129      if (RE2::FullMatch(type, regex)) {
130        return true;
131      }
132    }
133  }
134  return false;
135}
136
137string TFShow::FormatNodeMemory(ShowNode* node, int64 bytes,
138                                int64 total_bytes) const {
139  string memory = FormatMemory(total_bytes);
140  if (node->account) {
141    memory = FormatMemory(bytes) + "/" + memory;
142  } else {
143    memory = "--/" + memory;
144  }
145  return memory;
146}
147
148string TFShow::FormatNode(ShowNode* node, const Options& opts) const {
149  std::vector<string> info;
150  if (opts.select.find(kShown[2]) != opts.select.end()) {
151    const string shape = FormatShapes(node->node->shape());
152    if (!shape.empty()) {
153      info.push_back(shape);
154    }
155    string params = FormatNumber(node->proto().total_parameters()) + " params";
156    if (node->account) {
157      params = FormatNumber(node->proto().parameters()) + "/" + params;
158    } else {
159      params = "--/" + params;
160    }
161    info.push_back(params);
162  }
163  if (opts.select.find(kShown[3]) != opts.select.end()) {
164    string fops = FormatNumber(node->proto().total_float_ops()) + " flops";
165    if (node->account) {
166      fops = FormatNumber(node->proto().float_ops()) + "/" + fops;
167    } else {
168      fops = "--/" + fops;
169    }
170    info.push_back(fops);
171  }
172  std::vector<string> attrs;
173  if (opts.select.find(kShown[0]) != opts.select.end()) {
174    info.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
175                                    node->proto().total_requested_bytes()));
176  }
177  if (opts.select.find(kShown[11]) != opts.select.end()) {
178    info.push_back(FormatNodeMemory(node, node->proto().peak_bytes(),
179                                    node->proto().total_peak_bytes()));
180  }
181  if (opts.select.find(kShown[12]) != opts.select.end()) {
182    info.push_back(FormatNodeMemory(node, node->proto().residual_bytes(),
183                                    node->proto().total_residual_bytes()));
184  }
185  if (opts.select.find(kShown[13]) != opts.select.end()) {
186    info.push_back(FormatNodeMemory(node, node->proto().output_bytes(),
187                                    node->proto().total_output_bytes()));
188  }
189  if (opts.select.find(kShown[1]) != opts.select.end()) {
190    info.push_back(FormatTotalExecTime(node, opts));
191    info.push_back(FormatAcceleratorExecTime(node, opts));
192    info.push_back(FormatCPUExecTime(node, opts));
193  }
194  if (opts.select.find(kShown[9]) != opts.select.end() &&
195      opts.select.find(kShown[1]) == opts.select.end()) {
196    info.push_back(FormatAcceleratorExecTime(node, opts));
197  }
198  if (opts.select.find(kShown[10]) != opts.select.end() &&
199      opts.select.find(kShown[1]) == opts.select.end()) {
200    info.push_back(FormatCPUExecTime(node, opts));
201  }
202  if (opts.select.find(kShown[5]) != opts.select.end()) {
203    if (node->proto().devices_size() > 0) {
204      info.push_back(str_util::Join(node->proto().devices(), "|"));
205    }
206  }
207  if (opts.select.find(kShown[6]) != opts.select.end()) {
208    const std::set<string>& op_types = node->node->op_types();
209    info.push_back(str_util::Join(op_types, "|"));
210  }
211  if (opts.select.find(kShown[7]) != opts.select.end()) {
212    string run = FormatNumber(node->proto().total_run_count());
213    if (node->account) {
214      run = FormatNumber(node->proto().run_count()) + "/" + run;
215    } else {
216      run = "--/" + run;
217    }
218    string definition = FormatNumber(node->proto().total_definition_count());
219    if (node->account) {
220      definition = "1/" + definition;
221    } else {
222      definition = "--/" + definition;
223    }
224    info.push_back(run + "|" + definition);
225  }
226  if (opts.select.find(kShown[8]) != opts.select.end()) {
227    std::vector<string> shape_vec;
228    for (const auto& s : node->node->input_shapes()) {
229      if (s.second.empty()) {
230        shape_vec.push_back(strings::Printf("%d:unknown", s.first));
231      } else {
232        shape_vec.push_back(strings::Printf(
233            "%d:%s", s.first, str_util::Join(s.second, "x").c_str()));
234      }
235    }
236    info.push_back(str_util::Join(shape_vec, "|"));
237  }
238
239  return strings::Printf("%s (%s)", node->name().c_str(),
240                         str_util::Join(info, ", ").c_str());
241}
242
243string TFShow::FormatLegend(const Options& opts) const {
244  std::vector<string> legends;
245  if (opts.select.find(kShown[2]) != opts.select.end()) {
246    legends.push_back("# parameters");
247  }
248  if (opts.select.find(kShown[3]) != opts.select.end()) {
249    legends.push_back("# float_ops");
250  }
251  if (opts.select.find(kShown[0]) != opts.select.end()) {
252    legends.push_back("requested bytes");
253  }
254  if (opts.select.find(kShown[11]) != opts.select.end()) {
255    legends.push_back("peak bytes");
256  }
257  if (opts.select.find(kShown[12]) != opts.select.end()) {
258    legends.push_back("residual bytes");
259  }
260  if (opts.select.find(kShown[13]) != opts.select.end()) {
261    legends.push_back("output bytes");
262  }
263  if (opts.select.find(kShown[1]) != opts.select.end()) {
264    legends.push_back("total execution time");
265    legends.push_back("accelerator execution time");
266    legends.push_back("cpu execution time");
267  }
268  if (opts.select.find(kShown[9]) != opts.select.end() &&
269      opts.select.find(kShown[1]) == opts.select.end()) {
270    legends.push_back("accelerator execution time");
271  }
272  if (opts.select.find(kShown[10]) != opts.select.end() &&
273      opts.select.find(kShown[1]) == opts.select.end()) {
274    legends.push_back("cpu execution time");
275  }
276  if (opts.select.find(kShown[5]) != opts.select.end()) {
277    legends.push_back("assigned devices");
278  }
279  if (opts.select.find(kShown[6]) != opts.select.end()) {
280    legends.push_back("op types");
281  }
282  if (opts.select.find(kShown[7]) != opts.select.end()) {
283    legends.push_back("op count (run|defined)");
284  }
285  if (opts.select.find(kShown[8]) != opts.select.end()) {
286    legends.push_back("input shapes");
287  }
288  return strings::Printf("node name | %s\n",
289                         str_util::Join(legends, " | ").c_str());
290}
291
292}  // namespace tfprof
293}  // namespace tensorflow
294