1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.importer."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import json
21import os
22import random
23
24from tensorflow.core.util import test_log_pb2
25from tensorflow.python.client import session
26from tensorflow.python.framework import constant_op
27from tensorflow.python.platform import benchmark
28from tensorflow.python.platform import gfile
29from tensorflow.python.platform import test
30
31# Used by SomeRandomBenchmark class below.
32_ran_somebenchmark_1 = [False]
33_ran_somebenchmark_2 = [False]
34_ran_somebenchmark_but_shouldnt = [False]
35
36
37class SomeRandomBenchmark(test.Benchmark):
38  """This Benchmark should automatically be registered in the registry."""
39
40  def _dontRunThisBenchmark(self):
41    _ran_somebenchmark_but_shouldnt[0] = True
42
43  def notBenchmarkMethod(self):
44    _ran_somebenchmark_but_shouldnt[0] = True
45
46  def benchmark1(self):
47    _ran_somebenchmark_1[0] = True
48
49  def benchmark2(self):
50    _ran_somebenchmark_2[0] = True
51
52
53class TestReportingBenchmark(test.Benchmark):
54  """This benchmark (maybe) reports some stuff."""
55
56  def benchmarkReport1(self):
57    self.report_benchmark(iters=1)
58
59  def benchmarkReport2(self):
60    self.report_benchmark(
61        iters=2,
62        name="custom_benchmark_name",
63        extras={"number_key": 3,
64                "other_key": "string"})
65
66  def benchmark_times_an_op(self):
67    with session.Session() as sess:
68      a = constant_op.constant(0.0)
69      a_plus_a = a + a
70      self.run_op_benchmark(
71          sess, a_plus_a, min_iters=1000, store_trace=True, name="op_benchmark")
72
73
74class BenchmarkTest(test.TestCase):
75
76  def testGlobalBenchmarkRegistry(self):
77    registry = list(benchmark.GLOBAL_BENCHMARK_REGISTRY)
78    self.assertEqual(len(registry), 2)
79    self.assertTrue(SomeRandomBenchmark in registry)
80    self.assertTrue(TestReportingBenchmark in registry)
81
82  def testRunSomeRandomBenchmark(self):
83    # Validate that SomeBenchmark has not run yet
84    self.assertFalse(_ran_somebenchmark_1[0])
85    self.assertFalse(_ran_somebenchmark_2[0])
86    self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
87
88    # Run other benchmarks, but this wont run the one we care about
89    benchmark._run_benchmarks("unrelated")
90
91    # Validate that SomeBenchmark has not run yet
92    self.assertFalse(_ran_somebenchmark_1[0])
93    self.assertFalse(_ran_somebenchmark_2[0])
94    self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
95
96    # Run all the benchmarks, avoid generating any reports
97    if benchmark.TEST_REPORTER_TEST_ENV in os.environ:
98      del os.environ[benchmark.TEST_REPORTER_TEST_ENV]
99    benchmark._run_benchmarks("SomeRandom")
100
101    # Validate that SomeRandomBenchmark ran correctly
102    self.assertTrue(_ran_somebenchmark_1[0])
103    self.assertTrue(_ran_somebenchmark_2[0])
104    self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
105
106    _ran_somebenchmark_1[0] = False
107    _ran_somebenchmark_2[0] = False
108    _ran_somebenchmark_but_shouldnt[0] = False
109
110    # Test running a specific method of SomeRandomBenchmark
111    if benchmark.TEST_REPORTER_TEST_ENV in os.environ:
112      del os.environ[benchmark.TEST_REPORTER_TEST_ENV]
113    benchmark._run_benchmarks("SomeRandom.*1$")
114
115    self.assertTrue(_ran_somebenchmark_1[0])
116    self.assertFalse(_ran_somebenchmark_2[0])
117    self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
118
119  def testReportingBenchmark(self):
120    tempdir = test.get_temp_dir()
121    try:
122      gfile.MakeDirs(tempdir)
123    except OSError as e:
124      # It's OK if the directory already exists.
125      if " exists:" not in str(e):
126        raise e
127
128    prefix = os.path.join(tempdir,
129                          "reporting_bench_%016x_" % random.getrandbits(64))
130    expected_output_file = "%s%s" % (prefix,
131                                     "TestReportingBenchmark.benchmarkReport1")
132    expected_output_file_2 = "%s%s" % (
133        prefix, "TestReportingBenchmark.custom_benchmark_name")
134    expected_output_file_3 = "%s%s" % (prefix,
135                                       "TestReportingBenchmark.op_benchmark")
136    try:
137      self.assertFalse(gfile.Exists(expected_output_file))
138      # Run benchmark but without env, shouldn't write anything
139      if benchmark.TEST_REPORTER_TEST_ENV in os.environ:
140        del os.environ[benchmark.TEST_REPORTER_TEST_ENV]
141      reporting = TestReportingBenchmark()
142      reporting.benchmarkReport1()  # This should run without writing anything
143      self.assertFalse(gfile.Exists(expected_output_file))
144
145      # Runbenchmark with env, should write
146      os.environ[benchmark.TEST_REPORTER_TEST_ENV] = prefix
147
148      reporting = TestReportingBenchmark()
149      reporting.benchmarkReport1()  # This should write
150      reporting.benchmarkReport2()  # This should write
151      reporting.benchmark_times_an_op()  # This should write
152
153      # Check the files were written
154      self.assertTrue(gfile.Exists(expected_output_file))
155      self.assertTrue(gfile.Exists(expected_output_file_2))
156      self.assertTrue(gfile.Exists(expected_output_file_3))
157
158      # Check the contents are correct
159      expected_1 = test_log_pb2.BenchmarkEntry()
160      expected_1.name = "TestReportingBenchmark.benchmarkReport1"
161      expected_1.iters = 1
162
163      expected_2 = test_log_pb2.BenchmarkEntry()
164      expected_2.name = "TestReportingBenchmark.custom_benchmark_name"
165      expected_2.iters = 2
166      expected_2.extras["number_key"].double_value = 3
167      expected_2.extras["other_key"].string_value = "string"
168
169      expected_3 = test_log_pb2.BenchmarkEntry()
170      expected_3.name = "TestReportingBenchmark.op_benchmark"
171      expected_3.iters = 1000
172
173      def read_benchmark_entry(f):
174        s = gfile.GFile(f, "rb").read()
175        entries = test_log_pb2.BenchmarkEntries.FromString(s)
176        self.assertEquals(1, len(entries.entry))
177        return entries.entry[0]
178
179      read_benchmark_1 = read_benchmark_entry(expected_output_file)
180      self.assertProtoEquals(expected_1, read_benchmark_1)
181
182      read_benchmark_2 = read_benchmark_entry(expected_output_file_2)
183      self.assertProtoEquals(expected_2, read_benchmark_2)
184
185      read_benchmark_3 = read_benchmark_entry(expected_output_file_3)
186      self.assertEquals(expected_3.name, read_benchmark_3.name)
187      self.assertEquals(expected_3.iters, read_benchmark_3.iters)
188      self.assertGreater(read_benchmark_3.wall_time, 0)
189      full_trace = read_benchmark_3.extras["full_trace_chrome_format"]
190      json_trace = json.loads(full_trace.string_value)
191      self.assertTrue(isinstance(json_trace, dict))
192      self.assertTrue("traceEvents" in json_trace.keys())
193      allocator_keys = [k for k in read_benchmark_3.extras.keys()
194                        if k.startswith("allocator_maximum_num_bytes_")]
195      self.assertGreater(len(allocator_keys), 0)
196      for k in allocator_keys:
197        self.assertGreater(read_benchmark_3.extras[k].double_value, 0)
198
199    finally:
200      gfile.DeleteRecursively(tempdir)
201
202
203if __name__ == "__main__":
204  test.main()
205