1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the datasets shape inference."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.ops import iterator_ops
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import meta_graph
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.grappler import item
30from tensorflow.python.ops import array_ops
31from tensorflow.python.platform import test
32
33
34class GrapplerTest(test.TestCase):
35
36  def testFromTensors(self):
37    test_cases = [{
38        'tensor': 0,
39        'shape': tensor_shape.TensorShape([])
40    }, {
41        'tensor': np.array([1, 2, 3]),
42        'shape': tensor_shape.TensorShape([3])
43    }, {
44        'tensor': np.array([[1, 2, 3]]),
45        'shape': tensor_shape.TensorShape([1, 3])
46    }]
47
48    for test_case in test_cases:
49      with ops.Graph().as_default() as g:
50        dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
51        iterator = dataset.make_one_shot_iterator()
52        get_next = iterator.get_next()
53        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
54        train_op.append(get_next)
55        mg = meta_graph.create_meta_graph_def(graph=g)
56        grappler_item = item.Item(mg)
57        op_properties = grappler_item.GetOpProperties()
58        self.assertEqual(test_case['shape'],
59                         op_properties['IteratorGetNext'][0].shape)
60
61  def testFromTensorSlices(self):
62    test_cases = [{
63        'tensor': np.array([1, 2, 3]),
64        'shape': tensor_shape.TensorShape([])
65    }, {
66        'tensor': np.array([[1, 2, 3]]),
67        'shape': tensor_shape.TensorShape([3])
68    }, {
69        'tensor': np.array([[[1, 2, 3]]]),
70        'shape': tensor_shape.TensorShape([1, 3])
71    }]
72
73    for test_case in test_cases:
74      with ops.Graph().as_default() as g:
75        dataset = dataset_ops.Dataset.from_tensor_slices(test_case['tensor'])
76        iterator = dataset.make_one_shot_iterator()
77        get_next = iterator.get_next()
78        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
79        train_op.append(get_next)
80        mg = meta_graph.create_meta_graph_def(graph=g)
81        grappler_item = item.Item(mg)
82        op_properties = grappler_item.GetOpProperties()
83        self.assertEqual(test_case['shape'],
84                         op_properties['IteratorGetNext'][0].shape)
85
86  def testFromGenerator(self):
87    test_cases = [{
88        'tensor': 0,
89        'shape': tensor_shape.TensorShape([])
90    }, {
91        'tensor': np.array([1, 2, 3]),
92        'shape': tensor_shape.TensorShape([3])
93    }, {
94        'tensor': np.array([[1, 2, 3]]),
95        'shape': tensor_shape.TensorShape([1, 3])
96    }]
97
98    for test_case in test_cases:
99
100      def make_generator(tensor):
101
102        def generator():
103          yield tensor
104
105        return generator
106
107      with ops.Graph().as_default() as g:
108        dataset = dataset_ops.Dataset.from_generator(
109            make_generator(test_case['tensor']),
110            dtypes.int64,
111            output_shapes=test_case['shape'])
112        iterator = dataset.make_one_shot_iterator()
113        get_next = iterator.get_next()
114        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
115        train_op.append(get_next)
116        mg = meta_graph.create_meta_graph_def(graph=g)
117        grappler_item = item.Item(mg)
118        op_properties = grappler_item.GetOpProperties()
119        self.assertEqual(test_case['shape'],
120                         op_properties['IteratorGetNext'][0].shape)
121
122  def testRange(self):
123    with ops.Graph().as_default() as g:
124      dataset = dataset_ops.Dataset.range(42)
125      iterator = dataset.make_one_shot_iterator()
126      get_next = iterator.get_next()
127      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
128      train_op.append(get_next)
129      mg = meta_graph.create_meta_graph_def(graph=g)
130      grappler_item = item.Item(mg)
131      op_properties = grappler_item.GetOpProperties()
132      self.assertEqual(tensor_shape.scalar(),
133                       op_properties['IteratorGetNext'][0].shape)
134
135  def _testTransformation(self, fn):
136    test_cases = [{
137        'tensor': 0,
138        'shape': tensor_shape.TensorShape({})
139    }, {
140        'tensor': np.array([1, 2, 3]),
141        'shape': tensor_shape.TensorShape([3])
142    }, {
143        'tensor': np.array([[1, 2, 3]]),
144        'shape': tensor_shape.TensorShape([1, 3])
145    }]
146
147    for test_case in test_cases:
148      with ops.Graph().as_default() as g:
149        dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
150        dataset = fn(dataset, test_case['tensor'], test_case['shape'])
151        iterator = dataset.make_one_shot_iterator()
152        get_next = iterator.get_next()
153        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
154        train_op.append(get_next)
155        mg = meta_graph.create_meta_graph_def(graph=g)
156        grappler_item = item.Item(mg)
157        op_properties = grappler_item.GetOpProperties()
158        self.assertEqual(test_case['shape'],
159                         op_properties['IteratorGetNext'][0].shape)
160
161  def testConcatenate(self):
162
163    def fn(dataset, tensor, shape):
164      del shape
165      return dataset.concatenate(dataset_ops.Dataset.from_tensors(tensor))
166
167    self._testTransformation(fn)
168
169  def testPrefetch(self):
170
171    def fn(dataset, tensor, shape):
172      del tensor, shape
173      return dataset.prefetch(42)
174
175    self._testTransformation(fn)
176
177  def testRepeat(self):
178
179    def fn(dataset, tensor, shape):
180      del tensor, shape
181      return dataset.repeat(42)
182
183    self._testTransformation(fn)
184
185  def testShuffle(self):
186
187    def fn(dataset, tensor, shape):
188      del tensor, shape
189      return dataset.shuffle(42)
190
191    self._testTransformation(fn)
192
193  def testCache(self):
194
195    def fn(dataset, tensor, shape):
196      del tensor, shape
197      return dataset.cache()
198
199    self._testTransformation(fn)
200
201  def testTake(self):
202
203    def fn(dataset, tensor, shape):
204      del tensor, shape
205      return dataset.take(42)
206
207    self._testTransformation(fn)
208
209  def testSkip(self):
210
211    def fn(dataset, tensor, shape):
212      del tensor, shape
213      return dataset.skip(42)
214
215    self._testTransformation(fn)
216
217  def testShard(self):
218
219    def fn(dataset, tensor, shape):
220      del tensor, shape
221      return dataset.shard(42, 0)
222
223    self._testTransformation(fn)
224
225  def testFilter(self):
226
227    def fn(dataset, tensor, shape):
228      del tensor, shape
229      return dataset.filter(lambda x: True)
230
231    self._testTransformation(fn)
232
233  def as_tensor_shape(self, proto_with_symbolic_values):
234    for i in range(len(proto_with_symbolic_values.dim)):
235      if proto_with_symbolic_values.dim[i].size < -1:
236        proto_with_symbolic_values.dim[i].size = -1
237    return tensor_shape.TensorShape(proto_with_symbolic_values)
238
239  def testBatch(self):
240    test_cases = [{
241        'tensor': 0,
242        'shape': tensor_shape.TensorShape([None])
243    }, {
244        'tensor': np.array([1, 2, 3]),
245        'shape': tensor_shape.TensorShape([None, 3])
246    }, {
247        'tensor': np.array([[1, 2, 3]]),
248        'shape': tensor_shape.TensorShape([None, 1, 3])
249    }]
250
251    for test_case in test_cases:
252      with ops.Graph().as_default() as g:
253        dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
254        dataset = dataset.batch(42)
255        iterator = dataset.make_one_shot_iterator()
256        get_next = iterator.get_next()
257        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
258        train_op.append(get_next)
259        mg = meta_graph.create_meta_graph_def(graph=g)
260        grappler_item = item.Item(mg)
261        op_properties = grappler_item.GetOpProperties()
262        inferred_shape = self.as_tensor_shape(
263            op_properties['IteratorGetNext'][0].shape)
264        self.assertTrue(test_case['shape'][0].is_compatible_with(
265            inferred_shape[0]))
266        self.assertEqual(test_case['shape'][1:], inferred_shape[1:])
267
268  def testPaddedBatch(self):
269    test_cases = [{
270        'tensor': 0,
271        'shape': tensor_shape.TensorShape([None])
272    }, {
273        'tensor': np.array([1, 2, 3]),
274        'shape': tensor_shape.TensorShape([None, 4])
275    }, {
276        'tensor': np.array([[1, 2, 3]]),
277        'shape': tensor_shape.TensorShape([None, 2, 4])
278    }]
279
280    for test_case in test_cases:
281      with ops.Graph().as_default() as g:
282        dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
283        dataset = dataset.padded_batch(42, padded_shapes=test_case['shape'][1:])
284        iterator = dataset.make_one_shot_iterator()
285        get_next = iterator.get_next()
286        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
287        train_op.append(get_next)
288        mg = meta_graph.create_meta_graph_def(graph=g)
289        grappler_item = item.Item(mg)
290        op_properties = grappler_item.GetOpProperties()
291        inferred_shape = self.as_tensor_shape(
292            op_properties['IteratorGetNext'][0].shape)
293        self.assertTrue(test_case['shape'][0].is_compatible_with(
294            inferred_shape[0]))
295        self.assertEqual(test_case['shape'][1:], inferred_shape[1:])
296
297  def testFlatMap(self):
298    test_cases = [{
299        'tensor': 0,
300        'shape': tensor_shape.TensorShape([])
301    }, {
302        'tensor': np.array([1, 2, 3]),
303        'shape': tensor_shape.TensorShape([3])
304    }, {
305        'tensor': np.array([[1, 2, 3]]),
306        'shape': tensor_shape.TensorShape([1, 3])
307    }]
308
309    for test_case in test_cases:
310      with ops.Graph().as_default() as g:
311        dataset = dataset_ops.Dataset.range(42)
312
313        def make_dataset(tensor):
314
315          def dataset_fn(n):
316            return dataset_ops.Dataset.from_tensors(tensor).repeat(n)
317
318          return dataset_fn
319
320        dataset = dataset.flat_map(make_dataset(test_case['tensor']))
321        iterator = dataset.make_one_shot_iterator()
322        get_next = iterator.get_next()
323        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
324        train_op.append(get_next)
325        mg = meta_graph.create_meta_graph_def(graph=g)
326        grappler_item = item.Item(mg)
327        op_properties = grappler_item.GetOpProperties()
328        self.assertEqual(test_case['shape'],
329                         op_properties['IteratorGetNext'][0].shape)
330
331  def testInterleave(self):
332    test_cases = [{
333        'tensor': 0,
334        'shape': tensor_shape.TensorShape([])
335    }, {
336        'tensor': np.array([1, 2, 3]),
337        'shape': tensor_shape.TensorShape([3])
338    }, {
339        'tensor': np.array([[1, 2, 3]]),
340        'shape': tensor_shape.TensorShape([1, 3])
341    }]
342
343    for test_case in test_cases:
344      with ops.Graph().as_default() as g:
345        dataset = dataset_ops.Dataset.range(42)
346
347        def make_dataset(tensor):
348
349          def dataset_fn(n):
350            return dataset_ops.Dataset.from_tensors(tensor).repeat(n)
351
352          return dataset_fn
353
354        dataset = dataset.interleave(
355            make_dataset(test_case['tensor']), cycle_length=42)
356        iterator = dataset.make_one_shot_iterator()
357        get_next = iterator.get_next()
358        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
359        train_op.append(get_next)
360        mg = meta_graph.create_meta_graph_def(graph=g)
361        grappler_item = item.Item(mg)
362        op_properties = grappler_item.GetOpProperties()
363        self.assertEqual(test_case['shape'],
364                         op_properties['IteratorGetNext'][0].shape)
365
366  def testMap(self):
367    test_cases = [{
368        'tensor': 0,
369        'shape': tensor_shape.TensorShape([])
370    }, {
371        'tensor': np.array([1, 2, 3]),
372        'shape': tensor_shape.TensorShape([3])
373    }, {
374        'tensor': np.array([[1, 2, 3]]),
375        'shape': tensor_shape.TensorShape([3, 1])
376    }, {
377        'tensor': np.array([[[1, 2, 3], [4, 5, 6]]]),
378        'shape': tensor_shape.TensorShape([3, 2, 1])
379    }]
380
381    for test_case in test_cases:
382      with ops.Graph().as_default() as g:
383        dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
384        dataset = dataset.map(array_ops.transpose)
385        iterator = dataset.make_one_shot_iterator()
386        get_next = iterator.get_next()
387        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
388        train_op.append(get_next)
389        mg = meta_graph.create_meta_graph_def(graph=g)
390        grappler_item = item.Item(mg)
391        op_properties = grappler_item.GetOpProperties()
392        self.assertEqual(test_case['shape'],
393                         op_properties['IteratorGetNext'][0].shape)
394
395  def testFromStructure(self):
396    test_cases = [{
397        'shape': tensor_shape.TensorShape([])
398    }, {
399        'shape': tensor_shape.TensorShape([3])
400    }, {
401        'shape': tensor_shape.TensorShape([1, 2])
402    }, {
403        'shape': tensor_shape.TensorShape([1, 2, 3])
404    }]
405
406    for test_case in test_cases:
407      with ops.Graph().as_default() as g:
408        iterator = iterator_ops.Iterator.from_structure(
409            dtypes.int64, output_shapes=test_case['shape'])
410        get_next = iterator.get_next()
411        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
412        train_op.append(get_next)
413        mg = meta_graph.create_meta_graph_def(graph=g)
414        grappler_item = item.Item(mg)
415        op_properties = grappler_item.GetOpProperties()
416        self.assertEqual(test_case['shape'],
417                         op_properties['IteratorGetNext'][0].shape)
418
419  def testFromStringHandle(self):
420    test_cases = [{
421        'shape': tensor_shape.TensorShape([])
422    }, {
423        'shape': tensor_shape.TensorShape([3])
424    }, {
425        'shape': tensor_shape.TensorShape([1, 2])
426    }, {
427        'shape': tensor_shape.TensorShape([1, 2, 3])
428    }]
429
430    for test_case in test_cases:
431      with ops.Graph().as_default() as g:
432        iterator = iterator_ops.Iterator.from_structure(dtypes.int64)
433        handle = iterator.string_handle()
434        iterator = iterator_ops.Iterator.from_string_handle(
435            handle, dtypes.int64, output_shapes=test_case['shape'])
436        get_next = iterator.get_next()
437        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
438        train_op.append(get_next)
439        mg = meta_graph.create_meta_graph_def(graph=g)
440        grappler_item = item.Item(mg)
441        op_properties = grappler_item.GetOpProperties()
442        self.assertEqual(test_case['shape'],
443                         op_properties['IteratorGetNext'][0].shape)
444
445
446if __name__ == '__main__':
447  test.main()
448