17e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
27e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar#
37e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar# Licensed under the Apache License, Version 2.0 (the "License");
47e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar# you may not use this file except in compliance with the License.
57e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar# You may obtain a copy of the License at
67e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar#
77e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar#     http://www.apache.org/licenses/LICENSE-2.0
87e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar#
97e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar# Unless required by applicable law or agreed to in writing, software
107e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar# distributed under the License is distributed on an "AS IS" BASIS,
117e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
127e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar# See the License for the specific language governing permissions and
137e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar# limitations under the License.
147e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar# ==============================================================================
157e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankarfrom __future__ import absolute_import
167e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankarfrom __future__ import division
177e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankarfrom __future__ import print_function
187e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
19cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murrayimport time
20cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
21cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murrayimport numpy as np
22cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
239b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passosfrom tensorflow.contrib import lookup
247e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankarfrom tensorflow.contrib.eager.python import datasets
25bcf5dcc87ed2fe05197beaef30e536575608700bAsim Shankarfrom tensorflow.python.data import Dataset
267e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankarfrom tensorflow.python.eager import test
279b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passosfrom tensorflow.python.framework import constant_op
284bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankarfrom tensorflow.python.framework import dtypes
2968cb86ed592d714beabf71402322c9de0e611a69Derek Murrayfrom tensorflow.python.framework import errors
30bcf5dcc87ed2fe05197beaef30e536575608700bAsim Shankarfrom tensorflow.python.framework import ops
3168cb86ed592d714beabf71402322c9de0e611a69Derek Murrayfrom tensorflow.python.framework import sparse_tensor
327e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankarfrom tensorflow.python.ops import math_ops
334bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankarfrom tensorflow.python.ops import script_ops
347e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
357e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
367e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankarclass IteratorTest(test.TestCase):
377e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
387e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar  def testBasic(self):
397e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    got = []
407e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    for t in datasets.Iterator(Dataset.range(4)):
417e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar      got.append(t.numpy())
427e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    self.assertAllEqual([0, 1, 2, 3], got)
437e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
4468cb86ed592d714beabf71402322c9de0e611a69Derek Murray  def testGetNext(self):
4568cb86ed592d714beabf71402322c9de0e611a69Derek Murray    iterator = datasets.Iterator(Dataset.range(4))
4668cb86ed592d714beabf71402322c9de0e611a69Derek Murray    self.assertEqual(0, iterator.get_next().numpy())
4768cb86ed592d714beabf71402322c9de0e611a69Derek Murray    self.assertEqual(1, iterator.get_next().numpy())
4868cb86ed592d714beabf71402322c9de0e611a69Derek Murray    self.assertEqual(2, iterator.get_next().numpy())
4968cb86ed592d714beabf71402322c9de0e611a69Derek Murray    self.assertEqual(3, iterator.get_next().numpy())
5068cb86ed592d714beabf71402322c9de0e611a69Derek Murray    with self.assertRaises(errors.OutOfRangeError):
5168cb86ed592d714beabf71402322c9de0e611a69Derek Murray      iterator.get_next()
5268cb86ed592d714beabf71402322c9de0e611a69Derek Murray
537e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar  def testMultipleIteratorsOnTheSameDataset(self):
547e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    ds = Dataset.range(4)
557e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    it1 = datasets.Iterator(ds)
567e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    it2 = datasets.Iterator(ds)
577e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    got = [x.numpy() for x in it1]
587e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    self.assertAllEqual([0, 1, 2, 3], got)
597e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
607e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    got = [x.numpy() for x in it2]
617e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    self.assertAllEqual([0, 1, 2, 3], got)
627e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
637e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar  def testNestedOutputs(self):
647e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4),
657e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar                                                     Dataset.range(4)))))
667e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    total = 0
677e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    # The Iterator will return a nested structure of Tensor objects.
687e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    # Some funkiness to compare against simple integers.
697e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    for (i, x) in enumerate(datasets.Iterator(ds)):
707e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar      want = (i, (i, i))
717e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar      got = (x[0].numpy(), (x[1][0].numpy(), x[1][1].numpy()))
727e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar      self.assertEqual(got, want)
737e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar      total += 1
747e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    self.assertEqual(4, total)
757e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
767e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar  def testMapAndFilter(self):
777e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    def even(x):
787e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar      return math_ops.equal(math_ops.mod(x, 2), 0)
797e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
807e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even))
817e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    got = [x.numpy() for x in it]
827e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar    self.assertAllEqual([0, 4, 16, 36], got)
837e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
849b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos  def testMapCaptureLookupTable(self):
859b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos    default_val = -1
869b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos    keys = constant_op.constant(['brain', 'salad', 'surgery'])
879b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos    values = constant_op.constant([0, 1, 2], dtypes.int64)
889b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos    table = lookup.HashTable(
899b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos        lookup.KeyValueTensorInitializer(keys, values), default_val)
909b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos    dataset = Dataset.from_tensor_slices(['brain', 'salad', 'surgery'])
919b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos    dataset = dataset.map(table.lookup)
929b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos    it = datasets.Iterator(dataset)
939b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos    got = [x.numpy() for x in it]
949b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos    self.assertAllEqual([0, 1, 2], got)
959b03bcd74d117a2a6ee270af438a3a28e7123111Alexandre Passos
961ad7cb6f05c221ff0df5532e4101e99250dec33fAsim Shankar  def testMultipleIteratorsOnADatasetThatUsesFunctions(self):
971ad7cb6f05c221ff0df5532e4101e99250dec33fAsim Shankar    ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square)
981ad7cb6f05c221ff0df5532e4101e99250dec33fAsim Shankar
991ad7cb6f05c221ff0df5532e4101e99250dec33fAsim Shankar    got1 = [x.numpy() for x in datasets.Iterator(ds)]
1001ad7cb6f05c221ff0df5532e4101e99250dec33fAsim Shankar    self.assertAllEqual([1, 4, 9, 16, 25, 36], got1)
1011ad7cb6f05c221ff0df5532e4101e99250dec33fAsim Shankar    got2 = [x.numpy() for x in datasets.Iterator(ds)]
1021ad7cb6f05c221ff0df5532e4101e99250dec33fAsim Shankar    self.assertAllEqual(got1, got2)
1031ad7cb6f05c221ff0df5532e4101e99250dec33fAsim Shankar
10468cb86ed592d714beabf71402322c9de0e611a69Derek Murray  def assertSparseValuesEqual(self, a, b):
10568cb86ed592d714beabf71402322c9de0e611a69Derek Murray    self.assertAllEqual(a.indices, b.indices)
10668cb86ed592d714beabf71402322c9de0e611a69Derek Murray    self.assertAllEqual(a.values, b.values)
10768cb86ed592d714beabf71402322c9de0e611a69Derek Murray    self.assertAllEqual(a.dense_shape, b.dense_shape)
10868cb86ed592d714beabf71402322c9de0e611a69Derek Murray
10968cb86ed592d714beabf71402322c9de0e611a69Derek Murray  def testSparseTensorElements(self):
11068cb86ed592d714beabf71402322c9de0e611a69Derek Murray    components = (sparse_tensor.SparseTensorValue(
11168cb86ed592d714beabf71402322c9de0e611a69Derek Murray        indices=np.array([[0, 0], [1, 0], [2, 0]]),
11268cb86ed592d714beabf71402322c9de0e611a69Derek Murray        values=np.array([0, 0, 0]),
11368cb86ed592d714beabf71402322c9de0e611a69Derek Murray        dense_shape=np.array([3, 1])),
11468cb86ed592d714beabf71402322c9de0e611a69Derek Murray                  sparse_tensor.SparseTensorValue(
11568cb86ed592d714beabf71402322c9de0e611a69Derek Murray                      indices=np.array([[0, 0], [1, 1], [2, 2]]),
11668cb86ed592d714beabf71402322c9de0e611a69Derek Murray                      values=np.array([1, 2, 3]),
11768cb86ed592d714beabf71402322c9de0e611a69Derek Murray                      dense_shape=np.array([3, 3])))
11868cb86ed592d714beabf71402322c9de0e611a69Derek Murray
11968cb86ed592d714beabf71402322c9de0e611a69Derek Murray    expected = [
12068cb86ed592d714beabf71402322c9de0e611a69Derek Murray        (sparse_tensor.SparseTensorValue(
12168cb86ed592d714beabf71402322c9de0e611a69Derek Murray            indices=np.array([[0]]),
12268cb86ed592d714beabf71402322c9de0e611a69Derek Murray            values=np.array([0]),
12368cb86ed592d714beabf71402322c9de0e611a69Derek Murray            dense_shape=np.array([1])),
12468cb86ed592d714beabf71402322c9de0e611a69Derek Murray         sparse_tensor.SparseTensorValue(
12568cb86ed592d714beabf71402322c9de0e611a69Derek Murray             indices=np.array([[0]]),
12668cb86ed592d714beabf71402322c9de0e611a69Derek Murray             values=np.array([1]),
12768cb86ed592d714beabf71402322c9de0e611a69Derek Murray             dense_shape=np.array([3]))),
12868cb86ed592d714beabf71402322c9de0e611a69Derek Murray        (sparse_tensor.SparseTensorValue(
12968cb86ed592d714beabf71402322c9de0e611a69Derek Murray            indices=np.array([[0]]),
13068cb86ed592d714beabf71402322c9de0e611a69Derek Murray            values=np.array([0]),
13168cb86ed592d714beabf71402322c9de0e611a69Derek Murray            dense_shape=np.array([1])),
13268cb86ed592d714beabf71402322c9de0e611a69Derek Murray         sparse_tensor.SparseTensorValue(
13368cb86ed592d714beabf71402322c9de0e611a69Derek Murray             indices=np.array([[1]]),
13468cb86ed592d714beabf71402322c9de0e611a69Derek Murray             values=np.array([2]),
13568cb86ed592d714beabf71402322c9de0e611a69Derek Murray             dense_shape=np.array([3]))),
13668cb86ed592d714beabf71402322c9de0e611a69Derek Murray        (sparse_tensor.SparseTensorValue(
13768cb86ed592d714beabf71402322c9de0e611a69Derek Murray            indices=np.array([[0]]),
13868cb86ed592d714beabf71402322c9de0e611a69Derek Murray            values=np.array([0]),
13968cb86ed592d714beabf71402322c9de0e611a69Derek Murray            dense_shape=np.array([1])),
14068cb86ed592d714beabf71402322c9de0e611a69Derek Murray         sparse_tensor.SparseTensorValue(
14168cb86ed592d714beabf71402322c9de0e611a69Derek Murray             indices=np.array([[2]]),
14268cb86ed592d714beabf71402322c9de0e611a69Derek Murray             values=np.array([3]),
14368cb86ed592d714beabf71402322c9de0e611a69Derek Murray             dense_shape=np.array([3]))),
14468cb86ed592d714beabf71402322c9de0e611a69Derek Murray    ]
14568cb86ed592d714beabf71402322c9de0e611a69Derek Murray
14668cb86ed592d714beabf71402322c9de0e611a69Derek Murray    for i, result in enumerate(
14768cb86ed592d714beabf71402322c9de0e611a69Derek Murray        datasets.Iterator(Dataset.from_tensor_slices(components))):
14868cb86ed592d714beabf71402322c9de0e611a69Derek Murray      self.assertSparseValuesEqual(expected[i][0], result[0])
14968cb86ed592d714beabf71402322c9de0e611a69Derek Murray      self.assertSparseValuesEqual(expected[i][1], result[1])
15068cb86ed592d714beabf71402322c9de0e611a69Derek Murray
1514bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankar  def testPyFunc(self):
1524bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankar
1534bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankar    def my_map(inp):
1544bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankar      return [[x + 1 for x in inp]]
1554bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankar
1564bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankar    ds = Dataset.range(4).map(
1574bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankar        lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64))
1584bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankar    got = [x.numpy() for x in datasets.Iterator(ds)]
1594bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankar    self.assertAllEqual([[1], [2], [3], [4]], got)
1604bf27f8d4acee2cb8df27427668bddc92137e2efAsim Shankar
161bcf5dcc87ed2fe05197beaef30e536575608700bAsim Shankar  def testTensorsPlacedOnDevice(self):
162bcf5dcc87ed2fe05197beaef30e536575608700bAsim Shankar    ds = Dataset.from_tensors([0., 1.])
163bcf5dcc87ed2fe05197beaef30e536575608700bAsim Shankar    with ops.device(test.gpu_device_name()):
164bcf5dcc87ed2fe05197beaef30e536575608700bAsim Shankar      x = datasets.Iterator(ds).next()
165bcf5dcc87ed2fe05197beaef30e536575608700bAsim Shankar      x = math_ops.add(x, x)
166bcf5dcc87ed2fe05197beaef30e536575608700bAsim Shankar    self.assertAllEqual([0., 2.], x.numpy())
167bcf5dcc87ed2fe05197beaef30e536575608700bAsim Shankar
1687e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar
169cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murrayclass DatasetConstructorBenchmark(test.Benchmark):
170cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
171cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray  def benchmarkSliceRepeatBatchEager(self):
172cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    input_size = 10000
173cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    batch_size = 100
174cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    num_epochs = 100
175cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
176cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    input_data = np.random.randn(input_size)
177cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
178cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    dataset = (
179cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        Dataset.from_tensor_slices(input_data).repeat(num_epochs)
180cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        .batch(batch_size))
181cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    iterator = datasets.Iterator(dataset)
182cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
183cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    ends = [time.time()]
184cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    for _ in iterator:
185cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray      ends.append(time.time())
186cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
187cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    deltas = np.ediff1d(ends)
188cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    median_wall_time = np.median(deltas)
189cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    print(
190cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        'Slice/repeat/batch eager input size: %d batch size: %d Median wall '
191cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        'time per element: %f'
192cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        % (input_size, batch_size, median_wall_time))
193cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    self.report_benchmark(
194cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        iters=len(deltas),
195cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        wall_time=median_wall_time,
196cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        name='benchmark_slice_repeat_batch_eager_input_%d_batch_%d' %
197cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        (input_size, batch_size))
198cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
199cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray  def benchmarkSliceBatchCacheRepeatCallable(self):
200cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    input_size = 10000
201cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    batch_size = 100
202cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    num_epochs = 100
203cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
204cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    input_data = np.random.randn(input_size)
205cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
206cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    dataset = (
207cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        Dataset.from_tensor_slices(input_data).batch(batch_size).cache()
208cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        .repeat(num_epochs))
209cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    iterator = datasets.Iterator(dataset)
210cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
211cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    ends = [time.time()]
212cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    for _ in iterator:
213cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray      ends.append(time.time())
214cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
215cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    deltas = np.ediff1d(ends)
216cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    median_wall_time = np.median(deltas)
217cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    print(
218cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        'Slice/batch/cache/repeat eager input size: %d batch size: %d Median '
219cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        'wall time per element: %f'
220cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        % (input_size, batch_size, median_wall_time))
221cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray    self.report_benchmark(
222cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        iters=len(deltas),
223cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        wall_time=median_wall_time,
224cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        name='benchmark_slice_batch_cache_repeat_eager_input_%d_batch_%d' %
225cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray        (input_size, batch_size))
226cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
227cb14c3527f4beabe17f5a4a72f17144dd7d25bc6Derek Murray
2287e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankarif __name__ == '__main__':
2297e47624f5f646dca74c3484c330cb43baec75b2aAsim Shankar  test.main()
230