1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for pprof_profiler.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gzip 22 23from proto import profile_pb2 24from tensorflow.core.framework import step_stats_pb2 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.python.framework import constant_op 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.platform import test 30from tensorflow.python.profiler import pprof_profiler 31 32 33class PprofProfilerTest(test.TestCase): 34 35 def testDataEmpty(self): 36 output_dir = test.get_temp_dir() 37 run_metadata = config_pb2.RunMetadata() 38 graph = test.mock.MagicMock() 39 graph.get_operations.return_value = [] 40 41 profiles = pprof_profiler.get_profiles(graph, run_metadata) 42 self.assertEquals(0, len(profiles)) 43 profile_files = pprof_profiler.profile( 44 graph, run_metadata, output_dir) 45 self.assertEquals(0, len(profile_files)) 46 47 def testRunMetadataEmpty(self): 48 output_dir = test.get_temp_dir() 49 run_metadata = config_pb2.RunMetadata() 50 graph = test.mock.MagicMock() 51 op1 = test.mock.MagicMock() 52 op1.name = 'Add/123' 53 op1.traceback = [('a/b/file1', 10, 'some_var')] 54 op1.type = 'add' 55 graph.get_operations.return_value = [op1] 56 57 profiles = pprof_profiler.get_profiles(graph, run_metadata) 58 self.assertEquals(0, len(profiles)) 59 profile_files = pprof_profiler.profile( 60 graph, run_metadata, output_dir) 61 self.assertEquals(0, len(profile_files)) 62 63 def testValidProfile(self): 64 output_dir = test.get_temp_dir() 65 run_metadata = config_pb2.RunMetadata() 66 67 node1 = step_stats_pb2.NodeExecStats( 68 node_name='Add/123', 69 op_start_rel_micros=3, 70 op_end_rel_micros=5, 71 all_end_rel_micros=4) 72 73 run_metadata = config_pb2.RunMetadata() 74 device1 = run_metadata.step_stats.dev_stats.add() 75 device1.device = 'deviceA' 76 device1.node_stats.extend([node1]) 77 78 graph = test.mock.MagicMock() 79 op1 = test.mock.MagicMock() 80 op1.name = 'Add/123' 81 op1.traceback = [ 82 ('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')] 83 op1.type = 'add' 84 graph.get_operations.return_value = [op1] 85 86 expected_proto = """sample_type { 87 type: 5 88 unit: 5 89} 90sample_type { 91 type: 6 92 unit: 7 93} 94sample_type { 95 type: 8 96 unit: 7 97} 98sample { 99 value: 1 100 value: 4 101 value: 2 102 label { 103 key: 1 104 str: 2 105 } 106 label { 107 key: 3 108 str: 4 109 } 110} 111string_table: "" 112string_table: "node_name" 113string_table: "Add/123" 114string_table: "op_type" 115string_table: "add" 116string_table: "count" 117string_table: "all_time" 118string_table: "nanoseconds" 119string_table: "op_time" 120string_table: "Device 1 of 1: deviceA" 121comment: 9 122""" 123 # Test with protos 124 profiles = pprof_profiler.get_profiles(graph, run_metadata) 125 self.assertEquals(1, len(profiles)) 126 self.assertTrue('deviceA' in profiles) 127 self.assertEquals(expected_proto, str(profiles['deviceA'])) 128 # Test with files 129 profile_files = pprof_profiler.profile( 130 graph, run_metadata, output_dir) 131 self.assertEquals(1, len(profile_files)) 132 with gzip.open(profile_files[0]) as profile_file: 133 profile_contents = profile_file.read() 134 profile = profile_pb2.Profile() 135 profile.ParseFromString(profile_contents) 136 self.assertEquals(expected_proto, str(profile)) 137 138 def testProfileWithWhileLoop(self): 139 options = config_pb2.RunOptions() 140 options.trace_level = config_pb2.RunOptions.FULL_TRACE 141 run_metadata = config_pb2.RunMetadata() 142 143 num_iters = 5 144 with self.test_session() as sess: 145 i = constant_op.constant(0) 146 c = lambda i: math_ops.less(i, num_iters) 147 b = lambda i: math_ops.add(i, 1) 148 r = control_flow_ops.while_loop(c, b, [i]) 149 sess.run(r, options=options, run_metadata=run_metadata) 150 profiles = pprof_profiler.get_profiles(sess.graph, run_metadata) 151 self.assertEquals(1, len(profiles)) 152 profile = next(iter(profiles.values())) 153 add_samples = [] # Samples for the while/Add node 154 for sample in profile.sample: 155 if profile.string_table[sample.label[0].str] == 'while/Add': 156 add_samples.append(sample) 157 # Values for same nodes are aggregated. 158 self.assertEquals(1, len(add_samples)) 159 # Value of "count" should be equal to number of iterations. 160 self.assertEquals(num_iters, add_samples[0].value[0]) 161 162 163if __name__ == '__main__': 164 test.main() 165