115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray#
315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray# Licensed under the Apache License, Version 2.0 (the "License");
415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray# you may not use this file except in compliance with the License.
515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray# You may obtain a copy of the License at
615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray#
715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray#     http://www.apache.org/licenses/LICENSE-2.0
815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray#
915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray# Unless required by applicable law or agreed to in writing, software
1015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray# distributed under the License is distributed on an "AS IS" BASIS,
1115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray# See the License for the specific language governing permissions and
1315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray# limitations under the License.
1415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray# ==============================================================================
1515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray"""Tests for the experimental input pipeline statistics gathering ops."""
1615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayfrom __future__ import absolute_import
1715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayfrom __future__ import division
1815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayfrom __future__ import print_function
1915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
2015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayimport numpy as np
2115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
2222fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawalfrom tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
2315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayfrom tensorflow.contrib.data.python.ops import stats_ops
2415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayfrom tensorflow.core.framework import summary_pb2
2515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayfrom tensorflow.python.data.ops import dataset_ops
2615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayfrom tensorflow.python.framework import errors
2715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayfrom tensorflow.python.framework import ops
2815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayfrom tensorflow.python.ops import array_ops
2915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayfrom tensorflow.python.platform import test
3015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
3115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
3215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayclass StatsDatasetTest(test.TestCase):
3315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
3415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  def _assertSummaryHasCount(self, summary_str, tag, expected_value):
3515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    summary_proto = summary_pb2.Summary()
3615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    summary_proto.ParseFromString(summary_str)
3715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    for value in summary_proto.value:
3815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      if tag == value.tag:
3915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self.assertEqual(expected_value, value.histo.num)
4015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        return
4115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
4215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
4315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  def _assertSummaryHasSum(self, summary_str, tag, expected_value):
4415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    summary_proto = summary_pb2.Summary()
4515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    summary_proto.ParseFromString(summary_str)
4615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    for value in summary_proto.value:
4715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      if tag == value.tag:
4815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self.assertEqual(expected_value, value.histo.sum)
4915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        return
5015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
5115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
5215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  def testBytesProduced(self):
5315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    dataset = dataset_ops.Dataset.range(100).map(
5415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
5515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray            stats_ops.bytes_produced_stats("bytes_produced"))
5615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    iterator = dataset.make_initializable_iterator()
5715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator = stats_ops.StatsAggregator()
5815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
5915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    next_element = iterator.get_next()
6015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    summary_t = stats_aggregator.get_summary()
6115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
6215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    with self.test_session() as sess:
6315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      sess.run([iterator.initializer, stats_aggregator_subscriber])
6415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      expected_sum = 0.0
6515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      for i in range(100):
6615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self.assertAllEqual(
6715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray            np.array([i] * i, dtype=np.int64), sess.run(next_element))
6815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        summary_str = sess.run(summary_t)
6915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
7015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        expected_sum += i * 8.0
7115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
7215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      with self.assertRaises(errors.OutOfRangeError):
7315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        sess.run(next_element)
7415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      summary_str = sess.run(summary_t)
7515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
7615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
7715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
7815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  def testLatencyStats(self):
7915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    dataset = dataset_ops.Dataset.range(100).apply(
8015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        stats_ops.latency_stats("record_latency"))
8115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    iterator = dataset.make_initializable_iterator()
8215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator = stats_ops.StatsAggregator()
8315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
8415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    next_element = iterator.get_next()
8515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    summary_t = stats_aggregator.get_summary()
8615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
8715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    with self.test_session() as sess:
8815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      sess.run([iterator.initializer, stats_aggregator_subscriber])
8915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      for i in range(100):
9015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self.assertEqual(i, sess.run(next_element))
9115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self._assertSummaryHasCount(
9215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray            sess.run(summary_t), "record_latency", float(i + 1))
9315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      with self.assertRaises(errors.OutOfRangeError):
9415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        sess.run(next_element)
9515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
9615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
9715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  def testReinitialize(self):
9815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    dataset = dataset_ops.Dataset.range(100).apply(
9915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        stats_ops.latency_stats("record_latency"))
10015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    iterator = dataset.make_initializable_iterator()
10115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator = stats_ops.StatsAggregator()
10215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
10315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    next_element = iterator.get_next()
10415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    summary_t = stats_aggregator.get_summary()
10515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
10615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    with self.test_session() as sess:
10715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      sess.run(stats_aggregator_subscriber)
10815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      for j in range(5):
10915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        sess.run(iterator.initializer)
11015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        for i in range(100):
11115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray          self.assertEqual(i, sess.run(next_element))
11215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray          self._assertSummaryHasCount(
11315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray              sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
11415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        with self.assertRaises(errors.OutOfRangeError):
11515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray          sess.run(next_element)
11615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self._assertSummaryHasCount(
11715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray            sess.run(summary_t), "record_latency", (j + 1) * 100.0)
11815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
11915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  def testNoAggregatorRegistered(self):
12015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    dataset = dataset_ops.Dataset.range(100).apply(
12115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        stats_ops.latency_stats("record_latency"))
12215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    iterator = dataset.make_initializable_iterator()
12315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    next_element = iterator.get_next()
12415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
12515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    with self.test_session() as sess:
12615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      sess.run(iterator.initializer)
12715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      for i in range(100):
12815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self.assertEqual(i, sess.run(next_element))
12915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      with self.assertRaises(errors.OutOfRangeError):
13015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        sess.run(next_element)
13115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
13215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  def testMultipleTags(self):
13315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    dataset = dataset_ops.Dataset.range(100).apply(
13415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        stats_ops.latency_stats("record_latency")).apply(
13515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray            stats_ops.latency_stats("record_latency_2"))
13615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    iterator = dataset.make_initializable_iterator()
13715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator = stats_ops.StatsAggregator()
13815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
13915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    next_element = iterator.get_next()
14015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    summary_t = stats_aggregator.get_summary()
14115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
14215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    with self.test_session() as sess:
14315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      sess.run([iterator.initializer, stats_aggregator_subscriber])
14415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      for i in range(100):
14515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self.assertEqual(i, sess.run(next_element))
14615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self._assertSummaryHasCount(
14715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray            sess.run(summary_t), "record_latency", float(i + 1))
14815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self._assertSummaryHasCount(
14915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray            sess.run(summary_t), "record_latency_2", float(i + 1))
15015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      with self.assertRaises(errors.OutOfRangeError):
15115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        sess.run(next_element)
15215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
15315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      self._assertSummaryHasCount(
15415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray          sess.run(summary_t), "record_latency_2", 100.0)
15515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
15615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  def testRepeatedTags(self):
15715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    dataset = dataset_ops.Dataset.range(100).apply(
15815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        stats_ops.latency_stats("record_latency")).apply(
15915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray            stats_ops.latency_stats("record_latency"))
16015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    iterator = dataset.make_initializable_iterator()
16115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator = stats_ops.StatsAggregator()
16215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
16315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    next_element = iterator.get_next()
16415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    summary_t = stats_aggregator.get_summary()
16515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
16615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    with self.test_session() as sess:
16715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      sess.run([iterator.initializer, stats_aggregator_subscriber])
16815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      for i in range(100):
16915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self.assertEqual(i, sess.run(next_element))
17015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self._assertSummaryHasCount(
17115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray            sess.run(summary_t), "record_latency", float(2 * (i + 1)))
17215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      with self.assertRaises(errors.OutOfRangeError):
17315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        sess.run(next_element)
17415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
17515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
17615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  def testMultipleIteratorsSameAggregator(self):
17715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    dataset = dataset_ops.Dataset.range(100).apply(
17815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        stats_ops.latency_stats("record_latency"))
17915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    iterator_0 = dataset.make_initializable_iterator()
18015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    iterator_1 = dataset.make_initializable_iterator()
18115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator = stats_ops.StatsAggregator()
18215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0),
18315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray                                    stats_aggregator.subscribe(iterator_1)]
18415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    next_element = iterator_0.get_next() + iterator_1.get_next()
18515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    summary_t = stats_aggregator.get_summary()
18615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
18715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    with self.test_session() as sess:
18815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      sess.run([iterator_0.initializer, iterator_1.initializer,
18915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray                stats_aggregator_subscribers])
19015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      for i in range(100):
19115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self.assertEqual(i * 2, sess.run(next_element))
19215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        self._assertSummaryHasCount(
19315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray            sess.run(summary_t), "record_latency", float(2 * (i + 1)))
19415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      with self.assertRaises(errors.OutOfRangeError):
19515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        sess.run(next_element)
19615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
19715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
19815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  def testMultipleStatsAggregatorsSameIteratorFail(self):
19915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    dataset = dataset_ops.Dataset.range(100).apply(
20015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        stats_ops.latency_stats("record_latency"))
20115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    iterator = dataset.make_initializable_iterator()
20215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator_0 = stats_ops.StatsAggregator()
20315907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    stats_aggregator_1 = stats_ops.StatsAggregator()
20415907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
20515907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray    with self.test_session() as sess:
20615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      sess.run(stats_aggregator_0.subscribe(iterator))
20715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      # TODO(mrry): Consider making this allowable (and also allowing
20815907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      # aggregators to unsubscribe).
20915907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray      with self.assertRaises(errors.FailedPreconditionError):
21015907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray        sess.run(stats_aggregator_1.subscribe(iterator))
21115907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
21215907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray
21322fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawalclass StatsDatasetSerializationTest(
21422fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal    dataset_serialization_test_base.DatasetSerializationTestBase):
21522fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal
21622fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal  def _build_dataset_bytes_stats(self, num_elements):
21722fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal    return dataset_ops.Dataset.range(num_elements).map(
21822fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal        lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
21922fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal            stats_ops.bytes_produced_stats("bytes_produced"))
22022fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal
22122fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal  def testBytesStatsDatasetSaveableCore(self):
22222fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal    num_outputs = 100
22322fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal    self.run_core_tests(
22422fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal        lambda: self._build_dataset_bytes_stats(num_outputs),
22522fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal        lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
22622fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal
227ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal  def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
228ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal    return dataset_ops.Dataset.range(num_elements).apply(
229ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal        stats_ops.latency_stats(tag))
230ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal
231ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal  def _build_dataset_multiple_tags(self,
232ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal                                   num_elements,
233ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal                                   tag1="record_latency",
234ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal                                   tag2="record_latency_2"):
235ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal    return dataset_ops.Dataset.range(num_elements).apply(
236ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal        stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
237ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal
238ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal  def testLatencyStatsDatasetSaveableCore(self):
239ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal    num_outputs = 100
240ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal
241ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal    self.run_core_tests(
242ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal        lambda: self._build_dataset_latency_stats(num_outputs),
243ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal        lambda: self._build_dataset_latency_stats(num_outputs // 10),
244ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal        num_outputs)
245ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal
246ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal    self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
247ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal                        None, num_outputs)
248ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal
249ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal    tag1 = "record_latency"
250ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal    tag2 = "record_latency"
251ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal    self.run_core_tests(
252ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal        lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
253ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal        None, num_outputs)
254ed24130f90c2c45db0473df3e9158d4895ce326bShivani Agrawal
25522fe6558a958c6cc81d16d371031c06e262b1c83Shivani Agrawal
25615907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murrayif __name__ == "__main__":
25715907659888a3e36e8de3d5a95de8d3327cb7c46Derek Murray  test.main()
258