1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the datasets shape inference.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.data.ops import iterator_ops 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import meta_graph 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.grappler import item 30from tensorflow.python.ops import array_ops 31from tensorflow.python.platform import test 32 33 34class GrapplerTest(test.TestCase): 35 36 def testFromTensors(self): 37 test_cases = [{ 38 'tensor': 0, 39 'shape': tensor_shape.TensorShape([]) 40 }, { 41 'tensor': np.array([1, 2, 3]), 42 'shape': tensor_shape.TensorShape([3]) 43 }, { 44 'tensor': np.array([[1, 2, 3]]), 45 'shape': tensor_shape.TensorShape([1, 3]) 46 }] 47 48 for test_case in test_cases: 49 with ops.Graph().as_default() as g: 50 dataset = dataset_ops.Dataset.from_tensors(test_case['tensor']) 51 iterator = dataset.make_one_shot_iterator() 52 get_next = iterator.get_next() 53 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 54 train_op.append(get_next) 55 mg = meta_graph.create_meta_graph_def(graph=g) 56 grappler_item = item.Item(mg) 57 op_properties = grappler_item.GetOpProperties() 58 self.assertEqual(test_case['shape'], 59 op_properties['IteratorGetNext'][0].shape) 60 61 def testFromTensorSlices(self): 62 test_cases = [{ 63 'tensor': np.array([1, 2, 3]), 64 'shape': tensor_shape.TensorShape([]) 65 }, { 66 'tensor': np.array([[1, 2, 3]]), 67 'shape': tensor_shape.TensorShape([3]) 68 }, { 69 'tensor': np.array([[[1, 2, 3]]]), 70 'shape': tensor_shape.TensorShape([1, 3]) 71 }] 72 73 for test_case in test_cases: 74 with ops.Graph().as_default() as g: 75 dataset = dataset_ops.Dataset.from_tensor_slices(test_case['tensor']) 76 iterator = dataset.make_one_shot_iterator() 77 get_next = iterator.get_next() 78 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 79 train_op.append(get_next) 80 mg = meta_graph.create_meta_graph_def(graph=g) 81 grappler_item = item.Item(mg) 82 op_properties = grappler_item.GetOpProperties() 83 self.assertEqual(test_case['shape'], 84 op_properties['IteratorGetNext'][0].shape) 85 86 def testFromGenerator(self): 87 test_cases = [{ 88 'tensor': 0, 89 'shape': tensor_shape.TensorShape([]) 90 }, { 91 'tensor': np.array([1, 2, 3]), 92 'shape': tensor_shape.TensorShape([3]) 93 }, { 94 'tensor': np.array([[1, 2, 3]]), 95 'shape': tensor_shape.TensorShape([1, 3]) 96 }] 97 98 for test_case in test_cases: 99 100 def make_generator(tensor): 101 102 def generator(): 103 yield tensor 104 105 return generator 106 107 with ops.Graph().as_default() as g: 108 dataset = dataset_ops.Dataset.from_generator( 109 make_generator(test_case['tensor']), 110 dtypes.int64, 111 output_shapes=test_case['shape']) 112 iterator = dataset.make_one_shot_iterator() 113 get_next = iterator.get_next() 114 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 115 train_op.append(get_next) 116 mg = meta_graph.create_meta_graph_def(graph=g) 117 grappler_item = item.Item(mg) 118 op_properties = grappler_item.GetOpProperties() 119 self.assertEqual(test_case['shape'], 120 op_properties['IteratorGetNext'][0].shape) 121 122 def testRange(self): 123 with ops.Graph().as_default() as g: 124 dataset = dataset_ops.Dataset.range(42) 125 iterator = dataset.make_one_shot_iterator() 126 get_next = iterator.get_next() 127 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 128 train_op.append(get_next) 129 mg = meta_graph.create_meta_graph_def(graph=g) 130 grappler_item = item.Item(mg) 131 op_properties = grappler_item.GetOpProperties() 132 self.assertEqual(tensor_shape.scalar(), 133 op_properties['IteratorGetNext'][0].shape) 134 135 def _testTransformation(self, fn): 136 test_cases = [{ 137 'tensor': 0, 138 'shape': tensor_shape.TensorShape({}) 139 }, { 140 'tensor': np.array([1, 2, 3]), 141 'shape': tensor_shape.TensorShape([3]) 142 }, { 143 'tensor': np.array([[1, 2, 3]]), 144 'shape': tensor_shape.TensorShape([1, 3]) 145 }] 146 147 for test_case in test_cases: 148 with ops.Graph().as_default() as g: 149 dataset = dataset_ops.Dataset.from_tensors(test_case['tensor']) 150 dataset = fn(dataset, test_case['tensor'], test_case['shape']) 151 iterator = dataset.make_one_shot_iterator() 152 get_next = iterator.get_next() 153 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 154 train_op.append(get_next) 155 mg = meta_graph.create_meta_graph_def(graph=g) 156 grappler_item = item.Item(mg) 157 op_properties = grappler_item.GetOpProperties() 158 self.assertEqual(test_case['shape'], 159 op_properties['IteratorGetNext'][0].shape) 160 161 def testConcatenate(self): 162 163 def fn(dataset, tensor, shape): 164 del shape 165 return dataset.concatenate(dataset_ops.Dataset.from_tensors(tensor)) 166 167 self._testTransformation(fn) 168 169 def testPrefetch(self): 170 171 def fn(dataset, tensor, shape): 172 del tensor, shape 173 return dataset.prefetch(42) 174 175 self._testTransformation(fn) 176 177 def testRepeat(self): 178 179 def fn(dataset, tensor, shape): 180 del tensor, shape 181 return dataset.repeat(42) 182 183 self._testTransformation(fn) 184 185 def testShuffle(self): 186 187 def fn(dataset, tensor, shape): 188 del tensor, shape 189 return dataset.shuffle(42) 190 191 self._testTransformation(fn) 192 193 def testCache(self): 194 195 def fn(dataset, tensor, shape): 196 del tensor, shape 197 return dataset.cache() 198 199 self._testTransformation(fn) 200 201 def testTake(self): 202 203 def fn(dataset, tensor, shape): 204 del tensor, shape 205 return dataset.take(42) 206 207 self._testTransformation(fn) 208 209 def testSkip(self): 210 211 def fn(dataset, tensor, shape): 212 del tensor, shape 213 return dataset.skip(42) 214 215 self._testTransformation(fn) 216 217 def testShard(self): 218 219 def fn(dataset, tensor, shape): 220 del tensor, shape 221 return dataset.shard(42, 0) 222 223 self._testTransformation(fn) 224 225 def testFilter(self): 226 227 def fn(dataset, tensor, shape): 228 del tensor, shape 229 return dataset.filter(lambda x: True) 230 231 self._testTransformation(fn) 232 233 def as_tensor_shape(self, proto_with_symbolic_values): 234 for i in range(len(proto_with_symbolic_values.dim)): 235 if proto_with_symbolic_values.dim[i].size < -1: 236 proto_with_symbolic_values.dim[i].size = -1 237 return tensor_shape.TensorShape(proto_with_symbolic_values) 238 239 def testBatch(self): 240 test_cases = [{ 241 'tensor': 0, 242 'shape': tensor_shape.TensorShape([None]) 243 }, { 244 'tensor': np.array([1, 2, 3]), 245 'shape': tensor_shape.TensorShape([None, 3]) 246 }, { 247 'tensor': np.array([[1, 2, 3]]), 248 'shape': tensor_shape.TensorShape([None, 1, 3]) 249 }] 250 251 for test_case in test_cases: 252 with ops.Graph().as_default() as g: 253 dataset = dataset_ops.Dataset.from_tensors(test_case['tensor']) 254 dataset = dataset.batch(42) 255 iterator = dataset.make_one_shot_iterator() 256 get_next = iterator.get_next() 257 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 258 train_op.append(get_next) 259 mg = meta_graph.create_meta_graph_def(graph=g) 260 grappler_item = item.Item(mg) 261 op_properties = grappler_item.GetOpProperties() 262 inferred_shape = self.as_tensor_shape( 263 op_properties['IteratorGetNext'][0].shape) 264 self.assertTrue(test_case['shape'][0].is_compatible_with( 265 inferred_shape[0])) 266 self.assertEqual(test_case['shape'][1:], inferred_shape[1:]) 267 268 def testPaddedBatch(self): 269 test_cases = [{ 270 'tensor': 0, 271 'shape': tensor_shape.TensorShape([None]) 272 }, { 273 'tensor': np.array([1, 2, 3]), 274 'shape': tensor_shape.TensorShape([None, 4]) 275 }, { 276 'tensor': np.array([[1, 2, 3]]), 277 'shape': tensor_shape.TensorShape([None, 2, 4]) 278 }] 279 280 for test_case in test_cases: 281 with ops.Graph().as_default() as g: 282 dataset = dataset_ops.Dataset.from_tensors(test_case['tensor']) 283 dataset = dataset.padded_batch(42, padded_shapes=test_case['shape'][1:]) 284 iterator = dataset.make_one_shot_iterator() 285 get_next = iterator.get_next() 286 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 287 train_op.append(get_next) 288 mg = meta_graph.create_meta_graph_def(graph=g) 289 grappler_item = item.Item(mg) 290 op_properties = grappler_item.GetOpProperties() 291 inferred_shape = self.as_tensor_shape( 292 op_properties['IteratorGetNext'][0].shape) 293 self.assertTrue(test_case['shape'][0].is_compatible_with( 294 inferred_shape[0])) 295 self.assertEqual(test_case['shape'][1:], inferred_shape[1:]) 296 297 def testFlatMap(self): 298 test_cases = [{ 299 'tensor': 0, 300 'shape': tensor_shape.TensorShape([]) 301 }, { 302 'tensor': np.array([1, 2, 3]), 303 'shape': tensor_shape.TensorShape([3]) 304 }, { 305 'tensor': np.array([[1, 2, 3]]), 306 'shape': tensor_shape.TensorShape([1, 3]) 307 }] 308 309 for test_case in test_cases: 310 with ops.Graph().as_default() as g: 311 dataset = dataset_ops.Dataset.range(42) 312 313 def make_dataset(tensor): 314 315 def dataset_fn(n): 316 return dataset_ops.Dataset.from_tensors(tensor).repeat(n) 317 318 return dataset_fn 319 320 dataset = dataset.flat_map(make_dataset(test_case['tensor'])) 321 iterator = dataset.make_one_shot_iterator() 322 get_next = iterator.get_next() 323 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 324 train_op.append(get_next) 325 mg = meta_graph.create_meta_graph_def(graph=g) 326 grappler_item = item.Item(mg) 327 op_properties = grappler_item.GetOpProperties() 328 self.assertEqual(test_case['shape'], 329 op_properties['IteratorGetNext'][0].shape) 330 331 def testInterleave(self): 332 test_cases = [{ 333 'tensor': 0, 334 'shape': tensor_shape.TensorShape([]) 335 }, { 336 'tensor': np.array([1, 2, 3]), 337 'shape': tensor_shape.TensorShape([3]) 338 }, { 339 'tensor': np.array([[1, 2, 3]]), 340 'shape': tensor_shape.TensorShape([1, 3]) 341 }] 342 343 for test_case in test_cases: 344 with ops.Graph().as_default() as g: 345 dataset = dataset_ops.Dataset.range(42) 346 347 def make_dataset(tensor): 348 349 def dataset_fn(n): 350 return dataset_ops.Dataset.from_tensors(tensor).repeat(n) 351 352 return dataset_fn 353 354 dataset = dataset.interleave( 355 make_dataset(test_case['tensor']), cycle_length=42) 356 iterator = dataset.make_one_shot_iterator() 357 get_next = iterator.get_next() 358 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 359 train_op.append(get_next) 360 mg = meta_graph.create_meta_graph_def(graph=g) 361 grappler_item = item.Item(mg) 362 op_properties = grappler_item.GetOpProperties() 363 self.assertEqual(test_case['shape'], 364 op_properties['IteratorGetNext'][0].shape) 365 366 def testMap(self): 367 test_cases = [{ 368 'tensor': 0, 369 'shape': tensor_shape.TensorShape([]) 370 }, { 371 'tensor': np.array([1, 2, 3]), 372 'shape': tensor_shape.TensorShape([3]) 373 }, { 374 'tensor': np.array([[1, 2, 3]]), 375 'shape': tensor_shape.TensorShape([3, 1]) 376 }, { 377 'tensor': np.array([[[1, 2, 3], [4, 5, 6]]]), 378 'shape': tensor_shape.TensorShape([3, 2, 1]) 379 }] 380 381 for test_case in test_cases: 382 with ops.Graph().as_default() as g: 383 dataset = dataset_ops.Dataset.from_tensors(test_case['tensor']) 384 dataset = dataset.map(array_ops.transpose) 385 iterator = dataset.make_one_shot_iterator() 386 get_next = iterator.get_next() 387 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 388 train_op.append(get_next) 389 mg = meta_graph.create_meta_graph_def(graph=g) 390 grappler_item = item.Item(mg) 391 op_properties = grappler_item.GetOpProperties() 392 self.assertEqual(test_case['shape'], 393 op_properties['IteratorGetNext'][0].shape) 394 395 def testFromStructure(self): 396 test_cases = [{ 397 'shape': tensor_shape.TensorShape([]) 398 }, { 399 'shape': tensor_shape.TensorShape([3]) 400 }, { 401 'shape': tensor_shape.TensorShape([1, 2]) 402 }, { 403 'shape': tensor_shape.TensorShape([1, 2, 3]) 404 }] 405 406 for test_case in test_cases: 407 with ops.Graph().as_default() as g: 408 iterator = iterator_ops.Iterator.from_structure( 409 dtypes.int64, output_shapes=test_case['shape']) 410 get_next = iterator.get_next() 411 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 412 train_op.append(get_next) 413 mg = meta_graph.create_meta_graph_def(graph=g) 414 grappler_item = item.Item(mg) 415 op_properties = grappler_item.GetOpProperties() 416 self.assertEqual(test_case['shape'], 417 op_properties['IteratorGetNext'][0].shape) 418 419 def testFromStringHandle(self): 420 test_cases = [{ 421 'shape': tensor_shape.TensorShape([]) 422 }, { 423 'shape': tensor_shape.TensorShape([3]) 424 }, { 425 'shape': tensor_shape.TensorShape([1, 2]) 426 }, { 427 'shape': tensor_shape.TensorShape([1, 2, 3]) 428 }] 429 430 for test_case in test_cases: 431 with ops.Graph().as_default() as g: 432 iterator = iterator_ops.Iterator.from_structure(dtypes.int64) 433 handle = iterator.string_handle() 434 iterator = iterator_ops.Iterator.from_string_handle( 435 handle, dtypes.int64, output_shapes=test_case['shape']) 436 get_next = iterator.get_next() 437 train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) 438 train_op.append(get_next) 439 mg = meta_graph.create_meta_graph_def(graph=g) 440 grappler_item = item.Item(mg) 441 op_properties = grappler_item.GetOpProperties() 442 self.assertEqual(test_case['shape'], 443 op_properties['IteratorGetNext'][0].shape) 444 445 446if __name__ == '__main__': 447 test.main() 448