1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.framework.importer.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import json 21import os 22import random 23 24from tensorflow.core.util import test_log_pb2 25from tensorflow.python.client import session 26from tensorflow.python.framework import constant_op 27from tensorflow.python.platform import benchmark 28from tensorflow.python.platform import gfile 29from tensorflow.python.platform import test 30 31# Used by SomeRandomBenchmark class below. 32_ran_somebenchmark_1 = [False] 33_ran_somebenchmark_2 = [False] 34_ran_somebenchmark_but_shouldnt = [False] 35 36 37class SomeRandomBenchmark(test.Benchmark): 38 """This Benchmark should automatically be registered in the registry.""" 39 40 def _dontRunThisBenchmark(self): 41 _ran_somebenchmark_but_shouldnt[0] = True 42 43 def notBenchmarkMethod(self): 44 _ran_somebenchmark_but_shouldnt[0] = True 45 46 def benchmark1(self): 47 _ran_somebenchmark_1[0] = True 48 49 def benchmark2(self): 50 _ran_somebenchmark_2[0] = True 51 52 53class TestReportingBenchmark(test.Benchmark): 54 """This benchmark (maybe) reports some stuff.""" 55 56 def benchmarkReport1(self): 57 self.report_benchmark(iters=1) 58 59 def benchmarkReport2(self): 60 self.report_benchmark( 61 iters=2, 62 name="custom_benchmark_name", 63 extras={"number_key": 3, 64 "other_key": "string"}) 65 66 def benchmark_times_an_op(self): 67 with session.Session() as sess: 68 a = constant_op.constant(0.0) 69 a_plus_a = a + a 70 self.run_op_benchmark( 71 sess, a_plus_a, min_iters=1000, store_trace=True, name="op_benchmark") 72 73 74class BenchmarkTest(test.TestCase): 75 76 def testGlobalBenchmarkRegistry(self): 77 registry = list(benchmark.GLOBAL_BENCHMARK_REGISTRY) 78 self.assertEqual(len(registry), 2) 79 self.assertTrue(SomeRandomBenchmark in registry) 80 self.assertTrue(TestReportingBenchmark in registry) 81 82 def testRunSomeRandomBenchmark(self): 83 # Validate that SomeBenchmark has not run yet 84 self.assertFalse(_ran_somebenchmark_1[0]) 85 self.assertFalse(_ran_somebenchmark_2[0]) 86 self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) 87 88 # Run other benchmarks, but this wont run the one we care about 89 benchmark._run_benchmarks("unrelated") 90 91 # Validate that SomeBenchmark has not run yet 92 self.assertFalse(_ran_somebenchmark_1[0]) 93 self.assertFalse(_ran_somebenchmark_2[0]) 94 self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) 95 96 # Run all the benchmarks, avoid generating any reports 97 if benchmark.TEST_REPORTER_TEST_ENV in os.environ: 98 del os.environ[benchmark.TEST_REPORTER_TEST_ENV] 99 benchmark._run_benchmarks("SomeRandom") 100 101 # Validate that SomeRandomBenchmark ran correctly 102 self.assertTrue(_ran_somebenchmark_1[0]) 103 self.assertTrue(_ran_somebenchmark_2[0]) 104 self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) 105 106 _ran_somebenchmark_1[0] = False 107 _ran_somebenchmark_2[0] = False 108 _ran_somebenchmark_but_shouldnt[0] = False 109 110 # Test running a specific method of SomeRandomBenchmark 111 if benchmark.TEST_REPORTER_TEST_ENV in os.environ: 112 del os.environ[benchmark.TEST_REPORTER_TEST_ENV] 113 benchmark._run_benchmarks("SomeRandom.*1$") 114 115 self.assertTrue(_ran_somebenchmark_1[0]) 116 self.assertFalse(_ran_somebenchmark_2[0]) 117 self.assertFalse(_ran_somebenchmark_but_shouldnt[0]) 118 119 def testReportingBenchmark(self): 120 tempdir = test.get_temp_dir() 121 try: 122 gfile.MakeDirs(tempdir) 123 except OSError as e: 124 # It's OK if the directory already exists. 125 if " exists:" not in str(e): 126 raise e 127 128 prefix = os.path.join(tempdir, 129 "reporting_bench_%016x_" % random.getrandbits(64)) 130 expected_output_file = "%s%s" % (prefix, 131 "TestReportingBenchmark.benchmarkReport1") 132 expected_output_file_2 = "%s%s" % ( 133 prefix, "TestReportingBenchmark.custom_benchmark_name") 134 expected_output_file_3 = "%s%s" % (prefix, 135 "TestReportingBenchmark.op_benchmark") 136 try: 137 self.assertFalse(gfile.Exists(expected_output_file)) 138 # Run benchmark but without env, shouldn't write anything 139 if benchmark.TEST_REPORTER_TEST_ENV in os.environ: 140 del os.environ[benchmark.TEST_REPORTER_TEST_ENV] 141 reporting = TestReportingBenchmark() 142 reporting.benchmarkReport1() # This should run without writing anything 143 self.assertFalse(gfile.Exists(expected_output_file)) 144 145 # Runbenchmark with env, should write 146 os.environ[benchmark.TEST_REPORTER_TEST_ENV] = prefix 147 148 reporting = TestReportingBenchmark() 149 reporting.benchmarkReport1() # This should write 150 reporting.benchmarkReport2() # This should write 151 reporting.benchmark_times_an_op() # This should write 152 153 # Check the files were written 154 self.assertTrue(gfile.Exists(expected_output_file)) 155 self.assertTrue(gfile.Exists(expected_output_file_2)) 156 self.assertTrue(gfile.Exists(expected_output_file_3)) 157 158 # Check the contents are correct 159 expected_1 = test_log_pb2.BenchmarkEntry() 160 expected_1.name = "TestReportingBenchmark.benchmarkReport1" 161 expected_1.iters = 1 162 163 expected_2 = test_log_pb2.BenchmarkEntry() 164 expected_2.name = "TestReportingBenchmark.custom_benchmark_name" 165 expected_2.iters = 2 166 expected_2.extras["number_key"].double_value = 3 167 expected_2.extras["other_key"].string_value = "string" 168 169 expected_3 = test_log_pb2.BenchmarkEntry() 170 expected_3.name = "TestReportingBenchmark.op_benchmark" 171 expected_3.iters = 1000 172 173 def read_benchmark_entry(f): 174 s = gfile.GFile(f, "rb").read() 175 entries = test_log_pb2.BenchmarkEntries.FromString(s) 176 self.assertEquals(1, len(entries.entry)) 177 return entries.entry[0] 178 179 read_benchmark_1 = read_benchmark_entry(expected_output_file) 180 self.assertProtoEquals(expected_1, read_benchmark_1) 181 182 read_benchmark_2 = read_benchmark_entry(expected_output_file_2) 183 self.assertProtoEquals(expected_2, read_benchmark_2) 184 185 read_benchmark_3 = read_benchmark_entry(expected_output_file_3) 186 self.assertEquals(expected_3.name, read_benchmark_3.name) 187 self.assertEquals(expected_3.iters, read_benchmark_3.iters) 188 self.assertGreater(read_benchmark_3.wall_time, 0) 189 full_trace = read_benchmark_3.extras["full_trace_chrome_format"] 190 json_trace = json.loads(full_trace.string_value) 191 self.assertTrue(isinstance(json_trace, dict)) 192 self.assertTrue("traceEvents" in json_trace.keys()) 193 allocator_keys = [k for k in read_benchmark_3.extras.keys() 194 if k.startswith("allocator_maximum_num_bytes_")] 195 self.assertGreater(len(allocator_keys), 0) 196 for k in allocator_keys: 197 self.assertGreater(read_benchmark_3.extras[k].double_value, 0) 198 199 finally: 200 gfile.DeleteRecursively(tempdir) 201 202 203if __name__ == "__main__": 204 test.main() 205