1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the experimental input pipeline statistics gathering ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base 23from tensorflow.contrib.data.python.ops import stats_ops 24from tensorflow.core.framework import summary_pb2 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.platform import test 30 31 32class StatsDatasetTest(test.TestCase): 33 34 def _assertSummaryHasCount(self, summary_str, tag, expected_value): 35 summary_proto = summary_pb2.Summary() 36 summary_proto.ParseFromString(summary_str) 37 for value in summary_proto.value: 38 if tag == value.tag: 39 self.assertEqual(expected_value, value.histo.num) 40 return 41 self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) 42 43 def _assertSummaryHasSum(self, summary_str, tag, expected_value): 44 summary_proto = summary_pb2.Summary() 45 summary_proto.ParseFromString(summary_str) 46 for value in summary_proto.value: 47 if tag == value.tag: 48 self.assertEqual(expected_value, value.histo.sum) 49 return 50 self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) 51 52 def testBytesProduced(self): 53 dataset = dataset_ops.Dataset.range(100).map( 54 lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( 55 stats_ops.bytes_produced_stats("bytes_produced")) 56 iterator = dataset.make_initializable_iterator() 57 stats_aggregator = stats_ops.StatsAggregator() 58 stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) 59 next_element = iterator.get_next() 60 summary_t = stats_aggregator.get_summary() 61 62 with self.test_session() as sess: 63 sess.run([iterator.initializer, stats_aggregator_subscriber]) 64 expected_sum = 0.0 65 for i in range(100): 66 self.assertAllEqual( 67 np.array([i] * i, dtype=np.int64), sess.run(next_element)) 68 summary_str = sess.run(summary_t) 69 self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1)) 70 expected_sum += i * 8.0 71 self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) 72 with self.assertRaises(errors.OutOfRangeError): 73 sess.run(next_element) 74 summary_str = sess.run(summary_t) 75 self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0) 76 self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) 77 78 def testLatencyStats(self): 79 dataset = dataset_ops.Dataset.range(100).apply( 80 stats_ops.latency_stats("record_latency")) 81 iterator = dataset.make_initializable_iterator() 82 stats_aggregator = stats_ops.StatsAggregator() 83 stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) 84 next_element = iterator.get_next() 85 summary_t = stats_aggregator.get_summary() 86 87 with self.test_session() as sess: 88 sess.run([iterator.initializer, stats_aggregator_subscriber]) 89 for i in range(100): 90 self.assertEqual(i, sess.run(next_element)) 91 self._assertSummaryHasCount( 92 sess.run(summary_t), "record_latency", float(i + 1)) 93 with self.assertRaises(errors.OutOfRangeError): 94 sess.run(next_element) 95 self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) 96 97 def testReinitialize(self): 98 dataset = dataset_ops.Dataset.range(100).apply( 99 stats_ops.latency_stats("record_latency")) 100 iterator = dataset.make_initializable_iterator() 101 stats_aggregator = stats_ops.StatsAggregator() 102 stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) 103 next_element = iterator.get_next() 104 summary_t = stats_aggregator.get_summary() 105 106 with self.test_session() as sess: 107 sess.run(stats_aggregator_subscriber) 108 for j in range(5): 109 sess.run(iterator.initializer) 110 for i in range(100): 111 self.assertEqual(i, sess.run(next_element)) 112 self._assertSummaryHasCount( 113 sess.run(summary_t), "record_latency", float((j * 100) + i + 1)) 114 with self.assertRaises(errors.OutOfRangeError): 115 sess.run(next_element) 116 self._assertSummaryHasCount( 117 sess.run(summary_t), "record_latency", (j + 1) * 100.0) 118 119 def testNoAggregatorRegistered(self): 120 dataset = dataset_ops.Dataset.range(100).apply( 121 stats_ops.latency_stats("record_latency")) 122 iterator = dataset.make_initializable_iterator() 123 next_element = iterator.get_next() 124 125 with self.test_session() as sess: 126 sess.run(iterator.initializer) 127 for i in range(100): 128 self.assertEqual(i, sess.run(next_element)) 129 with self.assertRaises(errors.OutOfRangeError): 130 sess.run(next_element) 131 132 def testMultipleTags(self): 133 dataset = dataset_ops.Dataset.range(100).apply( 134 stats_ops.latency_stats("record_latency")).apply( 135 stats_ops.latency_stats("record_latency_2")) 136 iterator = dataset.make_initializable_iterator() 137 stats_aggregator = stats_ops.StatsAggregator() 138 stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) 139 next_element = iterator.get_next() 140 summary_t = stats_aggregator.get_summary() 141 142 with self.test_session() as sess: 143 sess.run([iterator.initializer, stats_aggregator_subscriber]) 144 for i in range(100): 145 self.assertEqual(i, sess.run(next_element)) 146 self._assertSummaryHasCount( 147 sess.run(summary_t), "record_latency", float(i + 1)) 148 self._assertSummaryHasCount( 149 sess.run(summary_t), "record_latency_2", float(i + 1)) 150 with self.assertRaises(errors.OutOfRangeError): 151 sess.run(next_element) 152 self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) 153 self._assertSummaryHasCount( 154 sess.run(summary_t), "record_latency_2", 100.0) 155 156 def testRepeatedTags(self): 157 dataset = dataset_ops.Dataset.range(100).apply( 158 stats_ops.latency_stats("record_latency")).apply( 159 stats_ops.latency_stats("record_latency")) 160 iterator = dataset.make_initializable_iterator() 161 stats_aggregator = stats_ops.StatsAggregator() 162 stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) 163 next_element = iterator.get_next() 164 summary_t = stats_aggregator.get_summary() 165 166 with self.test_session() as sess: 167 sess.run([iterator.initializer, stats_aggregator_subscriber]) 168 for i in range(100): 169 self.assertEqual(i, sess.run(next_element)) 170 self._assertSummaryHasCount( 171 sess.run(summary_t), "record_latency", float(2 * (i + 1))) 172 with self.assertRaises(errors.OutOfRangeError): 173 sess.run(next_element) 174 self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) 175 176 def testMultipleIteratorsSameAggregator(self): 177 dataset = dataset_ops.Dataset.range(100).apply( 178 stats_ops.latency_stats("record_latency")) 179 iterator_0 = dataset.make_initializable_iterator() 180 iterator_1 = dataset.make_initializable_iterator() 181 stats_aggregator = stats_ops.StatsAggregator() 182 stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0), 183 stats_aggregator.subscribe(iterator_1)] 184 next_element = iterator_0.get_next() + iterator_1.get_next() 185 summary_t = stats_aggregator.get_summary() 186 187 with self.test_session() as sess: 188 sess.run([iterator_0.initializer, iterator_1.initializer, 189 stats_aggregator_subscribers]) 190 for i in range(100): 191 self.assertEqual(i * 2, sess.run(next_element)) 192 self._assertSummaryHasCount( 193 sess.run(summary_t), "record_latency", float(2 * (i + 1))) 194 with self.assertRaises(errors.OutOfRangeError): 195 sess.run(next_element) 196 self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) 197 198 def testMultipleStatsAggregatorsSameIteratorFail(self): 199 dataset = dataset_ops.Dataset.range(100).apply( 200 stats_ops.latency_stats("record_latency")) 201 iterator = dataset.make_initializable_iterator() 202 stats_aggregator_0 = stats_ops.StatsAggregator() 203 stats_aggregator_1 = stats_ops.StatsAggregator() 204 205 with self.test_session() as sess: 206 sess.run(stats_aggregator_0.subscribe(iterator)) 207 # TODO(mrry): Consider making this allowable (and also allowing 208 # aggregators to unsubscribe). 209 with self.assertRaises(errors.FailedPreconditionError): 210 sess.run(stats_aggregator_1.subscribe(iterator)) 211 212 213class StatsDatasetSerializationTest( 214 dataset_serialization_test_base.DatasetSerializationTestBase): 215 216 def _build_dataset_bytes_stats(self, num_elements): 217 return dataset_ops.Dataset.range(num_elements).map( 218 lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( 219 stats_ops.bytes_produced_stats("bytes_produced")) 220 221 def testBytesStatsDatasetSaveableCore(self): 222 num_outputs = 100 223 self.run_core_tests( 224 lambda: self._build_dataset_bytes_stats(num_outputs), 225 lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) 226 227 def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): 228 return dataset_ops.Dataset.range(num_elements).apply( 229 stats_ops.latency_stats(tag)) 230 231 def _build_dataset_multiple_tags(self, 232 num_elements, 233 tag1="record_latency", 234 tag2="record_latency_2"): 235 return dataset_ops.Dataset.range(num_elements).apply( 236 stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) 237 238 def testLatencyStatsDatasetSaveableCore(self): 239 num_outputs = 100 240 241 self.run_core_tests( 242 lambda: self._build_dataset_latency_stats(num_outputs), 243 lambda: self._build_dataset_latency_stats(num_outputs // 10), 244 num_outputs) 245 246 self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), 247 None, num_outputs) 248 249 tag1 = "record_latency" 250 tag2 = "record_latency" 251 self.run_core_tests( 252 lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), 253 None, num_outputs) 254 255 256if __name__ == "__main__": 257 test.main() 258