1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for pprof_profiler."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gzip
22
23from proto import profile_pb2
24from tensorflow.core.framework import step_stats_pb2
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.python.framework import constant_op
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.platform import test
30from tensorflow.python.profiler import pprof_profiler
31
32
33class PprofProfilerTest(test.TestCase):
34
35  def testDataEmpty(self):
36    output_dir = test.get_temp_dir()
37    run_metadata = config_pb2.RunMetadata()
38    graph = test.mock.MagicMock()
39    graph.get_operations.return_value = []
40
41    profiles = pprof_profiler.get_profiles(graph, run_metadata)
42    self.assertEquals(0, len(profiles))
43    profile_files = pprof_profiler.profile(
44        graph, run_metadata, output_dir)
45    self.assertEquals(0, len(profile_files))
46
47  def testRunMetadataEmpty(self):
48    output_dir = test.get_temp_dir()
49    run_metadata = config_pb2.RunMetadata()
50    graph = test.mock.MagicMock()
51    op1 = test.mock.MagicMock()
52    op1.name = 'Add/123'
53    op1.traceback = [('a/b/file1', 10, 'some_var')]
54    op1.type = 'add'
55    graph.get_operations.return_value = [op1]
56
57    profiles = pprof_profiler.get_profiles(graph, run_metadata)
58    self.assertEquals(0, len(profiles))
59    profile_files = pprof_profiler.profile(
60        graph, run_metadata, output_dir)
61    self.assertEquals(0, len(profile_files))
62
63  def testValidProfile(self):
64    output_dir = test.get_temp_dir()
65    run_metadata = config_pb2.RunMetadata()
66
67    node1 = step_stats_pb2.NodeExecStats(
68        node_name='Add/123',
69        op_start_rel_micros=3,
70        op_end_rel_micros=5,
71        all_end_rel_micros=4)
72
73    run_metadata = config_pb2.RunMetadata()
74    device1 = run_metadata.step_stats.dev_stats.add()
75    device1.device = 'deviceA'
76    device1.node_stats.extend([node1])
77
78    graph = test.mock.MagicMock()
79    op1 = test.mock.MagicMock()
80    op1.name = 'Add/123'
81    op1.traceback = [
82        ('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')]
83    op1.type = 'add'
84    graph.get_operations.return_value = [op1]
85
86    expected_proto = """sample_type {
87  type: 5
88  unit: 5
89}
90sample_type {
91  type: 6
92  unit: 7
93}
94sample_type {
95  type: 8
96  unit: 7
97}
98sample {
99  value: 1
100  value: 4
101  value: 2
102  label {
103    key: 1
104    str: 2
105  }
106  label {
107    key: 3
108    str: 4
109  }
110}
111string_table: ""
112string_table: "node_name"
113string_table: "Add/123"
114string_table: "op_type"
115string_table: "add"
116string_table: "count"
117string_table: "all_time"
118string_table: "nanoseconds"
119string_table: "op_time"
120string_table: "Device 1 of 1: deviceA"
121comment: 9
122"""
123    # Test with protos
124    profiles = pprof_profiler.get_profiles(graph, run_metadata)
125    self.assertEquals(1, len(profiles))
126    self.assertTrue('deviceA' in profiles)
127    self.assertEquals(expected_proto, str(profiles['deviceA']))
128    # Test with files
129    profile_files = pprof_profiler.profile(
130        graph, run_metadata, output_dir)
131    self.assertEquals(1, len(profile_files))
132    with gzip.open(profile_files[0]) as profile_file:
133      profile_contents = profile_file.read()
134      profile = profile_pb2.Profile()
135      profile.ParseFromString(profile_contents)
136      self.assertEquals(expected_proto, str(profile))
137
138  def testProfileWithWhileLoop(self):
139    options = config_pb2.RunOptions()
140    options.trace_level = config_pb2.RunOptions.FULL_TRACE
141    run_metadata = config_pb2.RunMetadata()
142
143    num_iters = 5
144    with self.test_session() as sess:
145      i = constant_op.constant(0)
146      c = lambda i: math_ops.less(i, num_iters)
147      b = lambda i: math_ops.add(i, 1)
148      r = control_flow_ops.while_loop(c, b, [i])
149      sess.run(r, options=options, run_metadata=run_metadata)
150      profiles = pprof_profiler.get_profiles(sess.graph, run_metadata)
151      self.assertEquals(1, len(profiles))
152      profile = next(iter(profiles.values()))
153      add_samples = []  # Samples for the while/Add node
154      for sample in profile.sample:
155        if profile.string_table[sample.label[0].str] == 'while/Add':
156          add_samples.append(sample)
157      # Values for same nodes are aggregated.
158      self.assertEquals(1, len(add_samples))
159      # Value of "count" should be equal to number of iterations.
160      self.assertEquals(num_iters, add_samples[0].value[0])
161
162
163if __name__ == '__main__':
164  test.main()
165