1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for Datasets.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21import collections 22import threading 23 24import numpy as np 25import six 26 27from tensorflow.python.data.ops import iterator_ops 28from tensorflow.python.data.util import nest 29from tensorflow.python.data.util import sparse 30from tensorflow.python.eager import context 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import function 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import random_seed 36from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 37from tensorflow.python.framework import tensor_shape 38from tensorflow.python.framework import tensor_util 39from tensorflow.python.ops import gen_dataset_ops 40from tensorflow.python.ops import gen_io_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import script_ops 43from tensorflow.python.util import deprecation 44from tensorflow.python.util.tf_export import tf_export 45 46 47@tf_export("data.Dataset") 48class Dataset(object): 49 """Represents a potentially large set of elements. 50 51 A `Dataset` can be used to represent an input pipeline as a 52 collection of elements (nested structures of tensors) and a "logical 53 plan" of transformations that act on those elements. 54 """ 55 __metaclass__ = abc.ABCMeta 56 57 def __init__(self): 58 pass 59 60 @abc.abstractmethod 61 def _as_variant_tensor(self): 62 """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. 63 64 Returns: 65 A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset. 66 """ 67 raise NotImplementedError("Dataset._as_variant_tensor") 68 69 def make_initializable_iterator(self, shared_name=None): 70 """Creates an `Iterator` for enumerating the elements of this dataset. 71 72 Note: The returned iterator will be in an uninitialized state, 73 and you must run the `iterator.initializer` operation before using it: 74 75 ```python 76 dataset = ... 77 iterator = dataset.make_initializable_iterator() 78 # ... 79 sess.run(iterator.initializer) 80 ``` 81 82 Args: 83 shared_name: (Optional.) If non-empty, the returned iterator will be 84 shared under the given name across multiple sessions that share the 85 same devices (e.g. when using a remote server). 86 87 Returns: 88 An `Iterator` over the elements of this dataset. 89 90 Raises: 91 RuntimeError: If eager execution is enabled. 92 """ 93 if context.in_eager_mode(): 94 raise RuntimeError( 95 "dataset.make_initializable_iterator is not supported when eager " 96 "execution is enabled.") 97 if shared_name is None: 98 shared_name = "" 99 iterator_resource = gen_dataset_ops.iterator( 100 container="", 101 shared_name=shared_name, 102 output_types=nest.flatten( 103 sparse.as_dense_types(self.output_types, self.output_classes)), 104 output_shapes=nest.flatten( 105 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 106 with ops.colocate_with(iterator_resource): 107 initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(), 108 iterator_resource) 109 return iterator_ops.Iterator(iterator_resource, initializer, 110 self.output_types, self.output_shapes, 111 self.output_classes) 112 113 def make_one_shot_iterator(self): 114 """Creates an `Iterator` for enumerating the elements of this dataset. 115 116 Note: The returned iterator will be initialized automatically. 117 A "one-shot" iterator does not currently support re-initialization. 118 119 Returns: 120 An `Iterator` over the elements of this dataset. 121 122 Raises: 123 RuntimeError: If eager execution is enabled. 124 """ 125 if context.in_eager_mode(): 126 raise RuntimeError( 127 "dataset.make_one_shot_iterator is not supported when eager " 128 "execution is enabled.") 129 # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is 130 # a 0-argument function. 131 @function.Defun(capture_by_value=True) 132 def _make_dataset(): 133 return self._as_variant_tensor() # pylint: disable=protected-access 134 135 try: 136 _make_dataset.add_to_graph(ops.get_default_graph()) 137 except ValueError as err: 138 if "Cannot capture a stateful node" in str(err): 139 raise ValueError( 140 "Failed to create a one-shot iterator for a dataset. " 141 "`Dataset.make_one_shot_iterator()` does not support datasets that " 142 "capture stateful objects, such as a `Variable` or `LookupTable`. " 143 "In these cases, use `Dataset.make_initializable_iterator()`. " 144 "(Original error: %s)" % err) 145 else: 146 six.reraise(ValueError, err) 147 148 return iterator_ops.Iterator( 149 gen_dataset_ops.one_shot_iterator( 150 dataset_factory=_make_dataset, 151 output_types=nest.flatten( 152 sparse.as_dense_types(self.output_types, self.output_classes)), 153 output_shapes=nest.flatten( 154 sparse.as_dense_shapes(self.output_shapes, 155 self.output_classes))), None, 156 self.output_types, self.output_shapes, self.output_classes) 157 158 @abc.abstractproperty 159 def output_classes(self): 160 """Returns the class of each component of an element of this dataset. 161 162 The expected values are `tf.Tensor` and `tf.SparseTensor`. 163 164 Returns: 165 A nested structure of Python `type` objects corresponding to each 166 component of an element of this dataset. 167 """ 168 raise NotImplementedError("Dataset.output_classes") 169 170 @abc.abstractproperty 171 def output_shapes(self): 172 """Returns the shape of each component of an element of this dataset. 173 174 Returns: 175 A nested structure of `tf.TensorShape` objects corresponding to each 176 component of an element of this dataset. 177 """ 178 raise NotImplementedError("Dataset.output_shapes") 179 180 @abc.abstractproperty 181 def output_types(self): 182 """Returns the type of each component of an element of this dataset. 183 184 Returns: 185 A nested structure of `tf.DType` objects corresponding to each component 186 of an element of this dataset. 187 """ 188 raise NotImplementedError("Dataset.output_types") 189 190 def __repr__(self): 191 output_shapes = nest.map_structure(str, self.output_shapes) 192 output_shapes = str(output_shapes).replace("'", "") 193 output_types = nest.map_structure(repr, self.output_types) 194 output_types = str(output_types).replace("'", "") 195 return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes, 196 output_types)) 197 198 @staticmethod 199 def from_tensors(tensors): 200 """Creates a `Dataset` with a single element, comprising the given tensors. 201 202 Args: 203 tensors: A nested structure of tensors. 204 205 Returns: 206 Dataset: A `Dataset`. 207 """ 208 return TensorDataset(tensors) 209 210 @staticmethod 211 def from_tensor_slices(tensors): 212 """Creates a `Dataset` whose elements are slices of the given tensors. 213 214 Args: 215 tensors: A nested structure of tensors, each having the same size in the 216 0th dimension. 217 218 Returns: 219 Dataset: A `Dataset`. 220 """ 221 return TensorSliceDataset(tensors) 222 223 @staticmethod 224 @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") 225 def from_sparse_tensor_slices(sparse_tensor): 226 """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. 227 228 Args: 229 sparse_tensor: A `tf.SparseTensor`. 230 231 Returns: 232 Dataset: A `Dataset` of rank-(N-1) sparse tensors. 233 """ 234 return SparseTensorSliceDataset(sparse_tensor) 235 236 class _GeneratorState(object): 237 """Stores outstanding iterators created from a Python generator. 238 239 This class keeps track of potentially multiple iterators that may have 240 been created from a generator, e.g. in the case that the dataset is 241 repeated, or nested within a parallel computation. 242 """ 243 244 def __init__(self, generator): 245 self._generator = generator 246 self._lock = threading.Lock() 247 self._next_id = 0 # GUARDED_BY(self._lock) 248 self._iterators = collections.defaultdict(lambda: iter(generator())) 249 250 def get_next_id(self): 251 with self._lock: 252 ret = self._next_id 253 self._next_id += 1 254 # NOTE(mrry): Explicitly create an array of `np.int64` because implicit 255 # casting in `py_func()` will create an array of `np.int32` on Windows, 256 # leading to a runtime error. 257 return np.array(ret, dtype=np.int64) 258 259 def get_iterator(self, iterator_id): 260 return self._iterators[iterator_id] 261 262 def iterator_completed(self, iterator_id): 263 del self._iterators[iterator_id] 264 265 @staticmethod 266 def from_generator(generator, output_types, output_shapes=None): 267 """Creates a `Dataset` whose elements are generated by `generator`. 268 269 The `generator` argument must be a callable object that returns 270 an object that support the `iter()` protocol (e.g. a generator function). 271 The elements generated by `generator` must be compatible with the given 272 `output_types` and (optional) `output_shapes` arguments. 273 274 For example: 275 276 ```python 277 import itertools 278 279 def gen(): 280 for i in itertools.count(1): 281 yield (i, [1] * i) 282 283 ds = Dataset.from_generator( 284 gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) 285 value = ds.make_one_shot_iterator().get_next() 286 287 sess.run(value) # (1, array([1])) 288 sess.run(value) # (2, array([1, 1])) 289 ``` 290 291 NOTE: The current implementation of `Dataset.from_generator()` uses 292 @{tf.py_func} and inherits the same constraints. In particular, it 293 requires the `Dataset`- and `Iterator`-related operations to be placed 294 on a device in the same process as the Python program that called 295 `Dataset.from_generator()`. The body of `generator` will not be 296 serialized in a `GraphDef`, and you should not use this method if you 297 need to serialize your model and restore it in a different environment. 298 299 NOTE: If `generator` depends on mutable global variables or other external 300 state, be aware that the runtime may invoke `generator` multiple times 301 (in order to support repeating the `Dataset`) and at any time 302 between the call to `Dataset.from_generator()` and the production of the 303 first element from the generator. Mutating global variables or external 304 state can cause undefined behavior, and we recommend that you explicitly 305 cache any external state in `generator` before calling 306 `Dataset.from_generator()`. 307 308 Args: 309 generator: A callable object that takes no arguments and returns an 310 object that supports the `iter()` protocol. 311 output_types: A nested structure of `tf.DType` objects corresponding to 312 each component of an element yielded by `generator`. 313 output_shapes: (Optional.) A nested structure of `tf.TensorShape` 314 objects corresponding to each component of an element yielded by 315 `generator`. 316 317 Returns: 318 Dataset: A `Dataset`. 319 """ 320 if not callable(generator): 321 raise TypeError("`generator` must be callable.") 322 if output_shapes is None: 323 output_shapes = nest.map_structure( 324 lambda _: tensor_shape.TensorShape(None), output_types) 325 else: 326 output_shapes = nest.map_structure_up_to( 327 output_types, tensor_shape.as_shape, output_shapes) 328 329 flattened_types = nest.flatten(output_types) 330 flattened_shapes = nest.flatten(output_shapes) 331 332 generator_state = Dataset._GeneratorState(generator) 333 334 def get_iterator_id_map_fn(unused_dummy): 335 """Creates a unique `iterator_id` for each pass over the dataset. 336 337 The "iterator_id" disambiguates between multiple concurrently 338 existing iterators. 339 340 Args: 341 unused_dummy: Ignored value. 342 343 Returns: 344 A `tf.int64` tensor whose value uniquely identifies an iterator in 345 `generator_state`. 346 """ 347 return script_ops.py_func( 348 generator_state.get_next_id, [], dtypes.int64, stateful=True) 349 350 def generator_map_fn(iterator_id_t): 351 """Generates the next element from iterator with ID `iterator_id_t`. 352 353 We map this function across an infinite repetition of the 354 `iterator_id_t`, and raise `StopIteration` to terminate the iteration. 355 356 Args: 357 iterator_id_t: A `tf.int64` tensor whose value uniquely identifies 358 the iterator in `generator_state` from which to generate an element. 359 360 Returns: 361 A nested structure of tensors representing an element from the iterator. 362 """ 363 364 def generator_py_func(iterator_id): 365 """A `py_func` that will be called to invoke the iterator.""" 366 try: 367 values = next(generator_state.get_iterator(iterator_id)) 368 except StopIteration: 369 generator_state.iterator_completed(iterator_id) 370 raise StopIteration("Iteration finished.") 371 372 # Use the same _convert function from the py_func() implementation to 373 # convert the returned values to arrays early, so that we can inspect 374 # their values. 375 # pylint: disable=protected-access 376 ret_arrays = [ 377 script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype) 378 for ret, dtype in zip( 379 nest.flatten_up_to(output_types, values), flattened_types) 380 ] 381 # pylint: enable=protected-access 382 383 # Additional type and shape checking to ensure that the components 384 # of the generated element match the `output_types` and `output_shapes` 385 # arguments. 386 for (ret_array, expected_dtype, expected_shape) in zip( 387 ret_arrays, flattened_types, flattened_shapes): 388 if ret_array.dtype != expected_dtype.as_numpy_dtype: 389 raise TypeError( 390 "`generator` yielded an element of type %s where an element " 391 "of type %s was expected." % (ret_array.dtype, 392 expected_dtype.as_numpy_dtype)) 393 if not expected_shape.is_compatible_with(ret_array.shape): 394 raise ValueError( 395 "`generator` yielded an element of shape %s where an element " 396 "of shape %s was expected." % (ret_array.shape, expected_shape)) 397 398 return ret_arrays 399 400 flat_values = script_ops.py_func( 401 generator_py_func, [iterator_id_t], flattened_types, stateful=True) 402 403 # The `py_func()` op drops the inferred shapes, so we add them back in 404 # here. 405 if output_shapes is not None: 406 for ret_t, shape in zip(flat_values, flattened_shapes): 407 ret_t.set_shape(shape) 408 409 return nest.pack_sequence_as(output_types, flat_values) 410 411 # This function associates each traversal of `generator` with a unique 412 # iterator ID. 413 def flat_map_fn(iterator_id_t): 414 # First, generate an infinite dataset containing the iterator ID repeated 415 # forever. 416 repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None) 417 418 # The `generator_map_fn` gets the next element from the iterator with the 419 # relevant ID, and raises StopIteration when that iterator contains no 420 # more elements. 421 return repeated_id.map(generator_map_fn) 422 423 # A single-element dataset that, each time it is evaluated, contains a 424 # freshly-generated and unique (for the returned dataset) int64 425 # ID that will be used to identify the appropriate Python state, which 426 # is encapsulated in `generator_state`, and captured in 427 # `get_iterator_id_map_fn`. 428 dummy = 0 429 id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn) 430 431 # A dataset that contains all of the elements generated by a 432 # single iterator created from `generator`, identified by the 433 # iterator ID contained in `id_dataset`. Lifting the iteration 434 # into a flat_map here enables multiple repetitions and/or nested 435 # versions of the returned dataset to be created, because it forces 436 # the generation of a new ID for each version. 437 return id_dataset.flat_map(flat_map_fn) 438 439 @staticmethod 440 def range(*args): 441 """Creates a `Dataset` of a step-separated range of values. 442 443 For example: 444 445 ```python 446 Dataset.range(5) == [0, 1, 2, 3, 4] 447 Dataset.range(2, 5) == [2, 3, 4] 448 Dataset.range(1, 5, 2) == [1, 3] 449 Dataset.range(1, 5, -2) == [] 450 Dataset.range(5, 1) == [] 451 Dataset.range(5, 1, -2) == [5, 3] 452 ``` 453 454 Args: 455 *args: follow same semantics as python's xrange. 456 len(args) == 1 -> start = 0, stop = args[0], step = 1 457 len(args) == 2 -> start = args[0], stop = args[1], step = 1 458 len(args) == 3 -> start = args[0], stop = args[1, stop = args[2] 459 460 Returns: 461 Dataset: A `RangeDataset`. 462 463 Raises: 464 ValueError: if len(args) == 0. 465 """ 466 return RangeDataset(*args) 467 468 @staticmethod 469 def zip(datasets): 470 """Creates a `Dataset` by zipping together the given datasets. 471 472 This method has similar semantics to the built-in `zip()` function 473 in Python, with the main difference being that the `datasets` 474 argument can be an arbitrary nested structure of `Dataset` objects. 475 For example: 476 477 ```python 478 # NOTE: The following examples use `{ ... }` to represent the 479 # contents of a dataset. 480 a = { 1, 2, 3 } 481 b = { 4, 5, 6 } 482 c = { (7, 8), (9, 10), (11, 12) } 483 d = { 13, 14 } 484 485 # The nested structure of the `datasets` argument determines the 486 # structure of elements in the resulting dataset. 487 Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) } 488 Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) } 489 490 # The `datasets` argument may contain an arbitrary number of 491 # datasets. 492 Dataset.zip((a, b, c)) == { (1, 4, (7, 8)), 493 (2, 5, (9, 10)), 494 (3, 6, (11, 12)) } 495 496 # The number of elements in the resulting dataset is the same as 497 # the size of the smallest dataset in `datasets`. 498 Dataset.zip((a, d)) == { (1, 13), (2, 14) } 499 ``` 500 501 Args: 502 datasets: A nested structure of datasets. 503 504 Returns: 505 Dataset: A `Dataset`. 506 """ 507 return ZipDataset(datasets) 508 509 def concatenate(self, dataset): 510 """Creates a `Dataset` by concatenating given dataset with this dataset. 511 512 ```python 513 # NOTE: The following examples use `{ ... }` to represent the 514 # contents of a dataset. 515 a = { 1, 2, 3 } 516 b = { 4, 5, 6, 7 } 517 518 # Input dataset and dataset to be concatenated should have same 519 # nested structures and output types. 520 # c = { (8, 9), (10, 11), (12, 13) } 521 # d = { 14.0, 15.0, 16.0 } 522 # a.concatenate(c) and a.concatenate(d) would result in error. 523 524 a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 } 525 ``` 526 527 Args: 528 dataset: `Dataset` to be concatenated. 529 530 Returns: 531 Dataset: A `Dataset`. 532 """ 533 return ConcatenateDataset(self, dataset) 534 535 def prefetch(self, buffer_size): 536 """Creates a `Dataset` that prefetches elements from this dataset. 537 538 Args: 539 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the 540 maximum number elements that will be buffered when prefetching. 541 542 Returns: 543 Dataset: A `Dataset`. 544 """ 545 return PrefetchDataset(self, buffer_size) 546 547 @staticmethod 548 def list_files(file_pattern): 549 """A dataset of all files matching a pattern. 550 551 Example: 552 If we had the following files on our filesystem: 553 - /path/to/dir/a.txt 554 - /path/to/dir/b.py 555 - /path/to/dir/c.py 556 If we pass "/path/to/dir/*.py" as the directory, the dataset would 557 produce: 558 - /path/to/dir/b.py 559 - /path/to/dir/c.py 560 561 NOTE: The order of the file names returned can be non-deterministic. 562 563 Args: 564 file_pattern: A string or scalar string `tf.Tensor`, representing 565 the filename pattern that will be matched. 566 567 Returns: 568 Dataset: A `Dataset` of strings corresponding to file names. 569 """ 570 return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) 571 572 def repeat(self, count=None): 573 """Repeats this dataset `count` times. 574 575 NOTE: If this dataset is a function of global state (e.g. a random number 576 generator), then different repetitions may produce different elements. 577 578 Args: 579 count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 580 number of times the dataset should be repeated. The default behavior 581 (if `count` is `None` or `-1`) is for the dataset be repeated 582 indefinitely. 583 584 Returns: 585 Dataset: A `Dataset`. 586 """ 587 return RepeatDataset(self, count) 588 589 def _enumerate(self, start=0): 590 591 max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max 592 return Dataset.zip((Dataset.range(start, max_value), self)) 593 594 def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None): 595 """Randomly shuffles the elements of this dataset. 596 597 Args: 598 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the 599 number of elements from this dataset from which the new 600 dataset will sample. 601 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 602 random seed that will be used to create the distribution. See 603 @{tf.set_random_seed} for behavior. 604 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 605 that the dataset should be pseudorandomly reshuffled each time it is 606 iterated over. (Defaults to `True`.) 607 608 Returns: 609 Dataset: A `Dataset`. 610 """ 611 return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration) 612 613 def cache(self, filename=""): 614 """Caches the elements in this dataset. 615 616 Args: 617 filename: A `tf.string` scalar `tf.Tensor`, representing the name of a 618 directory on the filesystem to use for caching tensors in this Dataset. 619 If a filename is not provided, the dataset will be cached in memory. 620 621 Returns: 622 Dataset: A `Dataset`. 623 """ 624 return CacheDataset(self, filename) 625 626 def take(self, count): 627 """Creates a `Dataset` with at most `count` elements from this dataset. 628 629 Args: 630 count: A `tf.int64` scalar `tf.Tensor`, representing the number of 631 elements of this dataset that should be taken to form the new dataset. 632 If `count` is -1, or if `count` is greater than the size of this 633 dataset, the new dataset will contain all elements of this dataset. 634 635 Returns: 636 Dataset: A `Dataset`. 637 """ 638 return TakeDataset(self, count) 639 640 def skip(self, count): 641 """Creates a `Dataset` that skips `count` elements from this dataset. 642 643 Args: 644 count: A `tf.int64` scalar `tf.Tensor`, representing the number 645 of elements of this dataset that should be skipped to form the 646 new dataset. If `count` is greater than the size of this 647 dataset, the new dataset will contain no elements. If `count` 648 is -1, skips the entire dataset. 649 650 Returns: 651 Dataset: A `Dataset`. 652 """ 653 return SkipDataset(self, count) 654 655 def shard(self, num_shards, index): 656 """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. 657 658 This dataset operator is very useful when running distributed training, as 659 it allows each worker to read a unique subset. 660 661 When reading a single input file, you can skip elements as follows: 662 663 ```python 664 d = tf.data.TFRecordDataset(FLAGS.input_file) 665 d = d.shard(FLAGS.num_workers, FLAGS.worker_index) 666 d = d.repeat(FLAGS.num_epochs) 667 d = d.shuffle(FLAGS.shuffle_buffer_size) 668 d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) 669 ``` 670 671 Important caveats: 672 673 - Be sure to shard before you use any randomizing operator (such as 674 shuffle). 675 - Generally it is best if the shard operator is used early in the dataset 676 pipeline. For example, when reading from a set of TFRecord files, shard 677 before converting the dataset to input samples. This avoids reading every 678 file on every worker. The following is an example of an efficient 679 sharding strategy within a complete pipeline: 680 681 ```python 682 d = Dataset.list_files(FLAGS.pattern) 683 d = d.shard(FLAGS.num_workers, FLAGS.worker_index) 684 d = d.repeat(FLAGS.num_epochs) 685 d = d.shuffle(FLAGS.shuffle_buffer_size) 686 d = d.repeat() 687 d = d.interleave(tf.data.TFRecordDataset, 688 cycle_length=FLAGS.num_readers, block_length=1) 689 d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) 690 ``` 691 692 Args: 693 num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of 694 shards operating in parallel. 695 index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. 696 697 Returns: 698 Dataset: A `Dataset`. 699 700 Raises: 701 ValueError: if `num_shards` or `index` are illegal values. Note: error 702 checking is done on a best-effort basis, and aren't guaranteed to be 703 caught upon dataset creation. (e.g. providing in a placeholder tensor 704 bypasses the early checking, and will instead result in an error during 705 a session.run call.) 706 """ 707 num_shards = ops.convert_to_tensor( 708 num_shards, name="num_shards", dtype=dtypes.int64) 709 num_shards_static = tensor_util.constant_value(num_shards) 710 index = ops.convert_to_tensor(index, name="index", dtype=dtypes.int64) 711 index_static = tensor_util.constant_value(index) 712 713 if num_shards_static is not None and num_shards_static < 1: 714 raise ValueError("num_shards must be >= 1; got: %s" % num_shards_static) 715 if index_static is not None and index_static < 0: 716 raise ValueError("index must be >= 0; got: %s" % index_static) 717 if (index_static is not None and num_shards_static is not None and 718 index_static >= num_shards_static): 719 raise ValueError("index must be <= num_shards; %s is not < %s" % 720 (index_static, num_shards_static)) 721 722 def filter_fn(elem_index, _): 723 mod_result = math_ops.mod(elem_index, num_shards) 724 return math_ops.equal(mod_result, index) 725 726 return self._enumerate().filter(filter_fn).map(lambda _, elem: elem) 727 728 def batch(self, batch_size): 729 """Combines consecutive elements of this dataset into batches. 730 731 NOTE: If the number of elements (`N`) in this dataset is not an exact 732 multiple of `batch_size`, the final batch contain smaller tensors with 733 shape `N % batch_size` in the batch dimension. If your program depends on 734 the batches having the same shape, consider using the 735 @{tf.contrib.data.batch_and_drop_remainder} transformation instead. 736 737 Args: 738 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 739 consecutive elements of this dataset to combine in a single batch. 740 741 Returns: 742 Dataset: A `Dataset`. 743 """ 744 return BatchDataset(self, batch_size) 745 746 def padded_batch(self, batch_size, padded_shapes, padding_values=None): 747 """Combines consecutive elements of this dataset into padded batches. 748 749 Like `Dataset.dense_to_sparse_batch()`, this method combines 750 multiple consecutive elements of this dataset, which might have 751 different shapes, into a single element. The tensors in the 752 resulting element have an additional outer dimension, and are 753 padded to the respective shape in `padded_shapes`. 754 755 Args: 756 batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 757 consecutive elements of this dataset to combine in a single batch. 758 padded_shapes: A nested structure of `tf.TensorShape` or 759 `tf.int64` vector tensor-like objects representing the shape 760 to which the respective component of each input element should 761 be padded prior to batching. Any unknown dimensions 762 (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a 763 tensor-like object) will be padded to the maximum size of that 764 dimension in each batch. 765 padding_values: (Optional.) A nested structure of scalar-shaped 766 `tf.Tensor`, representing the padding values to use for the 767 respective components. Defaults are `0` for numeric types and 768 the empty string for string types. 769 770 Returns: 771 Dataset: A `Dataset`. 772 """ 773 return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values) 774 775 def map(self, map_func, num_parallel_calls=None): 776 """Maps `map_func` across this dataset. 777 778 Args: 779 map_func: A function mapping a nested structure of tensors (having 780 shapes and types defined by `self.output_shapes` and 781 `self.output_types`) to another nested structure of tensors. 782 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 783 representing the number elements to process in parallel. If not 784 specified, elements will be processed sequentially. 785 786 Returns: 787 Dataset: A `Dataset`. 788 """ 789 if num_parallel_calls is None: 790 return MapDataset(self, map_func) 791 else: 792 return ParallelMapDataset(self, map_func, num_parallel_calls) 793 794 def flat_map(self, map_func): 795 """Maps `map_func` across this dataset and flattens the result. 796 797 Args: 798 map_func: A function mapping a nested structure of tensors (having shapes 799 and types defined by `self.output_shapes` and `self.output_types`) to a 800 `Dataset`. 801 802 Returns: 803 Dataset: A `Dataset`. 804 """ 805 return FlatMapDataset(self, map_func) 806 807 def interleave(self, map_func, cycle_length, block_length=1): 808 """Maps `map_func` across this dataset, and interleaves the results. 809 810 For example, you can use `Dataset.interleave()` to process many input files 811 concurrently: 812 813 ```python 814 # Preprocess 4 files concurrently, and interleave blocks of 16 records from 815 # each file. 816 filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...] 817 dataset = (Dataset.from_tensor_slices(filenames) 818 .interleave(lambda x: 819 TextLineDataset(x).map(parse_fn, num_parallel_calls=1), 820 cycle_length=4, block_length=16)) 821 ``` 822 823 The `cycle_length` and `block_length` arguments control the order in which 824 elements are produced. `cycle_length` controls the number of input elements 825 that are processed concurrently. If you set `cycle_length` to 1, this 826 transformation will handle one input element at a time, and will produce 827 identical results = to @{tf.data.Dataset.flat_map}. In general, 828 this transformation will apply `map_func` to `cycle_length` input elements, 829 open iterators on the returned `Dataset` objects, and cycle through them 830 producing `block_length` consecutive elements from each iterator, and 831 consuming the next input element each time it reaches the end of an 832 iterator. 833 834 For example: 835 836 ```python 837 # NOTE: The following examples use `{ ... }` to represent the 838 # contents of a dataset. 839 a = { 1, 2, 3, 4, 5 } 840 841 # NOTE: New lines indicate "block" boundaries. 842 a.interleave(lambda x: Dataset.from_tensors(x).repeat(6), 843 cycle_length=2, block_length=4) == { 844 1, 1, 1, 1, 845 2, 2, 2, 2, 846 1, 1, 847 2, 2, 848 3, 3, 3, 3, 849 4, 4, 4, 4, 850 3, 3, 851 4, 4, 852 5, 5, 5, 5, 853 5, 5, 854 } 855 ``` 856 857 NOTE: The order of elements yielded by this transformation is 858 deterministic, as long as `map_func` is a pure function. If 859 `map_func` contains any stateful operations, the order in which 860 that state is accessed is undefined. 861 862 Args: 863 map_func: A function mapping a nested structure of tensors (having shapes 864 and types defined by `self.output_shapes` and `self.output_types`) to a 865 `Dataset`. 866 cycle_length: The number of elements from this dataset that will be 867 processed concurrently. 868 block_length: The number of consecutive elements to produce from each 869 input element before cycling to another input element. 870 871 Returns: 872 Dataset: A `Dataset`. 873 """ 874 return InterleaveDataset(self, map_func, cycle_length, block_length) 875 876 def filter(self, predicate): 877 """Filters this dataset according to `predicate`. 878 879 Args: 880 predicate: A function mapping a nested structure of tensors (having shapes 881 and types defined by `self.output_shapes` and `self.output_types`) to a 882 scalar `tf.bool` tensor. 883 884 Returns: 885 Dataset: A `Dataset`. 886 """ 887 return FilterDataset(self, predicate) 888 889 def apply(self, transformation_func): 890 """Apply a transformation function to this dataset. 891 892 `apply` enables chaining of custom `Dataset` transformations, which are 893 represented as functions that take one `Dataset` argument and return a 894 transformed `Dataset`. 895 896 For example: 897 898 ``` 899 dataset = (dataset.map(lambda x: x ** 2) 900 .apply(group_by_window(key_func, reduce_func, window_size)) 901 .map(lambda x: x ** 3)) 902 ``` 903 904 Args: 905 transformation_func: A function that takes one `Dataset` argument and 906 returns a `Dataset`. 907 908 Returns: 909 Dataset: The `Dataset` returned by applying `transformation_func` to this 910 dataset. 911 """ 912 dataset = transformation_func(self) 913 if not isinstance(dataset, Dataset): 914 raise TypeError("`transformation_func` must return a Dataset.") 915 return dataset 916 917 918class TensorDataset(Dataset): 919 """A `Dataset` with a single element, viz. a nested structure of tensors.""" 920 921 def __init__(self, tensors): 922 """See `Dataset.from_tensors()` for details.""" 923 super(TensorDataset, self).__init__() 924 with ops.name_scope("tensors"): 925 tensors = nest.pack_sequence_as(tensors, [ 926 sparse_tensor_lib.SparseTensor.from_value(t) 927 if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( 928 t, name="component_%d" % i) 929 for i, t in enumerate(nest.flatten(tensors)) 930 ]) 931 932 self._tensors = sparse.serialize_sparse_tensors(tensors) 933 self._output_classes = sparse.get_classes(tensors) 934 self._output_shapes = nest.pack_sequence_as( 935 tensors, [t.get_shape() for t in nest.flatten(tensors)]) 936 self._output_types = nest.pack_sequence_as( 937 tensors, [t.dtype for t in nest.flatten(tensors)]) 938 939 def _as_variant_tensor(self): 940 return gen_dataset_ops.tensor_dataset( 941 nest.flatten(self._tensors), 942 output_shapes=nest.flatten( 943 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 944 945 @property 946 def output_classes(self): 947 return self._output_classes 948 949 @property 950 def output_shapes(self): 951 return self._output_shapes 952 953 @property 954 def output_types(self): 955 return self._output_types 956 957 958class TensorSliceDataset(Dataset): 959 """A `Dataset` of slices from a nested structure of tensors.""" 960 961 def __init__(self, tensors): 962 """See `Dataset.from_tensor_slices()` for details.""" 963 super(TensorSliceDataset, self).__init__() 964 with ops.name_scope("tensors"): 965 tensors = nest.pack_sequence_as(tensors, [ 966 sparse_tensor_lib.SparseTensor.from_value(t) 967 if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( 968 t, name="component_%d" % i) 969 for i, t in enumerate(nest.flatten(tensors)) 970 ]) 971 flat_tensors = nest.flatten(tensors) 972 973 batch_dim = flat_tensors[0].get_shape()[0] 974 for t in flat_tensors[1:]: 975 batch_dim.assert_is_compatible_with(t.get_shape()[0]) 976 self._tensors = sparse.serialize_many_sparse_tensors(tensors) 977 self._output_classes = sparse.get_classes(tensors) 978 self._output_shapes = nest.pack_sequence_as( 979 tensors, [t.get_shape()[1:] for t in nest.flatten(tensors)]) 980 self._output_types = nest.pack_sequence_as( 981 tensors, [t.dtype for t in nest.flatten(tensors)]) 982 983 def _as_variant_tensor(self): 984 return gen_dataset_ops.tensor_slice_dataset( 985 nest.flatten(self._tensors), 986 output_shapes=nest.flatten( 987 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 988 989 @property 990 def output_classes(self): 991 return self._output_classes 992 993 @property 994 def output_shapes(self): 995 return self._output_shapes 996 997 @property 998 def output_types(self): 999 return self._output_types 1000 1001 1002class SparseTensorSliceDataset(Dataset): 1003 """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows.""" 1004 1005 def __init__(self, sparse_tensor): 1006 """See `Dataset.from_sparse_tensor_slices()` for details.""" 1007 super(SparseTensorSliceDataset, self).__init__() 1008 if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor): 1009 raise TypeError("`sparse_tensor` must be a `tf.SparseTensor` object.") 1010 self._sparse_tensor = sparse_tensor 1011 1012 def _as_variant_tensor(self): 1013 return gen_dataset_ops.sparse_tensor_slice_dataset( 1014 self._sparse_tensor.indices, self._sparse_tensor.values, 1015 self._sparse_tensor.dense_shape) 1016 1017 @property 1018 def output_classes(self): 1019 return (ops.Tensor, ops.Tensor, ops.Tensor) 1020 1021 @property 1022 def output_shapes(self): 1023 indices_shape = self._sparse_tensor.indices.get_shape() 1024 shape_shape = self._sparse_tensor.dense_shape.get_shape() 1025 rank = (indices_shape[1] - 1).merge_with(shape_shape[0] - 1) 1026 num_values = tensor_shape.Dimension(None) 1027 return (tensor_shape.TensorShape([num_values, rank]), 1028 tensor_shape.TensorShape([num_values]), 1029 tensor_shape.TensorShape([rank])) 1030 1031 @property 1032 def output_types(self): 1033 return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64) 1034 1035 1036class ZipDataset(Dataset): 1037 """A `Dataset` that zips its inputs together.""" 1038 1039 def __init__(self, datasets): 1040 """See `Dataset.zip()` for details.""" 1041 super(ZipDataset, self).__init__() 1042 for ds in nest.flatten(datasets): 1043 if not isinstance(ds, Dataset): 1044 if isinstance(ds, list): 1045 message = ("The argument to `Dataset.zip()` must be a nested " 1046 "structure of `Dataset` objects. Nested structures do not " 1047 "support Python lists; please use a tuple instead.") 1048 else: 1049 message = ("The argument to `Dataset.zip()` must be a nested " 1050 "structure of `Dataset` objects.") 1051 raise TypeError(message) 1052 self._datasets = datasets 1053 1054 def _as_variant_tensor(self): 1055 # pylint: disable=protected-access 1056 return gen_dataset_ops.zip_dataset( 1057 [ds._as_variant_tensor() for ds in nest.flatten(self._datasets)], 1058 output_shapes=[ 1059 s 1060 for ds in nest.flatten(self._datasets) 1061 for s in nest.flatten(ds.output_shapes) 1062 ], 1063 output_types=[ 1064 t 1065 for ds in nest.flatten(self._datasets) 1066 for t in nest.flatten(ds.output_types) 1067 ]) 1068 # pylint: enable=protected-access 1069 1070 @property 1071 def output_classes(self): 1072 return nest.pack_sequence_as( 1073 self._datasets, 1074 [ds.output_classes for ds in nest.flatten(self._datasets)]) 1075 1076 @property 1077 def output_shapes(self): 1078 return nest.pack_sequence_as( 1079 self._datasets, 1080 [ds.output_shapes for ds in nest.flatten(self._datasets)]) 1081 1082 @property 1083 def output_types(self): 1084 return nest.pack_sequence_as( 1085 self._datasets, 1086 [ds.output_types for ds in nest.flatten(self._datasets)]) 1087 1088 1089class ConcatenateDataset(Dataset): 1090 """A `Dataset` that concatenates its input with given dataset.""" 1091 1092 def __init__(self, input_dataset, dataset_to_concatenate): 1093 """See `Dataset.concatenate()` for details.""" 1094 super(ConcatenateDataset, self).__init__() 1095 self._input_dataset = input_dataset 1096 self._dataset_to_concatenate = dataset_to_concatenate 1097 nest.assert_same_structure(input_dataset.output_types, 1098 dataset_to_concatenate.output_types) 1099 for a, b in zip( 1100 nest.flatten(input_dataset.output_types), 1101 nest.flatten(dataset_to_concatenate.output_types)): 1102 if a != b: 1103 raise TypeError( 1104 "Two datasets to concatenate have different types %s and %s" % 1105 (input_dataset.output_types, dataset_to_concatenate.output_types)) 1106 1107 def _as_variant_tensor(self): 1108 # pylint: disable=protected-access 1109 return gen_dataset_ops.concatenate_dataset( 1110 self._input_dataset._as_variant_tensor(), 1111 self._dataset_to_concatenate._as_variant_tensor(), 1112 output_shapes=nest.flatten( 1113 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 1114 output_types=nest.flatten( 1115 sparse.as_dense_types(self.output_types, self.output_classes))) 1116 # pylint: enable=protected-access 1117 1118 @property 1119 def output_classes(self): 1120 return self._input_dataset.output_classes 1121 1122 @property 1123 def output_shapes(self): 1124 return nest.pack_sequence_as(self._input_dataset.output_shapes, [ 1125 ts1.most_specific_compatible_shape(ts2) 1126 for (ts1, ts2) in zip( 1127 nest.flatten(self._input_dataset.output_shapes), 1128 nest.flatten(self._dataset_to_concatenate.output_shapes)) 1129 ]) 1130 1131 @property 1132 def output_types(self): 1133 return self._input_dataset.output_types 1134 1135 1136class RepeatDataset(Dataset): 1137 """A `Dataset` that repeats its input several times.""" 1138 1139 def __init__(self, input_dataset, count): 1140 """See `Dataset.repeat()` for details.""" 1141 super(RepeatDataset, self).__init__() 1142 self._input_dataset = input_dataset 1143 if count is None: 1144 self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") 1145 else: 1146 self._count = ops.convert_to_tensor( 1147 count, dtype=dtypes.int64, name="count") 1148 1149 def _as_variant_tensor(self): 1150 return gen_dataset_ops.repeat_dataset( 1151 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1152 count=self._count, 1153 output_shapes=nest.flatten( 1154 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 1155 output_types=nest.flatten( 1156 sparse.as_dense_types(self.output_types, self.output_classes))) 1157 1158 @property 1159 def output_classes(self): 1160 return self._input_dataset.output_classes 1161 1162 @property 1163 def output_shapes(self): 1164 return self._input_dataset.output_shapes 1165 1166 @property 1167 def output_types(self): 1168 return self._input_dataset.output_types 1169 1170 1171class RangeDataset(Dataset): 1172 """A `Dataset` of a step separated range of values.""" 1173 1174 def __init__(self, *args): 1175 """See `Dataset.range()` for details.""" 1176 super(RangeDataset, self).__init__() 1177 self._parse_args(*args) 1178 1179 def _parse_args(self, *args): 1180 if len(args) == 1: 1181 self._start = self._build_tensor(0, "start") 1182 self._stop = self._build_tensor(args[0], "stop") 1183 self._step = self._build_tensor(1, "step") 1184 elif len(args) == 2: 1185 self._start = self._build_tensor(args[0], "start") 1186 self._stop = self._build_tensor(args[1], "stop") 1187 self._step = self._build_tensor(1, "step") 1188 elif len(args) == 3: 1189 self._start = self._build_tensor(args[0], "start") 1190 self._stop = self._build_tensor(args[1], "stop") 1191 self._step = self._build_tensor(args[2], "step") 1192 else: 1193 raise ValueError("Invalid arguments to RangeDataset: %s" % str(args)) 1194 1195 def _build_tensor(self, int64_value, name): 1196 return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name) 1197 1198 def _as_variant_tensor(self): 1199 return gen_dataset_ops.range_dataset( 1200 start=self._start, 1201 stop=self._stop, 1202 step=self._step, 1203 output_shapes=nest.flatten( 1204 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 1205 output_types=nest.flatten( 1206 sparse.as_dense_types(self.output_types, self.output_classes))) 1207 1208 @property 1209 def output_classes(self): 1210 return ops.Tensor 1211 1212 @property 1213 def output_shapes(self): 1214 return tensor_shape.scalar() 1215 1216 @property 1217 def output_types(self): 1218 return dtypes.int64 1219 1220 1221class CacheDataset(Dataset): 1222 """A `Dataset` that caches elements of its input.""" 1223 1224 def __init__(self, input_dataset, filename): 1225 """See `Dataset.cache()` for details.""" 1226 super(CacheDataset, self).__init__() 1227 self._input_dataset = input_dataset 1228 self._filename = ops.convert_to_tensor( 1229 filename, dtype=dtypes.string, name="filename") 1230 1231 def _as_variant_tensor(self): 1232 return gen_dataset_ops.cache_dataset( 1233 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1234 filename=self._filename, 1235 output_shapes=nest.flatten( 1236 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 1237 output_types=nest.flatten( 1238 sparse.as_dense_types(self.output_types, self.output_classes))) 1239 1240 @property 1241 def output_classes(self): 1242 return self._input_dataset.output_classes 1243 1244 @property 1245 def output_shapes(self): 1246 return self._input_dataset.output_shapes 1247 1248 @property 1249 def output_types(self): 1250 return self._input_dataset.output_types 1251 1252 1253class ShuffleDataset(Dataset): 1254 """A `Dataset` that randomly shuffles the elements of its input.""" 1255 1256 def __init__(self, 1257 input_dataset, 1258 buffer_size, 1259 seed=None, 1260 reshuffle_each_iteration=None): 1261 """Randomly shuffles the elements of this dataset. 1262 1263 Args: 1264 input_dataset: The input dataset. 1265 buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the 1266 number of elements from this dataset from which the new 1267 dataset will sample. 1268 seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the 1269 random seed that will be used to create the distribution. See 1270 @{tf.set_random_seed} for behavior. 1271 reshuffle_each_iteration: (Optional.) A boolean, which if true indicates 1272 that the dataset should be pseudorandomly reshuffled each time it is 1273 iterated over. (Defaults to `True`.) 1274 1275 Returns: 1276 A `Dataset`. 1277 1278 Raises: 1279 ValueError: if invalid arguments are provided. 1280 """ 1281 super(ShuffleDataset, self).__init__() 1282 self._input_dataset = input_dataset 1283 self._buffer_size = ops.convert_to_tensor( 1284 buffer_size, dtype=dtypes.int64, name="buffer_size") 1285 seed, seed2 = random_seed.get_seed(seed) 1286 if seed is None: 1287 self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") 1288 else: 1289 self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") 1290 if seed2 is None: 1291 self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") 1292 else: 1293 self._seed2 = ops.convert_to_tensor( 1294 seed2, dtype=dtypes.int64, name="seed2") 1295 if reshuffle_each_iteration is None: 1296 self._reshuffle_each_iteration = True 1297 else: 1298 self._reshuffle_each_iteration = reshuffle_each_iteration 1299 1300 def _as_variant_tensor(self): 1301 return gen_dataset_ops.shuffle_dataset( 1302 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1303 buffer_size=self._buffer_size, 1304 seed=self._seed, 1305 seed2=self._seed2, 1306 reshuffle_each_iteration=self._reshuffle_each_iteration, 1307 output_shapes=nest.flatten( 1308 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 1309 output_types=nest.flatten( 1310 sparse.as_dense_types(self.output_types, self.output_classes))) 1311 1312 @property 1313 def output_classes(self): 1314 return self._input_dataset.output_classes 1315 1316 @property 1317 def output_shapes(self): 1318 return self._input_dataset.output_shapes 1319 1320 @property 1321 def output_types(self): 1322 return self._input_dataset.output_types 1323 1324 1325class TakeDataset(Dataset): 1326 """A `Dataset` containing the first `count` elements from its input.""" 1327 1328 def __init__(self, input_dataset, count): 1329 """See `Dataset.take()` for details.""" 1330 super(TakeDataset, self).__init__() 1331 self._input_dataset = input_dataset 1332 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 1333 1334 def _as_variant_tensor(self): 1335 return gen_dataset_ops.take_dataset( 1336 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1337 count=self._count, 1338 output_shapes=nest.flatten( 1339 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 1340 output_types=nest.flatten( 1341 sparse.as_dense_types(self.output_types, self.output_classes))) 1342 1343 @property 1344 def output_classes(self): 1345 return self._input_dataset.output_classes 1346 1347 @property 1348 def output_shapes(self): 1349 return self._input_dataset.output_shapes 1350 1351 @property 1352 def output_types(self): 1353 return self._input_dataset.output_types 1354 1355 1356class SkipDataset(Dataset): 1357 """A `Dataset` skipping the first `count` elements from its input.""" 1358 1359 def __init__(self, input_dataset, count): 1360 """See `Dataset.skip()` for details.""" 1361 super(SkipDataset, self).__init__() 1362 self._input_dataset = input_dataset 1363 self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count") 1364 1365 def _as_variant_tensor(self): 1366 return gen_dataset_ops.skip_dataset( 1367 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1368 count=self._count, 1369 output_shapes=nest.flatten( 1370 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 1371 output_types=nest.flatten( 1372 sparse.as_dense_types(self.output_types, self.output_classes))) 1373 1374 @property 1375 def output_classes(self): 1376 return self._input_dataset.output_classes 1377 1378 @property 1379 def output_shapes(self): 1380 return self._input_dataset.output_shapes 1381 1382 @property 1383 def output_types(self): 1384 return self._input_dataset.output_types 1385 1386 1387class BatchDataset(Dataset): 1388 """A `Dataset` that batches contiguous elements from its input.""" 1389 1390 def __init__(self, input_dataset, batch_size): 1391 """See `Dataset.batch()` for details.""" 1392 super(BatchDataset, self).__init__() 1393 self._input_dataset = input_dataset 1394 self._batch_size = ops.convert_to_tensor( 1395 batch_size, dtype=dtypes.int64, name="batch_size") 1396 1397 def _as_variant_tensor(self): 1398 return gen_dataset_ops.batch_dataset( 1399 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1400 batch_size=self._batch_size, 1401 output_shapes=nest.flatten( 1402 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 1403 output_types=nest.flatten( 1404 sparse.as_dense_types(self.output_types, self.output_classes))) 1405 1406 @property 1407 def output_classes(self): 1408 return self._input_dataset.output_classes 1409 1410 @property 1411 def output_shapes(self): 1412 input_shapes = self._input_dataset.output_shapes 1413 return nest.pack_sequence_as(input_shapes, [ 1414 tensor_shape.vector(None).concatenate(s) 1415 for s in nest.flatten(self._input_dataset.output_shapes) 1416 ]) 1417 1418 @property 1419 def output_types(self): 1420 return self._input_dataset.output_types 1421 1422 1423def _partial_shape_to_tensor(shape_like): 1424 try: 1425 # First attempt to convert the input to a shape, and return the 1426 # "canonical" tensor representation, which uses `-1` in place of 1427 # `None`. 1428 shape_like = tensor_shape.as_shape(shape_like) 1429 return ops.convert_to_tensor( 1430 [dim if dim is not None else -1 for dim in shape_like.as_list()], 1431 dtype=dtypes.int64) 1432 except (TypeError, ValueError): 1433 # The argument was not trivially convertible to a 1434 # `tf.TensorShape`, so fall back on the conversion to tensor 1435 # machinery. 1436 return ops.convert_to_tensor(shape_like, dtype=dtypes.int64) 1437 1438 1439def _padding_value_to_tensor(value, output_type): 1440 """Converts the padding value to a tensor. 1441 1442 Args: 1443 value: The padding value. 1444 output_type: Its expected dtype. 1445 1446 Returns: 1447 A scalar `Tensor`. 1448 1449 Raises: 1450 ValueError: if the padding value is not a scalar. 1451 TypeError: if the padding value's type does not match `output_type`. 1452 """ 1453 value = ops.convert_to_tensor(value, name="padding_value") 1454 if not value.shape.is_compatible_with(tensor_shape.scalar()): 1455 raise ValueError("Padding value should be a scalar, but is not: %s" % value) 1456 if value.dtype != output_type: 1457 raise TypeError("Padding value tensor (%s) does not match output type: %s" % 1458 (value, output_type)) 1459 return value 1460 1461 1462def _default_padding(input_dataset): 1463 1464 def make_zero(t): 1465 if t.base_dtype == dtypes.string: 1466 return "" 1467 elif t.base_dtype == dtypes.variant: 1468 raise TypeError("Unable to create padding for field of type 'variant'") 1469 else: 1470 return np.zeros_like(t.as_numpy_dtype()) 1471 1472 return nest.map_structure(make_zero, input_dataset.output_types) 1473 1474 1475class PaddedBatchDataset(Dataset): 1476 """A `Dataset` that batches and pads contiguous elements from its input.""" 1477 1478 def __init__(self, input_dataset, batch_size, padded_shapes, padding_values): 1479 """See `Dataset.batch()` for details.""" 1480 super(PaddedBatchDataset, self).__init__() 1481 if sparse.any_sparse(input_dataset.output_classes): 1482 # TODO(b/63669786): support batching of sparse tensors 1483 raise TypeError( 1484 "Batching of padded sparse tensors is not currently supported") 1485 self._input_dataset = input_dataset 1486 self._batch_size = ops.convert_to_tensor( 1487 batch_size, dtype=dtypes.int64, name="batch_size") 1488 padding_values = ( 1489 padding_values 1490 if padding_values is not None else _default_padding(input_dataset)) 1491 self._padded_shapes = nest.map_structure_up_to( 1492 input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes) 1493 self._padding_values = nest.map_structure_up_to( 1494 input_dataset.output_shapes, _padding_value_to_tensor, padding_values, 1495 input_dataset.output_types) 1496 1497 def _as_variant_tensor(self): 1498 return gen_dataset_ops.padded_batch_dataset( 1499 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1500 batch_size=self._batch_size, 1501 padded_shapes=[ 1502 ops.convert_to_tensor(s, dtype=dtypes.int64) 1503 for s in nest.flatten(self._padded_shapes) 1504 ], 1505 padding_values=nest.flatten(self._padding_values), 1506 output_shapes=nest.flatten( 1507 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 1508 1509 @property 1510 def output_classes(self): 1511 return self._input_dataset.output_classes 1512 1513 @property 1514 def output_shapes(self): 1515 1516 def _padded_shape_to_batch_shape(s): 1517 return tensor_shape.vector(None).concatenate( 1518 tensor_util.constant_value_as_shape(s)) 1519 1520 return nest.map_structure(_padded_shape_to_batch_shape, self._padded_shapes) 1521 1522 @property 1523 def output_types(self): 1524 return self._input_dataset.output_types 1525 1526 1527def _should_unpack_args(args): 1528 """Returns `True` if `args` should be `*args` when passed to a callable.""" 1529 return type(args) is tuple # pylint: disable=unidiomatic-typecheck 1530 1531 1532class MapDataset(Dataset): 1533 """A `Dataset` that maps a function over elements in its input.""" 1534 1535 def __init__(self, input_dataset, map_func): 1536 """See `Dataset.map()` for details.""" 1537 super(MapDataset, self).__init__() 1538 self._input_dataset = input_dataset 1539 1540 self._output_classes = None 1541 self._output_shapes = None 1542 self._output_types = None 1543 1544 @function.Defun(*nest.flatten( 1545 sparse.as_dense_types(input_dataset.output_types, 1546 input_dataset.output_classes))) 1547 def tf_map_func(*args): 1548 """A wrapper for Defun that facilitates shape inference.""" 1549 # Pass in shape information from the input_dataset. 1550 dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, 1551 input_dataset.output_classes) 1552 for arg, shape in zip(args, nest.flatten(dense_shapes)): 1553 arg.set_shape(shape) 1554 1555 nested_args = nest.pack_sequence_as(input_dataset.output_types, args) 1556 nested_args = sparse.deserialize_sparse_tensors( 1557 nested_args, input_dataset.output_types, input_dataset.output_shapes, 1558 input_dataset.output_classes) 1559 if _should_unpack_args(nested_args): 1560 ret = map_func(*nested_args) 1561 else: 1562 ret = map_func(nested_args) 1563 1564 # If `map_func` returns a list of tensors, `nest.flatten()` and 1565 # `ops.convert_to_tensor()` would conspire to attempt to stack 1566 # those tensors into a single tensor, because the customized 1567 # version of `nest.flatten()` does not recurse into lists. Since 1568 # it is more likely that the list arose from returning the 1569 # result of an operation (such as `tf.py_func()`) that returns a 1570 # list of not-necessarily-stackable tensors, we treat the 1571 # returned value is a `tuple` instead. A user wishing to pack 1572 # the return value into a single tensor can use an explicit 1573 # `tf.stack()` before returning. 1574 if isinstance(ret, list): 1575 ret = tuple(ret) 1576 1577 # Convert any `SparseTensorValue`s to `SparseTensor`s. 1578 ret = nest.pack_sequence_as(ret, [ 1579 sparse_tensor_lib.SparseTensor.from_value(t) 1580 if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret) 1581 ]) 1582 1583 self._output_classes = sparse.get_classes(ret) 1584 self._output_shapes = nest.pack_sequence_as( 1585 ret, [t.get_shape() for t in nest.flatten(ret)]) 1586 self._output_types = nest.pack_sequence_as( 1587 ret, [t.dtype for t in nest.flatten(ret)]) 1588 1589 # Serialize any sparse tensors and convert result to tensors. 1590 ret = nest.pack_sequence_as(ret, [ 1591 ops.convert_to_tensor(t) 1592 for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) 1593 ]) 1594 return nest.flatten(ret) 1595 1596 self._map_func = tf_map_func 1597 self._map_func.add_to_graph(ops.get_default_graph()) 1598 1599 def _as_variant_tensor(self): 1600 input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access 1601 return gen_dataset_ops.map_dataset( 1602 input_t, 1603 self._map_func.captured_inputs, 1604 f=self._map_func, 1605 output_types=nest.flatten( 1606 sparse.as_dense_types(self.output_types, self.output_classes)), 1607 output_shapes=nest.flatten( 1608 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 1609 1610 @property 1611 def output_classes(self): 1612 return self._output_classes 1613 1614 @property 1615 def output_shapes(self): 1616 return self._output_shapes 1617 1618 @property 1619 def output_types(self): 1620 return self._output_types 1621 1622 1623class ParallelMapDataset(MapDataset): 1624 """A `Dataset` that maps a function over elements in its input in parallel.""" 1625 1626 def __init__(self, input_dataset, map_func, num_parallel_calls): 1627 """See `Dataset.map()` for details.""" 1628 super(ParallelMapDataset, self).__init__(input_dataset, map_func) 1629 1630 self._num_parallel_calls = ops.convert_to_tensor( 1631 num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls") 1632 1633 def _as_variant_tensor(self): 1634 input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access 1635 # pylint: disable=protected-access 1636 return gen_dataset_ops.parallel_map_dataset( 1637 input_t, 1638 self._map_func.captured_inputs, 1639 f=self._map_func, 1640 num_parallel_calls=self._num_parallel_calls, 1641 output_types=nest.flatten( 1642 sparse.as_dense_types(self.output_types, self.output_classes)), 1643 output_shapes=nest.flatten( 1644 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 1645 # pylint: enable=protected-access 1646 1647 1648class FlatMapDataset(Dataset): 1649 """A `Dataset` that maps a function over its input and flattens the result.""" 1650 1651 def __init__(self, input_dataset, map_func): 1652 """See `Dataset.flat_map()` for details.""" 1653 super(FlatMapDataset, self).__init__() 1654 self._input_dataset = input_dataset 1655 1656 @function.Defun(*nest.flatten( 1657 sparse.as_dense_types(input_dataset.output_types, 1658 input_dataset.output_classes))) 1659 def tf_map_func(*args): 1660 """A wrapper for Defun that facilitates shape inference.""" 1661 # Pass in shape information from the input_dataset. 1662 dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, 1663 input_dataset.output_classes) 1664 for arg, shape in zip(args, nest.flatten(dense_shapes)): 1665 arg.set_shape(shape) 1666 1667 nested_args = nest.pack_sequence_as(input_dataset.output_types, args) 1668 nested_args = sparse.deserialize_sparse_tensors( 1669 nested_args, input_dataset.output_types, input_dataset.output_shapes, 1670 input_dataset.output_classes) 1671 if _should_unpack_args(nested_args): 1672 dataset = map_func(*nested_args) 1673 else: 1674 dataset = map_func(nested_args) 1675 1676 if not isinstance(dataset, Dataset): 1677 raise TypeError("`map_func` must return a `Dataset` object.") 1678 1679 self._output_classes = dataset.output_classes 1680 self._output_types = dataset.output_types 1681 self._output_shapes = dataset.output_shapes 1682 1683 return dataset._as_variant_tensor() # pylint: disable=protected-access 1684 1685 self._map_func = tf_map_func 1686 self._map_func.add_to_graph(ops.get_default_graph()) 1687 1688 def _as_variant_tensor(self): 1689 return gen_dataset_ops.flat_map_dataset( 1690 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1691 self._map_func.captured_inputs, 1692 f=self._map_func, 1693 output_types=nest.flatten( 1694 sparse.as_dense_types(self.output_types, self.output_classes)), 1695 output_shapes=nest.flatten( 1696 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 1697 1698 @property 1699 def output_classes(self): 1700 return self._output_classes 1701 1702 @property 1703 def output_shapes(self): 1704 return self._output_shapes 1705 1706 @property 1707 def output_types(self): 1708 return self._output_types 1709 1710 1711class InterleaveDataset(Dataset): 1712 """A `Dataset` that maps a function over its input and interleaves the result. 1713 """ 1714 1715 def __init__(self, input_dataset, map_func, cycle_length, block_length): 1716 """See `Dataset.interleave()` for details.""" 1717 super(InterleaveDataset, self).__init__() 1718 self._input_dataset = input_dataset 1719 1720 @function.Defun(*nest.flatten( 1721 sparse.as_dense_types(input_dataset.output_types, 1722 input_dataset.output_classes))) 1723 def tf_map_func(*args): 1724 """A wrapper for Defun that facilitates shape inference.""" 1725 # Pass in shape information from the input_dataset. 1726 dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, 1727 input_dataset.output_classes) 1728 for arg, shape in zip(args, nest.flatten(dense_shapes)): 1729 arg.set_shape(shape) 1730 1731 nested_args = nest.pack_sequence_as(input_dataset.output_types, args) 1732 nested_args = sparse.deserialize_sparse_tensors( 1733 nested_args, input_dataset.output_types, input_dataset.output_shapes, 1734 input_dataset.output_classes) 1735 if _should_unpack_args(nested_args): 1736 dataset = map_func(*nested_args) 1737 else: 1738 dataset = map_func(nested_args) 1739 1740 if not isinstance(dataset, Dataset): 1741 raise TypeError("`map_func` must return a `Dataset` object.") 1742 1743 self._output_classes = dataset.output_classes 1744 self._output_types = dataset.output_types 1745 self._output_shapes = dataset.output_shapes 1746 1747 return dataset._as_variant_tensor() # pylint: disable=protected-access 1748 1749 self._map_func = tf_map_func 1750 self._map_func.add_to_graph(ops.get_default_graph()) 1751 1752 self._cycle_length = ops.convert_to_tensor( 1753 cycle_length, dtype=dtypes.int64, name="cycle_length") 1754 self._block_length = ops.convert_to_tensor( 1755 block_length, dtype=dtypes.int64, name="block_length") 1756 1757 def _as_variant_tensor(self): 1758 return gen_dataset_ops.interleave_dataset( 1759 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1760 self._map_func.captured_inputs, 1761 self._cycle_length, 1762 self._block_length, 1763 f=self._map_func, 1764 output_types=nest.flatten( 1765 sparse.as_dense_types(self.output_types, self.output_classes)), 1766 output_shapes=nest.flatten( 1767 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 1768 1769 @property 1770 def output_classes(self): 1771 return self._output_classes 1772 1773 @property 1774 def output_shapes(self): 1775 return self._output_shapes 1776 1777 @property 1778 def output_types(self): 1779 return self._output_types 1780 1781 1782class FilterDataset(Dataset): 1783 """A `Dataset` that filters its input according to a predicate function.""" 1784 1785 def __init__(self, input_dataset, predicate): 1786 """See `Dataset.filter()` for details.""" 1787 super(FilterDataset, self).__init__() 1788 self._input_dataset = input_dataset 1789 1790 @function.Defun(*nest.flatten( 1791 sparse.as_dense_types(input_dataset.output_types, 1792 input_dataset.output_classes))) 1793 def tf_predicate(*args): 1794 """A wrapper for Defun that facilitates shape inference.""" 1795 # Pass in shape information from the input_dataset. 1796 dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, 1797 input_dataset.output_classes) 1798 for arg, shape in zip(args, nest.flatten(dense_shapes)): 1799 arg.set_shape(shape) 1800 1801 nested_args = nest.pack_sequence_as(input_dataset.output_types, args) 1802 nested_args = sparse.deserialize_sparse_tensors( 1803 nested_args, input_dataset.output_types, input_dataset.output_shapes, 1804 input_dataset.output_classes) 1805 if _should_unpack_args(nested_args): 1806 ret = predicate(*nested_args) 1807 else: 1808 ret = predicate(nested_args) 1809 1810 ret = ops.convert_to_tensor(ret, dtype=dtypes.bool) 1811 if not (ret.dtype == dtypes.bool and 1812 ret.shape.is_compatible_with(tensor_shape.scalar())): 1813 raise ValueError("`predicate` must return a scalar boolean tensor.") 1814 1815 return ret 1816 1817 self._predicate = tf_predicate 1818 self._predicate.add_to_graph(ops.get_default_graph()) 1819 1820 def _as_variant_tensor(self): 1821 return gen_dataset_ops.filter_dataset( 1822 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1823 other_arguments=self._predicate.captured_inputs, 1824 predicate=self._predicate, 1825 output_types=nest.flatten( 1826 sparse.as_dense_types(self.output_types, self.output_classes)), 1827 output_shapes=nest.flatten( 1828 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 1829 1830 @property 1831 def output_classes(self): 1832 return self._input_dataset.output_classes 1833 1834 @property 1835 def output_shapes(self): 1836 return self._input_dataset.output_shapes 1837 1838 @property 1839 def output_types(self): 1840 return self._input_dataset.output_types 1841 1842 1843class PrefetchDataset(Dataset): 1844 """A `Dataset` that asynchronously prefetches its input.""" 1845 1846 def __init__(self, input_dataset, buffer_size): 1847 """See `Dataset.prefetch()` for details.""" 1848 super(PrefetchDataset, self).__init__() 1849 self._input_dataset = input_dataset 1850 self._buffer_size = ops.convert_to_tensor( 1851 buffer_size, dtype=dtypes.int64, name="buffer_size") 1852 1853 def _as_variant_tensor(self): 1854 return gen_dataset_ops.prefetch_dataset( 1855 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 1856 buffer_size=self._buffer_size, 1857 output_shapes=nest.flatten( 1858 sparse.as_dense_shapes(self.output_shapes, self.output_classes)), 1859 output_types=nest.flatten( 1860 sparse.as_dense_types(self.output_types, self.output_classes))) 1861 1862 @property 1863 def output_classes(self): 1864 return self._input_dataset.output_classes 1865 1866 @property 1867 def output_shapes(self): 1868 return self._input_dataset.output_shapes 1869 1870 @property 1871 def output_types(self): 1872 return self._input_dataset.output_types 1873