1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the experimental input pipeline statistics gathering ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
23from tensorflow.contrib.data.python.ops import stats_ops
24from tensorflow.core.framework import summary_pb2
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.platform import test
30
31
32class StatsDatasetTest(test.TestCase):
33
34  def _assertSummaryHasCount(self, summary_str, tag, expected_value):
35    summary_proto = summary_pb2.Summary()
36    summary_proto.ParseFromString(summary_str)
37    for value in summary_proto.value:
38      if tag == value.tag:
39        self.assertEqual(expected_value, value.histo.num)
40        return
41    self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
42
43  def _assertSummaryHasSum(self, summary_str, tag, expected_value):
44    summary_proto = summary_pb2.Summary()
45    summary_proto.ParseFromString(summary_str)
46    for value in summary_proto.value:
47      if tag == value.tag:
48        self.assertEqual(expected_value, value.histo.sum)
49        return
50    self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
51
52  def testBytesProduced(self):
53    dataset = dataset_ops.Dataset.range(100).map(
54        lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
55            stats_ops.bytes_produced_stats("bytes_produced"))
56    iterator = dataset.make_initializable_iterator()
57    stats_aggregator = stats_ops.StatsAggregator()
58    stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
59    next_element = iterator.get_next()
60    summary_t = stats_aggregator.get_summary()
61
62    with self.test_session() as sess:
63      sess.run([iterator.initializer, stats_aggregator_subscriber])
64      expected_sum = 0.0
65      for i in range(100):
66        self.assertAllEqual(
67            np.array([i] * i, dtype=np.int64), sess.run(next_element))
68        summary_str = sess.run(summary_t)
69        self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
70        expected_sum += i * 8.0
71        self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
72      with self.assertRaises(errors.OutOfRangeError):
73        sess.run(next_element)
74      summary_str = sess.run(summary_t)
75      self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
76      self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
77
78  def testLatencyStats(self):
79    dataset = dataset_ops.Dataset.range(100).apply(
80        stats_ops.latency_stats("record_latency"))
81    iterator = dataset.make_initializable_iterator()
82    stats_aggregator = stats_ops.StatsAggregator()
83    stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
84    next_element = iterator.get_next()
85    summary_t = stats_aggregator.get_summary()
86
87    with self.test_session() as sess:
88      sess.run([iterator.initializer, stats_aggregator_subscriber])
89      for i in range(100):
90        self.assertEqual(i, sess.run(next_element))
91        self._assertSummaryHasCount(
92            sess.run(summary_t), "record_latency", float(i + 1))
93      with self.assertRaises(errors.OutOfRangeError):
94        sess.run(next_element)
95      self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
96
97  def testReinitialize(self):
98    dataset = dataset_ops.Dataset.range(100).apply(
99        stats_ops.latency_stats("record_latency"))
100    iterator = dataset.make_initializable_iterator()
101    stats_aggregator = stats_ops.StatsAggregator()
102    stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
103    next_element = iterator.get_next()
104    summary_t = stats_aggregator.get_summary()
105
106    with self.test_session() as sess:
107      sess.run(stats_aggregator_subscriber)
108      for j in range(5):
109        sess.run(iterator.initializer)
110        for i in range(100):
111          self.assertEqual(i, sess.run(next_element))
112          self._assertSummaryHasCount(
113              sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
114        with self.assertRaises(errors.OutOfRangeError):
115          sess.run(next_element)
116        self._assertSummaryHasCount(
117            sess.run(summary_t), "record_latency", (j + 1) * 100.0)
118
119  def testNoAggregatorRegistered(self):
120    dataset = dataset_ops.Dataset.range(100).apply(
121        stats_ops.latency_stats("record_latency"))
122    iterator = dataset.make_initializable_iterator()
123    next_element = iterator.get_next()
124
125    with self.test_session() as sess:
126      sess.run(iterator.initializer)
127      for i in range(100):
128        self.assertEqual(i, sess.run(next_element))
129      with self.assertRaises(errors.OutOfRangeError):
130        sess.run(next_element)
131
132  def testMultipleTags(self):
133    dataset = dataset_ops.Dataset.range(100).apply(
134        stats_ops.latency_stats("record_latency")).apply(
135            stats_ops.latency_stats("record_latency_2"))
136    iterator = dataset.make_initializable_iterator()
137    stats_aggregator = stats_ops.StatsAggregator()
138    stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
139    next_element = iterator.get_next()
140    summary_t = stats_aggregator.get_summary()
141
142    with self.test_session() as sess:
143      sess.run([iterator.initializer, stats_aggregator_subscriber])
144      for i in range(100):
145        self.assertEqual(i, sess.run(next_element))
146        self._assertSummaryHasCount(
147            sess.run(summary_t), "record_latency", float(i + 1))
148        self._assertSummaryHasCount(
149            sess.run(summary_t), "record_latency_2", float(i + 1))
150      with self.assertRaises(errors.OutOfRangeError):
151        sess.run(next_element)
152      self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
153      self._assertSummaryHasCount(
154          sess.run(summary_t), "record_latency_2", 100.0)
155
156  def testRepeatedTags(self):
157    dataset = dataset_ops.Dataset.range(100).apply(
158        stats_ops.latency_stats("record_latency")).apply(
159            stats_ops.latency_stats("record_latency"))
160    iterator = dataset.make_initializable_iterator()
161    stats_aggregator = stats_ops.StatsAggregator()
162    stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
163    next_element = iterator.get_next()
164    summary_t = stats_aggregator.get_summary()
165
166    with self.test_session() as sess:
167      sess.run([iterator.initializer, stats_aggregator_subscriber])
168      for i in range(100):
169        self.assertEqual(i, sess.run(next_element))
170        self._assertSummaryHasCount(
171            sess.run(summary_t), "record_latency", float(2 * (i + 1)))
172      with self.assertRaises(errors.OutOfRangeError):
173        sess.run(next_element)
174      self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
175
176  def testMultipleIteratorsSameAggregator(self):
177    dataset = dataset_ops.Dataset.range(100).apply(
178        stats_ops.latency_stats("record_latency"))
179    iterator_0 = dataset.make_initializable_iterator()
180    iterator_1 = dataset.make_initializable_iterator()
181    stats_aggregator = stats_ops.StatsAggregator()
182    stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0),
183                                    stats_aggregator.subscribe(iterator_1)]
184    next_element = iterator_0.get_next() + iterator_1.get_next()
185    summary_t = stats_aggregator.get_summary()
186
187    with self.test_session() as sess:
188      sess.run([iterator_0.initializer, iterator_1.initializer,
189                stats_aggregator_subscribers])
190      for i in range(100):
191        self.assertEqual(i * 2, sess.run(next_element))
192        self._assertSummaryHasCount(
193            sess.run(summary_t), "record_latency", float(2 * (i + 1)))
194      with self.assertRaises(errors.OutOfRangeError):
195        sess.run(next_element)
196      self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
197
198  def testMultipleStatsAggregatorsSameIteratorFail(self):
199    dataset = dataset_ops.Dataset.range(100).apply(
200        stats_ops.latency_stats("record_latency"))
201    iterator = dataset.make_initializable_iterator()
202    stats_aggregator_0 = stats_ops.StatsAggregator()
203    stats_aggregator_1 = stats_ops.StatsAggregator()
204
205    with self.test_session() as sess:
206      sess.run(stats_aggregator_0.subscribe(iterator))
207      # TODO(mrry): Consider making this allowable (and also allowing
208      # aggregators to unsubscribe).
209      with self.assertRaises(errors.FailedPreconditionError):
210        sess.run(stats_aggregator_1.subscribe(iterator))
211
212
213class StatsDatasetSerializationTest(
214    dataset_serialization_test_base.DatasetSerializationTestBase):
215
216  def _build_dataset_bytes_stats(self, num_elements):
217    return dataset_ops.Dataset.range(num_elements).map(
218        lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
219            stats_ops.bytes_produced_stats("bytes_produced"))
220
221  def testBytesStatsDatasetSaveableCore(self):
222    num_outputs = 100
223    self.run_core_tests(
224        lambda: self._build_dataset_bytes_stats(num_outputs),
225        lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
226
227  def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
228    return dataset_ops.Dataset.range(num_elements).apply(
229        stats_ops.latency_stats(tag))
230
231  def _build_dataset_multiple_tags(self,
232                                   num_elements,
233                                   tag1="record_latency",
234                                   tag2="record_latency_2"):
235    return dataset_ops.Dataset.range(num_elements).apply(
236        stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
237
238  def testLatencyStatsDatasetSaveableCore(self):
239    num_outputs = 100
240
241    self.run_core_tests(
242        lambda: self._build_dataset_latency_stats(num_outputs),
243        lambda: self._build_dataset_latency_stats(num_outputs // 10),
244        num_outputs)
245
246    self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
247                        None, num_outputs)
248
249    tag1 = "record_latency"
250    tag2 = "record_latency"
251    self.run_core_tests(
252        lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
253        None, num_outputs)
254
255
256if __name__ == "__main__":
257  test.main()
258