1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Datasets."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import collections
22import threading
23
24import numpy as np
25import six
26
27from tensorflow.python.data.ops import iterator_ops
28from tensorflow.python.data.util import nest
29from tensorflow.python.data.util import sparse
30from tensorflow.python.eager import context
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import function
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import random_seed
36from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.ops import gen_dataset_ops
40from tensorflow.python.ops import gen_io_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import script_ops
43from tensorflow.python.util import deprecation
44from tensorflow.python.util.tf_export import tf_export
45
46
47@tf_export("data.Dataset")
48class Dataset(object):
49  """Represents a potentially large set of elements.
50
51  A `Dataset` can be used to represent an input pipeline as a
52  collection of elements (nested structures of tensors) and a "logical
53  plan" of transformations that act on those elements.
54  """
55  __metaclass__ = abc.ABCMeta
56
57  def __init__(self):
58    pass
59
60  @abc.abstractmethod
61  def _as_variant_tensor(self):
62    """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
63
64    Returns:
65      A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
66    """
67    raise NotImplementedError("Dataset._as_variant_tensor")
68
69  def make_initializable_iterator(self, shared_name=None):
70    """Creates an `Iterator` for enumerating the elements of this dataset.
71
72    Note: The returned iterator will be in an uninitialized state,
73    and you must run the `iterator.initializer` operation before using it:
74
75    ```python
76    dataset = ...
77    iterator = dataset.make_initializable_iterator()
78    # ...
79    sess.run(iterator.initializer)
80    ```
81
82    Args:
83      shared_name: (Optional.) If non-empty, the returned iterator will be
84        shared under the given name across multiple sessions that share the
85        same devices (e.g. when using a remote server).
86
87    Returns:
88      An `Iterator` over the elements of this dataset.
89
90    Raises:
91      RuntimeError: If eager execution is enabled.
92    """
93    if context.in_eager_mode():
94      raise RuntimeError(
95          "dataset.make_initializable_iterator is not supported when eager "
96          "execution is enabled.")
97    if shared_name is None:
98      shared_name = ""
99    iterator_resource = gen_dataset_ops.iterator(
100        container="",
101        shared_name=shared_name,
102        output_types=nest.flatten(
103            sparse.as_dense_types(self.output_types, self.output_classes)),
104        output_shapes=nest.flatten(
105            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
106    with ops.colocate_with(iterator_resource):
107      initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
108                                                  iterator_resource)
109    return iterator_ops.Iterator(iterator_resource, initializer,
110                                 self.output_types, self.output_shapes,
111                                 self.output_classes)
112
113  def make_one_shot_iterator(self):
114    """Creates an `Iterator` for enumerating the elements of this dataset.
115
116    Note: The returned iterator will be initialized automatically.
117    A "one-shot" iterator does not currently support re-initialization.
118
119    Returns:
120      An `Iterator` over the elements of this dataset.
121
122    Raises:
123      RuntimeError: If eager execution is enabled.
124    """
125    if context.in_eager_mode():
126      raise RuntimeError(
127          "dataset.make_one_shot_iterator is not supported when eager "
128          "execution is enabled.")
129    # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
130    # a 0-argument function.
131    @function.Defun(capture_by_value=True)
132    def _make_dataset():
133      return self._as_variant_tensor()  # pylint: disable=protected-access
134
135    try:
136      _make_dataset.add_to_graph(ops.get_default_graph())
137    except ValueError as err:
138      if "Cannot capture a stateful node" in str(err):
139        raise ValueError(
140            "Failed to create a one-shot iterator for a dataset. "
141            "`Dataset.make_one_shot_iterator()` does not support datasets that "
142            "capture stateful objects, such as a `Variable` or `LookupTable`. "
143            "In these cases, use `Dataset.make_initializable_iterator()`. "
144            "(Original error: %s)" % err)
145      else:
146        six.reraise(ValueError, err)
147
148    return iterator_ops.Iterator(
149        gen_dataset_ops.one_shot_iterator(
150            dataset_factory=_make_dataset,
151            output_types=nest.flatten(
152                sparse.as_dense_types(self.output_types, self.output_classes)),
153            output_shapes=nest.flatten(
154                sparse.as_dense_shapes(self.output_shapes,
155                                       self.output_classes))), None,
156        self.output_types, self.output_shapes, self.output_classes)
157
158  @abc.abstractproperty
159  def output_classes(self):
160    """Returns the class of each component of an element of this dataset.
161
162    The expected values are `tf.Tensor` and `tf.SparseTensor`.
163
164    Returns:
165      A nested structure of Python `type` objects corresponding to each
166      component of an element of this dataset.
167    """
168    raise NotImplementedError("Dataset.output_classes")
169
170  @abc.abstractproperty
171  def output_shapes(self):
172    """Returns the shape of each component of an element of this dataset.
173
174    Returns:
175      A nested structure of `tf.TensorShape` objects corresponding to each
176      component of an element of this dataset.
177    """
178    raise NotImplementedError("Dataset.output_shapes")
179
180  @abc.abstractproperty
181  def output_types(self):
182    """Returns the type of each component of an element of this dataset.
183
184    Returns:
185      A nested structure of `tf.DType` objects corresponding to each component
186      of an element of this dataset.
187    """
188    raise NotImplementedError("Dataset.output_types")
189
190  def __repr__(self):
191    output_shapes = nest.map_structure(str, self.output_shapes)
192    output_shapes = str(output_shapes).replace("'", "")
193    output_types = nest.map_structure(repr, self.output_types)
194    output_types = str(output_types).replace("'", "")
195    return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes,
196                                            output_types))
197
198  @staticmethod
199  def from_tensors(tensors):
200    """Creates a `Dataset` with a single element, comprising the given tensors.
201
202    Args:
203      tensors: A nested structure of tensors.
204
205    Returns:
206      Dataset: A `Dataset`.
207    """
208    return TensorDataset(tensors)
209
210  @staticmethod
211  def from_tensor_slices(tensors):
212    """Creates a `Dataset` whose elements are slices of the given tensors.
213
214    Args:
215      tensors: A nested structure of tensors, each having the same size in the
216        0th dimension.
217
218    Returns:
219      Dataset: A `Dataset`.
220    """
221    return TensorSliceDataset(tensors)
222
223  @staticmethod
224  @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
225  def from_sparse_tensor_slices(sparse_tensor):
226    """Splits each rank-N `tf.SparseTensor` in this dataset row-wise.
227
228    Args:
229      sparse_tensor: A `tf.SparseTensor`.
230
231    Returns:
232      Dataset: A `Dataset` of rank-(N-1) sparse tensors.
233    """
234    return SparseTensorSliceDataset(sparse_tensor)
235
236  class _GeneratorState(object):
237    """Stores outstanding iterators created from a Python generator.
238
239    This class keeps track of potentially multiple iterators that may have
240    been created from a generator, e.g. in the case that the dataset is
241    repeated, or nested within a parallel computation.
242    """
243
244    def __init__(self, generator):
245      self._generator = generator
246      self._lock = threading.Lock()
247      self._next_id = 0  # GUARDED_BY(self._lock)
248      self._iterators = collections.defaultdict(lambda: iter(generator()))
249
250    def get_next_id(self):
251      with self._lock:
252        ret = self._next_id
253        self._next_id += 1
254      # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
255      # casting in `py_func()` will create an array of `np.int32` on Windows,
256      # leading to a runtime error.
257      return np.array(ret, dtype=np.int64)
258
259    def get_iterator(self, iterator_id):
260      return self._iterators[iterator_id]
261
262    def iterator_completed(self, iterator_id):
263      del self._iterators[iterator_id]
264
265  @staticmethod
266  def from_generator(generator, output_types, output_shapes=None):
267    """Creates a `Dataset` whose elements are generated by `generator`.
268
269    The `generator` argument must be a callable object that returns
270    an object that support the `iter()` protocol (e.g. a generator function).
271    The elements generated by `generator` must be compatible with the given
272    `output_types` and (optional) `output_shapes` arguments.
273
274    For example:
275
276    ```python
277    import itertools
278
279    def gen():
280      for i in itertools.count(1):
281        yield (i, [1] * i)
282
283    ds = Dataset.from_generator(
284        gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
285    value = ds.make_one_shot_iterator().get_next()
286
287    sess.run(value)  # (1, array([1]))
288    sess.run(value)  # (2, array([1, 1]))
289    ```
290
291    NOTE: The current implementation of `Dataset.from_generator()` uses
292    @{tf.py_func} and inherits the same constraints. In particular, it
293    requires the `Dataset`- and `Iterator`-related operations to be placed
294    on a device in the same process as the Python program that called
295    `Dataset.from_generator()`. The body of `generator` will not be
296    serialized in a `GraphDef`, and you should not use this method if you
297    need to serialize your model and restore it in a different environment.
298
299    NOTE: If `generator` depends on mutable global variables or other external
300    state, be aware that the runtime may invoke `generator` multiple times
301    (in order to support repeating the `Dataset`) and at any time
302    between the call to `Dataset.from_generator()` and the production of the
303    first element from the generator. Mutating global variables or external
304    state can cause undefined behavior, and we recommend that you explicitly
305    cache any external state in `generator` before calling
306    `Dataset.from_generator()`.
307
308    Args:
309      generator: A callable object that takes no arguments and returns an
310        object that supports the `iter()` protocol.
311      output_types: A nested structure of `tf.DType` objects corresponding to
312        each component of an element yielded by `generator`.
313      output_shapes: (Optional.) A nested structure of `tf.TensorShape`
314        objects corresponding to each component of an element yielded by
315        `generator`.
316
317    Returns:
318      Dataset: A `Dataset`.
319    """
320    if not callable(generator):
321      raise TypeError("`generator` must be callable.")
322    if output_shapes is None:
323      output_shapes = nest.map_structure(
324          lambda _: tensor_shape.TensorShape(None), output_types)
325    else:
326      output_shapes = nest.map_structure_up_to(
327          output_types, tensor_shape.as_shape, output_shapes)
328
329    flattened_types = nest.flatten(output_types)
330    flattened_shapes = nest.flatten(output_shapes)
331
332    generator_state = Dataset._GeneratorState(generator)
333
334    def get_iterator_id_map_fn(unused_dummy):
335      """Creates a unique `iterator_id` for each pass over the dataset.
336
337      The "iterator_id" disambiguates between multiple concurrently
338      existing iterators.
339
340      Args:
341        unused_dummy: Ignored value.
342
343      Returns:
344        A `tf.int64` tensor whose value uniquely identifies an iterator in
345        `generator_state`.
346      """
347      return script_ops.py_func(
348          generator_state.get_next_id, [], dtypes.int64, stateful=True)
349
350    def generator_map_fn(iterator_id_t):
351      """Generates the next element from iterator with ID `iterator_id_t`.
352
353      We map this function across an infinite repetition of the
354      `iterator_id_t`, and raise `StopIteration` to terminate the iteration.
355
356      Args:
357        iterator_id_t: A `tf.int64` tensor whose value uniquely identifies
358          the iterator in `generator_state` from which to generate an element.
359
360      Returns:
361        A nested structure of tensors representing an element from the iterator.
362      """
363
364      def generator_py_func(iterator_id):
365        """A `py_func` that will be called to invoke the iterator."""
366        try:
367          values = next(generator_state.get_iterator(iterator_id))
368        except StopIteration:
369          generator_state.iterator_completed(iterator_id)
370          raise StopIteration("Iteration finished.")
371
372        # Use the same _convert function from the py_func() implementation to
373        # convert the returned values to arrays early, so that we can inspect
374        # their values.
375        # pylint: disable=protected-access
376        ret_arrays = [
377            script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype)
378            for ret, dtype in zip(
379                nest.flatten_up_to(output_types, values), flattened_types)
380        ]
381        # pylint: enable=protected-access
382
383        # Additional type and shape checking to ensure that the components
384        # of the generated element match the `output_types` and `output_shapes`
385        # arguments.
386        for (ret_array, expected_dtype, expected_shape) in zip(
387            ret_arrays, flattened_types, flattened_shapes):
388          if ret_array.dtype != expected_dtype.as_numpy_dtype:
389            raise TypeError(
390                "`generator` yielded an element of type %s where an element "
391                "of type %s was expected." % (ret_array.dtype,
392                                              expected_dtype.as_numpy_dtype))
393          if not expected_shape.is_compatible_with(ret_array.shape):
394            raise ValueError(
395                "`generator` yielded an element of shape %s where an element "
396                "of shape %s was expected." % (ret_array.shape, expected_shape))
397
398        return ret_arrays
399
400      flat_values = script_ops.py_func(
401          generator_py_func, [iterator_id_t], flattened_types, stateful=True)
402
403      # The `py_func()` op drops the inferred shapes, so we add them back in
404      # here.
405      if output_shapes is not None:
406        for ret_t, shape in zip(flat_values, flattened_shapes):
407          ret_t.set_shape(shape)
408
409      return nest.pack_sequence_as(output_types, flat_values)
410
411    # This function associates each traversal of `generator` with a unique
412    # iterator ID.
413    def flat_map_fn(iterator_id_t):
414      # First, generate an infinite dataset containing the iterator ID repeated
415      # forever.
416      repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None)
417
418      # The `generator_map_fn` gets the next element from the iterator with the
419      # relevant ID, and raises StopIteration when that iterator contains no
420      # more elements.
421      return repeated_id.map(generator_map_fn)
422
423    # A single-element dataset that, each time it is evaluated, contains a
424    # freshly-generated and unique (for the returned dataset) int64
425    # ID that will be used to identify the appropriate Python state, which
426    # is encapsulated in `generator_state`, and captured in
427    # `get_iterator_id_map_fn`.
428    dummy = 0
429    id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn)
430
431    # A dataset that contains all of the elements generated by a
432    # single iterator created from `generator`, identified by the
433    # iterator ID contained in `id_dataset`. Lifting the iteration
434    # into a flat_map here enables multiple repetitions and/or nested
435    # versions of the returned dataset to be created, because it forces
436    # the generation of a new ID for each version.
437    return id_dataset.flat_map(flat_map_fn)
438
439  @staticmethod
440  def range(*args):
441    """Creates a `Dataset` of a step-separated range of values.
442
443    For example:
444
445    ```python
446    Dataset.range(5) == [0, 1, 2, 3, 4]
447    Dataset.range(2, 5) == [2, 3, 4]
448    Dataset.range(1, 5, 2) == [1, 3]
449    Dataset.range(1, 5, -2) == []
450    Dataset.range(5, 1) == []
451    Dataset.range(5, 1, -2) == [5, 3]
452    ```
453
454    Args:
455      *args: follow same semantics as python's xrange.
456        len(args) == 1 -> start = 0, stop = args[0], step = 1
457        len(args) == 2 -> start = args[0], stop = args[1], step = 1
458        len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
459
460    Returns:
461      Dataset: A `RangeDataset`.
462
463    Raises:
464      ValueError: if len(args) == 0.
465    """
466    return RangeDataset(*args)
467
468  @staticmethod
469  def zip(datasets):
470    """Creates a `Dataset` by zipping together the given datasets.
471
472    This method has similar semantics to the built-in `zip()` function
473    in Python, with the main difference being that the `datasets`
474    argument can be an arbitrary nested structure of `Dataset` objects.
475    For example:
476
477    ```python
478    # NOTE: The following examples use `{ ... }` to represent the
479    # contents of a dataset.
480    a = { 1, 2, 3 }
481    b = { 4, 5, 6 }
482    c = { (7, 8), (9, 10), (11, 12) }
483    d = { 13, 14 }
484
485    # The nested structure of the `datasets` argument determines the
486    # structure of elements in the resulting dataset.
487    Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) }
488    Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) }
489
490    # The `datasets` argument may contain an arbitrary number of
491    # datasets.
492    Dataset.zip((a, b, c)) == { (1, 4, (7, 8)),
493                                (2, 5, (9, 10)),
494                                (3, 6, (11, 12)) }
495
496    # The number of elements in the resulting dataset is the same as
497    # the size of the smallest dataset in `datasets`.
498    Dataset.zip((a, d)) == { (1, 13), (2, 14) }
499    ```
500
501    Args:
502      datasets: A nested structure of datasets.
503
504    Returns:
505      Dataset: A `Dataset`.
506    """
507    return ZipDataset(datasets)
508
509  def concatenate(self, dataset):
510    """Creates a `Dataset` by concatenating given dataset with this dataset.
511
512    ```python
513    # NOTE: The following examples use `{ ... }` to represent the
514    # contents of a dataset.
515    a = { 1, 2, 3 }
516    b = { 4, 5, 6, 7 }
517
518    # Input dataset and dataset to be concatenated should have same
519    # nested structures and output types.
520    # c = { (8, 9), (10, 11), (12, 13) }
521    # d = { 14.0, 15.0, 16.0 }
522    # a.concatenate(c) and a.concatenate(d) would result in error.
523
524    a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 }
525    ```
526
527    Args:
528      dataset: `Dataset` to be concatenated.
529
530    Returns:
531      Dataset: A `Dataset`.
532    """
533    return ConcatenateDataset(self, dataset)
534
535  def prefetch(self, buffer_size):
536    """Creates a `Dataset` that prefetches elements from this dataset.
537
538    Args:
539      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
540        maximum number elements that will be buffered when prefetching.
541
542    Returns:
543      Dataset: A `Dataset`.
544    """
545    return PrefetchDataset(self, buffer_size)
546
547  @staticmethod
548  def list_files(file_pattern):
549    """A dataset of all files matching a pattern.
550
551    Example:
552      If we had the following files on our filesystem:
553        - /path/to/dir/a.txt
554        - /path/to/dir/b.py
555        - /path/to/dir/c.py
556      If we pass "/path/to/dir/*.py" as the directory, the dataset would
557      produce:
558        - /path/to/dir/b.py
559        - /path/to/dir/c.py
560
561    NOTE: The order of the file names returned can be non-deterministic.
562
563    Args:
564      file_pattern: A string or scalar string `tf.Tensor`, representing
565        the filename pattern that will be matched.
566
567    Returns:
568     Dataset: A `Dataset` of strings corresponding to file names.
569    """
570    return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern))
571
572  def repeat(self, count=None):
573    """Repeats this dataset `count` times.
574
575    NOTE: If this dataset is a function of global state (e.g. a random number
576    generator), then different repetitions may produce different elements.
577
578    Args:
579      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
580        number of times the dataset should be repeated. The default behavior
581        (if `count` is `None` or `-1`) is for the dataset be repeated
582        indefinitely.
583
584    Returns:
585      Dataset: A `Dataset`.
586    """
587    return RepeatDataset(self, count)
588
589  def _enumerate(self, start=0):
590
591    max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
592    return Dataset.zip((Dataset.range(start, max_value), self))
593
594  def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
595    """Randomly shuffles the elements of this dataset.
596
597    Args:
598      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
599        number of elements from this dataset from which the new
600        dataset will sample.
601      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
602        random seed that will be used to create the distribution. See
603        @{tf.set_random_seed} for behavior.
604      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
605        that the dataset should be pseudorandomly reshuffled each time it is
606        iterated over. (Defaults to `True`.)
607
608    Returns:
609      Dataset: A `Dataset`.
610    """
611    return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)
612
613  def cache(self, filename=""):
614    """Caches the elements in this dataset.
615
616    Args:
617      filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
618        directory on the filesystem to use for caching tensors in this Dataset.
619        If a filename is not provided, the dataset will be cached in memory.
620
621    Returns:
622      Dataset: A `Dataset`.
623    """
624    return CacheDataset(self, filename)
625
626  def take(self, count):
627    """Creates a `Dataset` with at most `count` elements from this dataset.
628
629    Args:
630      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
631        elements of this dataset that should be taken to form the new dataset.
632        If `count` is -1, or if `count` is greater than the size of this
633        dataset, the new dataset will contain all elements of this dataset.
634
635    Returns:
636      Dataset: A `Dataset`.
637    """
638    return TakeDataset(self, count)
639
640  def skip(self, count):
641    """Creates a `Dataset` that skips `count` elements from this dataset.
642
643    Args:
644      count: A `tf.int64` scalar `tf.Tensor`, representing the number
645        of elements of this dataset that should be skipped to form the
646        new dataset.  If `count` is greater than the size of this
647        dataset, the new dataset will contain no elements.  If `count`
648        is -1, skips the entire dataset.
649
650    Returns:
651      Dataset: A `Dataset`.
652    """
653    return SkipDataset(self, count)
654
655  def shard(self, num_shards, index):
656    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
657
658    This dataset operator is very useful when running distributed training, as
659    it allows each worker to read a unique subset.
660
661    When reading a single input file, you can skip elements as follows:
662
663    ```python
664    d = tf.data.TFRecordDataset(FLAGS.input_file)
665    d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
666    d = d.repeat(FLAGS.num_epochs)
667    d = d.shuffle(FLAGS.shuffle_buffer_size)
668    d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
669    ```
670
671    Important caveats:
672
673    - Be sure to shard before you use any randomizing operator (such as
674      shuffle).
675    - Generally it is best if the shard operator is used early in the dataset
676      pipeline. For example, when reading from a set of TFRecord files, shard
677      before converting the dataset to input samples. This avoids reading every
678      file on every worker. The following is an example of an efficient
679      sharding strategy within a complete pipeline:
680
681    ```python
682    d = Dataset.list_files(FLAGS.pattern)
683    d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
684    d = d.repeat(FLAGS.num_epochs)
685    d = d.shuffle(FLAGS.shuffle_buffer_size)
686    d = d.repeat()
687    d = d.interleave(tf.data.TFRecordDataset,
688                     cycle_length=FLAGS.num_readers, block_length=1)
689    d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
690    ```
691
692    Args:
693      num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
694        shards operating in parallel.
695      index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
696
697    Returns:
698      Dataset: A `Dataset`.
699
700    Raises:
701      ValueError: if `num_shards` or `index` are illegal values. Note: error
702        checking is done on a best-effort basis, and aren't guaranteed to be
703        caught upon dataset creation. (e.g. providing in a placeholder tensor
704        bypasses the early checking, and will instead result in an error during
705        a session.run call.)
706    """
707    num_shards = ops.convert_to_tensor(
708        num_shards, name="num_shards", dtype=dtypes.int64)
709    num_shards_static = tensor_util.constant_value(num_shards)
710    index = ops.convert_to_tensor(index, name="index", dtype=dtypes.int64)
711    index_static = tensor_util.constant_value(index)
712
713    if num_shards_static is not None and num_shards_static < 1:
714      raise ValueError("num_shards must be >= 1; got: %s" % num_shards_static)
715    if index_static is not None and index_static < 0:
716      raise ValueError("index must be >= 0; got: %s" % index_static)
717    if (index_static is not None and num_shards_static is not None and
718        index_static >= num_shards_static):
719      raise ValueError("index must be <= num_shards; %s is not < %s" %
720                       (index_static, num_shards_static))
721
722    def filter_fn(elem_index, _):
723      mod_result = math_ops.mod(elem_index, num_shards)
724      return math_ops.equal(mod_result, index)
725
726    return self._enumerate().filter(filter_fn).map(lambda _, elem: elem)
727
728  def batch(self, batch_size):
729    """Combines consecutive elements of this dataset into batches.
730
731    NOTE: If the number of elements (`N`) in this dataset is not an exact
732    multiple of `batch_size`, the final batch contain smaller tensors with
733    shape `N % batch_size` in the batch dimension. If your program depends on
734    the batches having the same shape, consider using the
735    @{tf.contrib.data.batch_and_drop_remainder} transformation instead.
736
737    Args:
738      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
739        consecutive elements of this dataset to combine in a single batch.
740
741    Returns:
742      Dataset: A `Dataset`.
743    """
744    return BatchDataset(self, batch_size)
745
746  def padded_batch(self, batch_size, padded_shapes, padding_values=None):
747    """Combines consecutive elements of this dataset into padded batches.
748
749    Like `Dataset.dense_to_sparse_batch()`, this method combines
750    multiple consecutive elements of this dataset, which might have
751    different shapes, into a single element. The tensors in the
752    resulting element have an additional outer dimension, and are
753    padded to the respective shape in `padded_shapes`.
754
755    Args:
756      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
757        consecutive elements of this dataset to combine in a single batch.
758      padded_shapes: A nested structure of `tf.TensorShape` or
759        `tf.int64` vector tensor-like objects representing the shape
760        to which the respective component of each input element should
761        be padded prior to batching. Any unknown dimensions
762        (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a
763        tensor-like object) will be padded to the maximum size of that
764        dimension in each batch.
765      padding_values: (Optional.) A nested structure of scalar-shaped
766        `tf.Tensor`, representing the padding values to use for the
767        respective components.  Defaults are `0` for numeric types and
768        the empty string for string types.
769
770    Returns:
771      Dataset: A `Dataset`.
772    """
773    return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values)
774
775  def map(self, map_func, num_parallel_calls=None):
776    """Maps `map_func` across this dataset.
777
778    Args:
779      map_func: A function mapping a nested structure of tensors (having
780        shapes and types defined by `self.output_shapes` and
781       `self.output_types`) to another nested structure of tensors.
782      num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
783        representing the number elements to process in parallel. If not
784        specified, elements will be processed sequentially.
785
786    Returns:
787      Dataset: A `Dataset`.
788    """
789    if num_parallel_calls is None:
790      return MapDataset(self, map_func)
791    else:
792      return ParallelMapDataset(self, map_func, num_parallel_calls)
793
794  def flat_map(self, map_func):
795    """Maps `map_func` across this dataset and flattens the result.
796
797    Args:
798      map_func: A function mapping a nested structure of tensors (having shapes
799        and types defined by `self.output_shapes` and `self.output_types`) to a
800        `Dataset`.
801
802    Returns:
803      Dataset: A `Dataset`.
804    """
805    return FlatMapDataset(self, map_func)
806
807  def interleave(self, map_func, cycle_length, block_length=1):
808    """Maps `map_func` across this dataset, and interleaves the results.
809
810    For example, you can use `Dataset.interleave()` to process many input files
811    concurrently:
812
813    ```python
814    # Preprocess 4 files concurrently, and interleave blocks of 16 records from
815    # each file.
816    filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
817    dataset = (Dataset.from_tensor_slices(filenames)
818               .interleave(lambda x:
819                   TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
820                   cycle_length=4, block_length=16))
821    ```
822
823    The `cycle_length` and `block_length` arguments control the order in which
824    elements are produced. `cycle_length` controls the number of input elements
825    that are processed concurrently. If you set `cycle_length` to 1, this
826    transformation will handle one input element at a time, and will produce
827    identical results = to @{tf.data.Dataset.flat_map}. In general,
828    this transformation will apply `map_func` to `cycle_length` input elements,
829    open iterators on the returned `Dataset` objects, and cycle through them
830    producing `block_length` consecutive elements from each iterator, and
831    consuming the next input element each time it reaches the end of an
832    iterator.
833
834    For example:
835
836    ```python
837    # NOTE: The following examples use `{ ... }` to represent the
838    # contents of a dataset.
839    a = { 1, 2, 3, 4, 5 }
840
841    # NOTE: New lines indicate "block" boundaries.
842    a.interleave(lambda x: Dataset.from_tensors(x).repeat(6),
843                 cycle_length=2, block_length=4) == {
844        1, 1, 1, 1,
845        2, 2, 2, 2,
846        1, 1,
847        2, 2,
848        3, 3, 3, 3,
849        4, 4, 4, 4,
850        3, 3,
851        4, 4,
852        5, 5, 5, 5,
853        5, 5,
854    }
855    ```
856
857    NOTE: The order of elements yielded by this transformation is
858    deterministic, as long as `map_func` is a pure function. If
859    `map_func` contains any stateful operations, the order in which
860    that state is accessed is undefined.
861
862    Args:
863      map_func: A function mapping a nested structure of tensors (having shapes
864        and types defined by `self.output_shapes` and `self.output_types`) to a
865        `Dataset`.
866      cycle_length: The number of elements from this dataset that will be
867        processed concurrently.
868      block_length: The number of consecutive elements to produce from each
869        input element before cycling to another input element.
870
871    Returns:
872      Dataset: A `Dataset`.
873    """
874    return InterleaveDataset(self, map_func, cycle_length, block_length)
875
876  def filter(self, predicate):
877    """Filters this dataset according to `predicate`.
878
879    Args:
880      predicate: A function mapping a nested structure of tensors (having shapes
881        and types defined by `self.output_shapes` and `self.output_types`) to a
882        scalar `tf.bool` tensor.
883
884    Returns:
885      Dataset: A `Dataset`.
886    """
887    return FilterDataset(self, predicate)
888
889  def apply(self, transformation_func):
890    """Apply a transformation function to this dataset.
891
892    `apply` enables chaining of custom `Dataset` transformations, which are
893    represented as functions that take one `Dataset` argument and return a
894    transformed `Dataset`.
895
896    For example:
897
898    ```
899    dataset = (dataset.map(lambda x: x ** 2)
900               .apply(group_by_window(key_func, reduce_func, window_size))
901               .map(lambda x: x ** 3))
902    ```
903
904    Args:
905      transformation_func: A function that takes one `Dataset` argument and
906          returns a `Dataset`.
907
908    Returns:
909      Dataset: The `Dataset` returned by applying `transformation_func` to this
910          dataset.
911    """
912    dataset = transformation_func(self)
913    if not isinstance(dataset, Dataset):
914      raise TypeError("`transformation_func` must return a Dataset.")
915    return dataset
916
917
918class TensorDataset(Dataset):
919  """A `Dataset` with a single element, viz. a nested structure of tensors."""
920
921  def __init__(self, tensors):
922    """See `Dataset.from_tensors()` for details."""
923    super(TensorDataset, self).__init__()
924    with ops.name_scope("tensors"):
925      tensors = nest.pack_sequence_as(tensors, [
926          sparse_tensor_lib.SparseTensor.from_value(t)
927          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
928              t, name="component_%d" % i)
929          for i, t in enumerate(nest.flatten(tensors))
930      ])
931
932    self._tensors = sparse.serialize_sparse_tensors(tensors)
933    self._output_classes = sparse.get_classes(tensors)
934    self._output_shapes = nest.pack_sequence_as(
935        tensors, [t.get_shape() for t in nest.flatten(tensors)])
936    self._output_types = nest.pack_sequence_as(
937        tensors, [t.dtype for t in nest.flatten(tensors)])
938
939  def _as_variant_tensor(self):
940    return gen_dataset_ops.tensor_dataset(
941        nest.flatten(self._tensors),
942        output_shapes=nest.flatten(
943            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
944
945  @property
946  def output_classes(self):
947    return self._output_classes
948
949  @property
950  def output_shapes(self):
951    return self._output_shapes
952
953  @property
954  def output_types(self):
955    return self._output_types
956
957
958class TensorSliceDataset(Dataset):
959  """A `Dataset` of slices from a nested structure of tensors."""
960
961  def __init__(self, tensors):
962    """See `Dataset.from_tensor_slices()` for details."""
963    super(TensorSliceDataset, self).__init__()
964    with ops.name_scope("tensors"):
965      tensors = nest.pack_sequence_as(tensors, [
966          sparse_tensor_lib.SparseTensor.from_value(t)
967          if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
968              t, name="component_%d" % i)
969          for i, t in enumerate(nest.flatten(tensors))
970      ])
971      flat_tensors = nest.flatten(tensors)
972
973    batch_dim = flat_tensors[0].get_shape()[0]
974    for t in flat_tensors[1:]:
975      batch_dim.assert_is_compatible_with(t.get_shape()[0])
976    self._tensors = sparse.serialize_many_sparse_tensors(tensors)
977    self._output_classes = sparse.get_classes(tensors)
978    self._output_shapes = nest.pack_sequence_as(
979        tensors, [t.get_shape()[1:] for t in nest.flatten(tensors)])
980    self._output_types = nest.pack_sequence_as(
981        tensors, [t.dtype for t in nest.flatten(tensors)])
982
983  def _as_variant_tensor(self):
984    return gen_dataset_ops.tensor_slice_dataset(
985        nest.flatten(self._tensors),
986        output_shapes=nest.flatten(
987            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
988
989  @property
990  def output_classes(self):
991    return self._output_classes
992
993  @property
994  def output_shapes(self):
995    return self._output_shapes
996
997  @property
998  def output_types(self):
999    return self._output_types
1000
1001
1002class SparseTensorSliceDataset(Dataset):
1003  """A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows."""
1004
1005  def __init__(self, sparse_tensor):
1006    """See `Dataset.from_sparse_tensor_slices()` for details."""
1007    super(SparseTensorSliceDataset, self).__init__()
1008    if not isinstance(sparse_tensor, sparse_tensor_lib.SparseTensor):
1009      raise TypeError("`sparse_tensor` must be a `tf.SparseTensor` object.")
1010    self._sparse_tensor = sparse_tensor
1011
1012  def _as_variant_tensor(self):
1013    return gen_dataset_ops.sparse_tensor_slice_dataset(
1014        self._sparse_tensor.indices, self._sparse_tensor.values,
1015        self._sparse_tensor.dense_shape)
1016
1017  @property
1018  def output_classes(self):
1019    return (ops.Tensor, ops.Tensor, ops.Tensor)
1020
1021  @property
1022  def output_shapes(self):
1023    indices_shape = self._sparse_tensor.indices.get_shape()
1024    shape_shape = self._sparse_tensor.dense_shape.get_shape()
1025    rank = (indices_shape[1] - 1).merge_with(shape_shape[0] - 1)
1026    num_values = tensor_shape.Dimension(None)
1027    return (tensor_shape.TensorShape([num_values, rank]),
1028            tensor_shape.TensorShape([num_values]),
1029            tensor_shape.TensorShape([rank]))
1030
1031  @property
1032  def output_types(self):
1033    return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64)
1034
1035
1036class ZipDataset(Dataset):
1037  """A `Dataset` that zips its inputs together."""
1038
1039  def __init__(self, datasets):
1040    """See `Dataset.zip()` for details."""
1041    super(ZipDataset, self).__init__()
1042    for ds in nest.flatten(datasets):
1043      if not isinstance(ds, Dataset):
1044        if isinstance(ds, list):
1045          message = ("The argument to `Dataset.zip()` must be a nested "
1046                     "structure of `Dataset` objects. Nested structures do not "
1047                     "support Python lists; please use a tuple instead.")
1048        else:
1049          message = ("The argument to `Dataset.zip()` must be a nested "
1050                     "structure of `Dataset` objects.")
1051        raise TypeError(message)
1052    self._datasets = datasets
1053
1054  def _as_variant_tensor(self):
1055    # pylint: disable=protected-access
1056    return gen_dataset_ops.zip_dataset(
1057        [ds._as_variant_tensor() for ds in nest.flatten(self._datasets)],
1058        output_shapes=[
1059            s
1060            for ds in nest.flatten(self._datasets)
1061            for s in nest.flatten(ds.output_shapes)
1062        ],
1063        output_types=[
1064            t
1065            for ds in nest.flatten(self._datasets)
1066            for t in nest.flatten(ds.output_types)
1067        ])
1068    # pylint: enable=protected-access
1069
1070  @property
1071  def output_classes(self):
1072    return nest.pack_sequence_as(
1073        self._datasets,
1074        [ds.output_classes for ds in nest.flatten(self._datasets)])
1075
1076  @property
1077  def output_shapes(self):
1078    return nest.pack_sequence_as(
1079        self._datasets,
1080        [ds.output_shapes for ds in nest.flatten(self._datasets)])
1081
1082  @property
1083  def output_types(self):
1084    return nest.pack_sequence_as(
1085        self._datasets,
1086        [ds.output_types for ds in nest.flatten(self._datasets)])
1087
1088
1089class ConcatenateDataset(Dataset):
1090  """A `Dataset` that concatenates its input with given dataset."""
1091
1092  def __init__(self, input_dataset, dataset_to_concatenate):
1093    """See `Dataset.concatenate()` for details."""
1094    super(ConcatenateDataset, self).__init__()
1095    self._input_dataset = input_dataset
1096    self._dataset_to_concatenate = dataset_to_concatenate
1097    nest.assert_same_structure(input_dataset.output_types,
1098                               dataset_to_concatenate.output_types)
1099    for a, b in zip(
1100        nest.flatten(input_dataset.output_types),
1101        nest.flatten(dataset_to_concatenate.output_types)):
1102      if a != b:
1103        raise TypeError(
1104            "Two datasets to concatenate have different types %s and %s" %
1105            (input_dataset.output_types, dataset_to_concatenate.output_types))
1106
1107  def _as_variant_tensor(self):
1108    # pylint: disable=protected-access
1109    return gen_dataset_ops.concatenate_dataset(
1110        self._input_dataset._as_variant_tensor(),
1111        self._dataset_to_concatenate._as_variant_tensor(),
1112        output_shapes=nest.flatten(
1113            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
1114        output_types=nest.flatten(
1115            sparse.as_dense_types(self.output_types, self.output_classes)))
1116    # pylint: enable=protected-access
1117
1118  @property
1119  def output_classes(self):
1120    return self._input_dataset.output_classes
1121
1122  @property
1123  def output_shapes(self):
1124    return nest.pack_sequence_as(self._input_dataset.output_shapes, [
1125        ts1.most_specific_compatible_shape(ts2)
1126        for (ts1, ts2) in zip(
1127            nest.flatten(self._input_dataset.output_shapes),
1128            nest.flatten(self._dataset_to_concatenate.output_shapes))
1129    ])
1130
1131  @property
1132  def output_types(self):
1133    return self._input_dataset.output_types
1134
1135
1136class RepeatDataset(Dataset):
1137  """A `Dataset` that repeats its input several times."""
1138
1139  def __init__(self, input_dataset, count):
1140    """See `Dataset.repeat()` for details."""
1141    super(RepeatDataset, self).__init__()
1142    self._input_dataset = input_dataset
1143    if count is None:
1144      self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
1145    else:
1146      self._count = ops.convert_to_tensor(
1147          count, dtype=dtypes.int64, name="count")
1148
1149  def _as_variant_tensor(self):
1150    return gen_dataset_ops.repeat_dataset(
1151        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1152        count=self._count,
1153        output_shapes=nest.flatten(
1154            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
1155        output_types=nest.flatten(
1156            sparse.as_dense_types(self.output_types, self.output_classes)))
1157
1158  @property
1159  def output_classes(self):
1160    return self._input_dataset.output_classes
1161
1162  @property
1163  def output_shapes(self):
1164    return self._input_dataset.output_shapes
1165
1166  @property
1167  def output_types(self):
1168    return self._input_dataset.output_types
1169
1170
1171class RangeDataset(Dataset):
1172  """A `Dataset` of a step separated range of values."""
1173
1174  def __init__(self, *args):
1175    """See `Dataset.range()` for details."""
1176    super(RangeDataset, self).__init__()
1177    self._parse_args(*args)
1178
1179  def _parse_args(self, *args):
1180    if len(args) == 1:
1181      self._start = self._build_tensor(0, "start")
1182      self._stop = self._build_tensor(args[0], "stop")
1183      self._step = self._build_tensor(1, "step")
1184    elif len(args) == 2:
1185      self._start = self._build_tensor(args[0], "start")
1186      self._stop = self._build_tensor(args[1], "stop")
1187      self._step = self._build_tensor(1, "step")
1188    elif len(args) == 3:
1189      self._start = self._build_tensor(args[0], "start")
1190      self._stop = self._build_tensor(args[1], "stop")
1191      self._step = self._build_tensor(args[2], "step")
1192    else:
1193      raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))
1194
1195  def _build_tensor(self, int64_value, name):
1196    return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
1197
1198  def _as_variant_tensor(self):
1199    return gen_dataset_ops.range_dataset(
1200        start=self._start,
1201        stop=self._stop,
1202        step=self._step,
1203        output_shapes=nest.flatten(
1204            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
1205        output_types=nest.flatten(
1206            sparse.as_dense_types(self.output_types, self.output_classes)))
1207
1208  @property
1209  def output_classes(self):
1210    return ops.Tensor
1211
1212  @property
1213  def output_shapes(self):
1214    return tensor_shape.scalar()
1215
1216  @property
1217  def output_types(self):
1218    return dtypes.int64
1219
1220
1221class CacheDataset(Dataset):
1222  """A `Dataset` that caches elements of its input."""
1223
1224  def __init__(self, input_dataset, filename):
1225    """See `Dataset.cache()` for details."""
1226    super(CacheDataset, self).__init__()
1227    self._input_dataset = input_dataset
1228    self._filename = ops.convert_to_tensor(
1229        filename, dtype=dtypes.string, name="filename")
1230
1231  def _as_variant_tensor(self):
1232    return gen_dataset_ops.cache_dataset(
1233        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1234        filename=self._filename,
1235        output_shapes=nest.flatten(
1236            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
1237        output_types=nest.flatten(
1238            sparse.as_dense_types(self.output_types, self.output_classes)))
1239
1240  @property
1241  def output_classes(self):
1242    return self._input_dataset.output_classes
1243
1244  @property
1245  def output_shapes(self):
1246    return self._input_dataset.output_shapes
1247
1248  @property
1249  def output_types(self):
1250    return self._input_dataset.output_types
1251
1252
1253class ShuffleDataset(Dataset):
1254  """A `Dataset` that randomly shuffles the elements of its input."""
1255
1256  def __init__(self,
1257               input_dataset,
1258               buffer_size,
1259               seed=None,
1260               reshuffle_each_iteration=None):
1261    """Randomly shuffles the elements of this dataset.
1262
1263    Args:
1264      input_dataset: The input dataset.
1265      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
1266        number of elements from this dataset from which the new
1267        dataset will sample.
1268      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
1269        random seed that will be used to create the distribution. See
1270        @{tf.set_random_seed} for behavior.
1271      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
1272        that the dataset should be pseudorandomly reshuffled each time it is
1273        iterated over. (Defaults to `True`.)
1274
1275    Returns:
1276      A `Dataset`.
1277
1278    Raises:
1279      ValueError: if invalid arguments are provided.
1280    """
1281    super(ShuffleDataset, self).__init__()
1282    self._input_dataset = input_dataset
1283    self._buffer_size = ops.convert_to_tensor(
1284        buffer_size, dtype=dtypes.int64, name="buffer_size")
1285    seed, seed2 = random_seed.get_seed(seed)
1286    if seed is None:
1287      self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed")
1288    else:
1289      self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed")
1290    if seed2 is None:
1291      self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2")
1292    else:
1293      self._seed2 = ops.convert_to_tensor(
1294          seed2, dtype=dtypes.int64, name="seed2")
1295    if reshuffle_each_iteration is None:
1296      self._reshuffle_each_iteration = True
1297    else:
1298      self._reshuffle_each_iteration = reshuffle_each_iteration
1299
1300  def _as_variant_tensor(self):
1301    return gen_dataset_ops.shuffle_dataset(
1302        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1303        buffer_size=self._buffer_size,
1304        seed=self._seed,
1305        seed2=self._seed2,
1306        reshuffle_each_iteration=self._reshuffle_each_iteration,
1307        output_shapes=nest.flatten(
1308            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
1309        output_types=nest.flatten(
1310            sparse.as_dense_types(self.output_types, self.output_classes)))
1311
1312  @property
1313  def output_classes(self):
1314    return self._input_dataset.output_classes
1315
1316  @property
1317  def output_shapes(self):
1318    return self._input_dataset.output_shapes
1319
1320  @property
1321  def output_types(self):
1322    return self._input_dataset.output_types
1323
1324
1325class TakeDataset(Dataset):
1326  """A `Dataset` containing the first `count` elements from its input."""
1327
1328  def __init__(self, input_dataset, count):
1329    """See `Dataset.take()` for details."""
1330    super(TakeDataset, self).__init__()
1331    self._input_dataset = input_dataset
1332    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
1333
1334  def _as_variant_tensor(self):
1335    return gen_dataset_ops.take_dataset(
1336        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1337        count=self._count,
1338        output_shapes=nest.flatten(
1339            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
1340        output_types=nest.flatten(
1341            sparse.as_dense_types(self.output_types, self.output_classes)))
1342
1343  @property
1344  def output_classes(self):
1345    return self._input_dataset.output_classes
1346
1347  @property
1348  def output_shapes(self):
1349    return self._input_dataset.output_shapes
1350
1351  @property
1352  def output_types(self):
1353    return self._input_dataset.output_types
1354
1355
1356class SkipDataset(Dataset):
1357  """A `Dataset` skipping the first `count` elements from its input."""
1358
1359  def __init__(self, input_dataset, count):
1360    """See `Dataset.skip()` for details."""
1361    super(SkipDataset, self).__init__()
1362    self._input_dataset = input_dataset
1363    self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
1364
1365  def _as_variant_tensor(self):
1366    return gen_dataset_ops.skip_dataset(
1367        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1368        count=self._count,
1369        output_shapes=nest.flatten(
1370            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
1371        output_types=nest.flatten(
1372            sparse.as_dense_types(self.output_types, self.output_classes)))
1373
1374  @property
1375  def output_classes(self):
1376    return self._input_dataset.output_classes
1377
1378  @property
1379  def output_shapes(self):
1380    return self._input_dataset.output_shapes
1381
1382  @property
1383  def output_types(self):
1384    return self._input_dataset.output_types
1385
1386
1387class BatchDataset(Dataset):
1388  """A `Dataset` that batches contiguous elements from its input."""
1389
1390  def __init__(self, input_dataset, batch_size):
1391    """See `Dataset.batch()` for details."""
1392    super(BatchDataset, self).__init__()
1393    self._input_dataset = input_dataset
1394    self._batch_size = ops.convert_to_tensor(
1395        batch_size, dtype=dtypes.int64, name="batch_size")
1396
1397  def _as_variant_tensor(self):
1398    return gen_dataset_ops.batch_dataset(
1399        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1400        batch_size=self._batch_size,
1401        output_shapes=nest.flatten(
1402            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
1403        output_types=nest.flatten(
1404            sparse.as_dense_types(self.output_types, self.output_classes)))
1405
1406  @property
1407  def output_classes(self):
1408    return self._input_dataset.output_classes
1409
1410  @property
1411  def output_shapes(self):
1412    input_shapes = self._input_dataset.output_shapes
1413    return nest.pack_sequence_as(input_shapes, [
1414        tensor_shape.vector(None).concatenate(s)
1415        for s in nest.flatten(self._input_dataset.output_shapes)
1416    ])
1417
1418  @property
1419  def output_types(self):
1420    return self._input_dataset.output_types
1421
1422
1423def _partial_shape_to_tensor(shape_like):
1424  try:
1425    # First attempt to convert the input to a shape, and return the
1426    # "canonical" tensor representation, which uses `-1` in place of
1427    # `None`.
1428    shape_like = tensor_shape.as_shape(shape_like)
1429    return ops.convert_to_tensor(
1430        [dim if dim is not None else -1 for dim in shape_like.as_list()],
1431        dtype=dtypes.int64)
1432  except (TypeError, ValueError):
1433    # The argument was not trivially convertible to a
1434    # `tf.TensorShape`, so fall back on the conversion to tensor
1435    # machinery.
1436    return ops.convert_to_tensor(shape_like, dtype=dtypes.int64)
1437
1438
1439def _padding_value_to_tensor(value, output_type):
1440  """Converts the padding value to a tensor.
1441
1442  Args:
1443    value: The padding value.
1444    output_type: Its expected dtype.
1445
1446  Returns:
1447    A scalar `Tensor`.
1448
1449  Raises:
1450    ValueError: if the padding value is not a scalar.
1451    TypeError: if the padding value's type does not match `output_type`.
1452  """
1453  value = ops.convert_to_tensor(value, name="padding_value")
1454  if not value.shape.is_compatible_with(tensor_shape.scalar()):
1455    raise ValueError("Padding value should be a scalar, but is not: %s" % value)
1456  if value.dtype != output_type:
1457    raise TypeError("Padding value tensor (%s) does not match output type: %s" %
1458                    (value, output_type))
1459  return value
1460
1461
1462def _default_padding(input_dataset):
1463
1464  def make_zero(t):
1465    if t.base_dtype == dtypes.string:
1466      return ""
1467    elif t.base_dtype == dtypes.variant:
1468      raise TypeError("Unable to create padding for field of type 'variant'")
1469    else:
1470      return np.zeros_like(t.as_numpy_dtype())
1471
1472  return nest.map_structure(make_zero, input_dataset.output_types)
1473
1474
1475class PaddedBatchDataset(Dataset):
1476  """A `Dataset` that batches and pads contiguous elements from its input."""
1477
1478  def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
1479    """See `Dataset.batch()` for details."""
1480    super(PaddedBatchDataset, self).__init__()
1481    if sparse.any_sparse(input_dataset.output_classes):
1482      # TODO(b/63669786): support batching of sparse tensors
1483      raise TypeError(
1484          "Batching of padded sparse tensors is not currently supported")
1485    self._input_dataset = input_dataset
1486    self._batch_size = ops.convert_to_tensor(
1487        batch_size, dtype=dtypes.int64, name="batch_size")
1488    padding_values = (
1489        padding_values
1490        if padding_values is not None else _default_padding(input_dataset))
1491    self._padded_shapes = nest.map_structure_up_to(
1492        input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes)
1493    self._padding_values = nest.map_structure_up_to(
1494        input_dataset.output_shapes, _padding_value_to_tensor, padding_values,
1495        input_dataset.output_types)
1496
1497  def _as_variant_tensor(self):
1498    return gen_dataset_ops.padded_batch_dataset(
1499        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1500        batch_size=self._batch_size,
1501        padded_shapes=[
1502            ops.convert_to_tensor(s, dtype=dtypes.int64)
1503            for s in nest.flatten(self._padded_shapes)
1504        ],
1505        padding_values=nest.flatten(self._padding_values),
1506        output_shapes=nest.flatten(
1507            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
1508
1509  @property
1510  def output_classes(self):
1511    return self._input_dataset.output_classes
1512
1513  @property
1514  def output_shapes(self):
1515
1516    def _padded_shape_to_batch_shape(s):
1517      return tensor_shape.vector(None).concatenate(
1518          tensor_util.constant_value_as_shape(s))
1519
1520    return nest.map_structure(_padded_shape_to_batch_shape, self._padded_shapes)
1521
1522  @property
1523  def output_types(self):
1524    return self._input_dataset.output_types
1525
1526
1527def _should_unpack_args(args):
1528  """Returns `True` if `args` should be `*args` when passed to a callable."""
1529  return type(args) is tuple  # pylint: disable=unidiomatic-typecheck
1530
1531
1532class MapDataset(Dataset):
1533  """A `Dataset` that maps a function over elements in its input."""
1534
1535  def __init__(self, input_dataset, map_func):
1536    """See `Dataset.map()` for details."""
1537    super(MapDataset, self).__init__()
1538    self._input_dataset = input_dataset
1539
1540    self._output_classes = None
1541    self._output_shapes = None
1542    self._output_types = None
1543
1544    @function.Defun(*nest.flatten(
1545        sparse.as_dense_types(input_dataset.output_types,
1546                              input_dataset.output_classes)))
1547    def tf_map_func(*args):
1548      """A wrapper for Defun that facilitates shape inference."""
1549      # Pass in shape information from the input_dataset.
1550      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
1551                                            input_dataset.output_classes)
1552      for arg, shape in zip(args, nest.flatten(dense_shapes)):
1553        arg.set_shape(shape)
1554
1555      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
1556      nested_args = sparse.deserialize_sparse_tensors(
1557          nested_args, input_dataset.output_types, input_dataset.output_shapes,
1558          input_dataset.output_classes)
1559      if _should_unpack_args(nested_args):
1560        ret = map_func(*nested_args)
1561      else:
1562        ret = map_func(nested_args)
1563
1564      # If `map_func` returns a list of tensors, `nest.flatten()` and
1565      # `ops.convert_to_tensor()` would conspire to attempt to stack
1566      # those tensors into a single tensor, because the customized
1567      # version of `nest.flatten()` does not recurse into lists. Since
1568      # it is more likely that the list arose from returning the
1569      # result of an operation (such as `tf.py_func()`) that returns a
1570      # list of not-necessarily-stackable tensors, we treat the
1571      # returned value is a `tuple` instead. A user wishing to pack
1572      # the return value into a single tensor can use an explicit
1573      # `tf.stack()` before returning.
1574      if isinstance(ret, list):
1575        ret = tuple(ret)
1576
1577      # Convert any `SparseTensorValue`s to `SparseTensor`s.
1578      ret = nest.pack_sequence_as(ret, [
1579          sparse_tensor_lib.SparseTensor.from_value(t)
1580          if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
1581      ])
1582
1583      self._output_classes = sparse.get_classes(ret)
1584      self._output_shapes = nest.pack_sequence_as(
1585          ret, [t.get_shape() for t in nest.flatten(ret)])
1586      self._output_types = nest.pack_sequence_as(
1587          ret, [t.dtype for t in nest.flatten(ret)])
1588
1589      # Serialize any sparse tensors and convert result to tensors.
1590      ret = nest.pack_sequence_as(ret, [
1591          ops.convert_to_tensor(t)
1592          for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
1593      ])
1594      return nest.flatten(ret)
1595
1596    self._map_func = tf_map_func
1597    self._map_func.add_to_graph(ops.get_default_graph())
1598
1599  def _as_variant_tensor(self):
1600    input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
1601    return gen_dataset_ops.map_dataset(
1602        input_t,
1603        self._map_func.captured_inputs,
1604        f=self._map_func,
1605        output_types=nest.flatten(
1606            sparse.as_dense_types(self.output_types, self.output_classes)),
1607        output_shapes=nest.flatten(
1608            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
1609
1610  @property
1611  def output_classes(self):
1612    return self._output_classes
1613
1614  @property
1615  def output_shapes(self):
1616    return self._output_shapes
1617
1618  @property
1619  def output_types(self):
1620    return self._output_types
1621
1622
1623class ParallelMapDataset(MapDataset):
1624  """A `Dataset` that maps a function over elements in its input in parallel."""
1625
1626  def __init__(self, input_dataset, map_func, num_parallel_calls):
1627    """See `Dataset.map()` for details."""
1628    super(ParallelMapDataset, self).__init__(input_dataset, map_func)
1629
1630    self._num_parallel_calls = ops.convert_to_tensor(
1631        num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
1632
1633  def _as_variant_tensor(self):
1634    input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
1635    # pylint: disable=protected-access
1636    return gen_dataset_ops.parallel_map_dataset(
1637        input_t,
1638        self._map_func.captured_inputs,
1639        f=self._map_func,
1640        num_parallel_calls=self._num_parallel_calls,
1641        output_types=nest.flatten(
1642            sparse.as_dense_types(self.output_types, self.output_classes)),
1643        output_shapes=nest.flatten(
1644            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
1645    # pylint: enable=protected-access
1646
1647
1648class FlatMapDataset(Dataset):
1649  """A `Dataset` that maps a function over its input and flattens the result."""
1650
1651  def __init__(self, input_dataset, map_func):
1652    """See `Dataset.flat_map()` for details."""
1653    super(FlatMapDataset, self).__init__()
1654    self._input_dataset = input_dataset
1655
1656    @function.Defun(*nest.flatten(
1657        sparse.as_dense_types(input_dataset.output_types,
1658                              input_dataset.output_classes)))
1659    def tf_map_func(*args):
1660      """A wrapper for Defun that facilitates shape inference."""
1661      # Pass in shape information from the input_dataset.
1662      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
1663                                            input_dataset.output_classes)
1664      for arg, shape in zip(args, nest.flatten(dense_shapes)):
1665        arg.set_shape(shape)
1666
1667      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
1668      nested_args = sparse.deserialize_sparse_tensors(
1669          nested_args, input_dataset.output_types, input_dataset.output_shapes,
1670          input_dataset.output_classes)
1671      if _should_unpack_args(nested_args):
1672        dataset = map_func(*nested_args)
1673      else:
1674        dataset = map_func(nested_args)
1675
1676      if not isinstance(dataset, Dataset):
1677        raise TypeError("`map_func` must return a `Dataset` object.")
1678
1679      self._output_classes = dataset.output_classes
1680      self._output_types = dataset.output_types
1681      self._output_shapes = dataset.output_shapes
1682
1683      return dataset._as_variant_tensor()  # pylint: disable=protected-access
1684
1685    self._map_func = tf_map_func
1686    self._map_func.add_to_graph(ops.get_default_graph())
1687
1688  def _as_variant_tensor(self):
1689    return gen_dataset_ops.flat_map_dataset(
1690        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1691        self._map_func.captured_inputs,
1692        f=self._map_func,
1693        output_types=nest.flatten(
1694            sparse.as_dense_types(self.output_types, self.output_classes)),
1695        output_shapes=nest.flatten(
1696            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
1697
1698  @property
1699  def output_classes(self):
1700    return self._output_classes
1701
1702  @property
1703  def output_shapes(self):
1704    return self._output_shapes
1705
1706  @property
1707  def output_types(self):
1708    return self._output_types
1709
1710
1711class InterleaveDataset(Dataset):
1712  """A `Dataset` that maps a function over its input and interleaves the result.
1713  """
1714
1715  def __init__(self, input_dataset, map_func, cycle_length, block_length):
1716    """See `Dataset.interleave()` for details."""
1717    super(InterleaveDataset, self).__init__()
1718    self._input_dataset = input_dataset
1719
1720    @function.Defun(*nest.flatten(
1721        sparse.as_dense_types(input_dataset.output_types,
1722                              input_dataset.output_classes)))
1723    def tf_map_func(*args):
1724      """A wrapper for Defun that facilitates shape inference."""
1725      # Pass in shape information from the input_dataset.
1726      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
1727                                            input_dataset.output_classes)
1728      for arg, shape in zip(args, nest.flatten(dense_shapes)):
1729        arg.set_shape(shape)
1730
1731      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
1732      nested_args = sparse.deserialize_sparse_tensors(
1733          nested_args, input_dataset.output_types, input_dataset.output_shapes,
1734          input_dataset.output_classes)
1735      if _should_unpack_args(nested_args):
1736        dataset = map_func(*nested_args)
1737      else:
1738        dataset = map_func(nested_args)
1739
1740      if not isinstance(dataset, Dataset):
1741        raise TypeError("`map_func` must return a `Dataset` object.")
1742
1743      self._output_classes = dataset.output_classes
1744      self._output_types = dataset.output_types
1745      self._output_shapes = dataset.output_shapes
1746
1747      return dataset._as_variant_tensor()  # pylint: disable=protected-access
1748
1749    self._map_func = tf_map_func
1750    self._map_func.add_to_graph(ops.get_default_graph())
1751
1752    self._cycle_length = ops.convert_to_tensor(
1753        cycle_length, dtype=dtypes.int64, name="cycle_length")
1754    self._block_length = ops.convert_to_tensor(
1755        block_length, dtype=dtypes.int64, name="block_length")
1756
1757  def _as_variant_tensor(self):
1758    return gen_dataset_ops.interleave_dataset(
1759        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1760        self._map_func.captured_inputs,
1761        self._cycle_length,
1762        self._block_length,
1763        f=self._map_func,
1764        output_types=nest.flatten(
1765            sparse.as_dense_types(self.output_types, self.output_classes)),
1766        output_shapes=nest.flatten(
1767            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
1768
1769  @property
1770  def output_classes(self):
1771    return self._output_classes
1772
1773  @property
1774  def output_shapes(self):
1775    return self._output_shapes
1776
1777  @property
1778  def output_types(self):
1779    return self._output_types
1780
1781
1782class FilterDataset(Dataset):
1783  """A `Dataset` that filters its input according to a predicate function."""
1784
1785  def __init__(self, input_dataset, predicate):
1786    """See `Dataset.filter()` for details."""
1787    super(FilterDataset, self).__init__()
1788    self._input_dataset = input_dataset
1789
1790    @function.Defun(*nest.flatten(
1791        sparse.as_dense_types(input_dataset.output_types,
1792                              input_dataset.output_classes)))
1793    def tf_predicate(*args):
1794      """A wrapper for Defun that facilitates shape inference."""
1795      # Pass in shape information from the input_dataset.
1796      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
1797                                            input_dataset.output_classes)
1798      for arg, shape in zip(args, nest.flatten(dense_shapes)):
1799        arg.set_shape(shape)
1800
1801      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
1802      nested_args = sparse.deserialize_sparse_tensors(
1803          nested_args, input_dataset.output_types, input_dataset.output_shapes,
1804          input_dataset.output_classes)
1805      if _should_unpack_args(nested_args):
1806        ret = predicate(*nested_args)
1807      else:
1808        ret = predicate(nested_args)
1809
1810      ret = ops.convert_to_tensor(ret, dtype=dtypes.bool)
1811      if not (ret.dtype == dtypes.bool and
1812              ret.shape.is_compatible_with(tensor_shape.scalar())):
1813        raise ValueError("`predicate` must return a scalar boolean tensor.")
1814
1815      return ret
1816
1817    self._predicate = tf_predicate
1818    self._predicate.add_to_graph(ops.get_default_graph())
1819
1820  def _as_variant_tensor(self):
1821    return gen_dataset_ops.filter_dataset(
1822        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1823        other_arguments=self._predicate.captured_inputs,
1824        predicate=self._predicate,
1825        output_types=nest.flatten(
1826            sparse.as_dense_types(self.output_types, self.output_classes)),
1827        output_shapes=nest.flatten(
1828            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
1829
1830  @property
1831  def output_classes(self):
1832    return self._input_dataset.output_classes
1833
1834  @property
1835  def output_shapes(self):
1836    return self._input_dataset.output_shapes
1837
1838  @property
1839  def output_types(self):
1840    return self._input_dataset.output_types
1841
1842
1843class PrefetchDataset(Dataset):
1844  """A `Dataset` that asynchronously prefetches its input."""
1845
1846  def __init__(self, input_dataset, buffer_size):
1847    """See `Dataset.prefetch()` for details."""
1848    super(PrefetchDataset, self).__init__()
1849    self._input_dataset = input_dataset
1850    self._buffer_size = ops.convert_to_tensor(
1851        buffer_size, dtype=dtypes.int64, name="buffer_size")
1852
1853  def _as_variant_tensor(self):
1854    return gen_dataset_ops.prefetch_dataset(
1855        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
1856        buffer_size=self._buffer_size,
1857        output_shapes=nest.flatten(
1858            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
1859        output_types=nest.flatten(
1860            sparse.as_dense_types(self.output_types, self.output_classes)))
1861
1862  @property
1863  def output_classes(self):
1864    return self._input_dataset.output_classes
1865
1866  @property
1867  def output_shapes(self):
1868    return self._input_dataset.output_shapes
1869
1870  @property
1871  def output_types(self):
1872    return self._input_dataset.output_types
1873