1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for Datasets and Iterators."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.contrib.data.python.ops import batching
21from tensorflow.contrib.data.python.ops import enumerate_ops
22from tensorflow.contrib.data.python.ops import error_ops
23from tensorflow.contrib.data.python.ops import grouping
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.data.util import nest
26from tensorflow.python.ops import gen_dataset_ops
27from tensorflow.python.ops import gen_io_ops
28from tensorflow.python.util import deprecation
29
30
31class Dataset(dataset_ops.Dataset):
32  """Represents a potentially large set of elements.
33
34  A `Dataset` can be used to represent an input pipeline as a
35  collection of elements (nested structures of tensors) and a "logical
36  plan" of transformations that act on those elements.
37  """
38
39  def __init__(self, dataset):
40    super(Dataset, self).__init__()
41    self._dataset = dataset
42
43  @deprecation.deprecated(None, "Use `ds._as_variant_tensor()`.")
44  def make_dataset_resource(self):
45    return self._as_variant_tensor()
46
47  def _as_variant_tensor(self):
48    return self._dataset._as_variant_tensor()  # pylint: disable=protected-access
49
50  @property
51  def output_classes(self):
52    return self._dataset.output_classes
53
54  @property
55  def output_shapes(self):
56    return self._dataset.output_shapes
57
58  @property
59  def output_types(self):
60    return self._dataset.output_types
61
62  @staticmethod
63  @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensors()`.")
64  def from_tensors(tensors):
65    """Creates a `Dataset` with a single element, comprising the given tensors.
66
67    Args:
68      tensors: A nested structure of tensors.
69
70    Returns:
71      A `Dataset`.
72    """
73    return Dataset(dataset_ops.TensorDataset(tensors))
74
75  @staticmethod
76  @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
77  def from_tensor_slices(tensors):
78    """Creates a `Dataset` whose elements are slices of the given tensors.
79
80    Args:
81      tensors: A nested structure of tensors, each having the same size in the
82        0th dimension.
83
84    Returns:
85      A `Dataset`.
86    """
87    return Dataset(dataset_ops.TensorSliceDataset(tensors))
88
89  @staticmethod
90  @deprecation.deprecated(None,
91                          "Use `tf.data.Dataset.from_sparse_tensor_slices()`.")
92  def from_sparse_tensor_slices(sparse_tensor):
93    """Splits each rank-N `tf.SparseTensor` in this dataset row-wise.
94
95    Args:
96      sparse_tensor: A `tf.SparseTensor`.
97
98    Returns:
99      A `Dataset` of rank-(N-1) sparse tensors.
100    """
101    return Dataset(dataset_ops.SparseTensorSliceDataset(sparse_tensor))
102
103  @staticmethod
104  @deprecation.deprecated(None, "Use `tf.data.Dataset.from_generator()`.")
105  def from_generator(generator, output_types, output_shapes=None):
106    """Creates a `Dataset` whose elements are generated by `generator`.
107
108    The `generator` argument must be a callable object that returns
109    an object that support the `iter()` protocol (e.g. a generator function).
110    The elements generated by `generator` must be compatible with the given
111    `output_types` and (optional) `output_shapes` arguments.
112
113    For example:
114
115    ```python
116    import itertools
117
118    def gen():
119      for i in itertools.count(1):
120        yield (i, [1] * i)
121
122    ds = Dataset.from_generator(
123        gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
124    value = ds.make_one_shot_iterator().get_next()
125
126    sess.run(value)  # (1, array([1]))
127    sess.run(value)  # (2, array([1, 1]))
128    ```
129
130    Args:
131      generator: A callable object that takes no arguments and returns an
132        object that supports the `iter()` protocol.
133      output_types: A nested structure of `tf.DType` objects corresponding to
134        each component of an element yielded by `generator`.
135      output_shapes: (Optional.) A nested structure of `tf.TensorShape`
136        objects corresponding to each component of an element yielded by
137        `generator`.
138
139    Returns:
140      A `Dataset`.
141    """
142    return Dataset(dataset_ops.Dataset.from_generator(
143        generator, output_types, output_shapes))
144
145  @staticmethod
146  @deprecation.deprecated(None, "Use `tf.data.Dataset.range()`.")
147  def range(*args):
148    """Creates a `Dataset` of a step-separated range of values.
149
150    For example:
151
152    ```python
153    Dataset.range(5) == [0, 1, 2, 3, 4]
154    Dataset.range(2, 5) == [2, 3, 4]
155    Dataset.range(1, 5, 2) == [1, 3]
156    Dataset.range(1, 5, -2) == []
157    Dataset.range(5, 1) == []
158    Dataset.range(5, 1, -2) == [5, 3]
159    ```
160
161    Args:
162      *args: follow same semantics as python's xrange.
163        len(args) == 1 -> start = 0, stop = args[0], step = 1
164        len(args) == 2 -> start = args[0], stop = args[1], step = 1
165        len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
166
167    Returns:
168      A `RangeDataset`.
169
170    Raises:
171      ValueError: if len(args) == 0.
172    """
173    return Dataset(dataset_ops.RangeDataset(*args))
174
175  @staticmethod
176  @deprecation.deprecated(None, "Use `tf.data.Dataset.zip()`.")
177  def zip(datasets):
178    """Creates a `Dataset` by zipping together the given datasets.
179
180    This method has similar semantics to the built-in `zip()` function
181    in Python, with the main difference being that the `datasets`
182    argument can be an arbitrary nested structure of `Dataset` objects.
183    For example:
184
185    ```python
186    # NOTE: The following examples use `{ ... }` to represent the
187    # contents of a dataset.
188    a = { 1, 2, 3 }
189    b = { 4, 5, 6 }
190    c = { (7, 8), (9, 10), (11, 12) }
191    d = { 13, 14 }
192
193    # The nested structure of the `datasets` argument determines the
194    # structure of elements in the resulting dataset.
195    Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) }
196    Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) }
197
198    # The `datasets` argument may contain an arbitrary number of
199    # datasets.
200    Dataset.zip((a, b, c)) == { (1, 4, (7, 8)),
201                                (2, 5, (9, 10)),
202                                (3, 6, (11, 12)) }
203
204    # The number of elements in the resulting dataset is the same as
205    # the size of the smallest dataset in `datasets`.
206    Dataset.zip((a, d)) == { (1, 13), (2, 14) }
207    ```
208
209    Args:
210      datasets: A nested structure of datasets.
211
212    Returns:
213      A `Dataset`.
214    """
215    return Dataset(dataset_ops.ZipDataset(datasets))
216
217  def concatenate(self, dataset):
218    """Creates a `Dataset` by concatenating given dataset with this dataset.
219
220    ```python
221    # NOTE: The following examples use `{ ... }` to represent the
222    # contents of a dataset.
223    a = { 1, 2, 3 }
224    b = { 4, 5, 6, 7 }
225
226    # Input dataset and dataset to be concatenated should have same
227    # nested structures and output types.
228    # c = { (8, 9), (10, 11), (12, 13) }
229    # d = { 14.0, 15.0, 16.0 }
230    # a.concatenate(c) and a.concatenate(d) would result in error.
231
232    a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 }
233    ```
234
235    Args:
236      dataset: `Dataset` to be concatenated.
237
238    Returns:
239      A `Dataset`.
240    """
241    return Dataset(dataset_ops.ConcatenateDataset(self._dataset, dataset))
242
243  def prefetch(self, buffer_size):
244    """Creates a `Dataset` that prefetches elements from this dataset.
245
246    Args:
247      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
248        maximum number elements that will be buffered when prefetching.
249
250    Returns:
251      A `Dataset`.
252    """
253    return Dataset(dataset_ops.PrefetchDataset(self._dataset, buffer_size))
254
255  @staticmethod
256  @deprecation.deprecated(None, "Use `tf.data.Dataset.list_files()`.")
257  def list_files(file_pattern):
258    """A dataset of all files matching a pattern.
259
260    Example:
261      If we had the following files on our filesystem:
262        - /path/to/dir/a.txt
263        - /path/to/dir/b.py
264        - /path/to/dir/c.py
265      If we pass "/path/to/dir/*.py" as the directory, the dataset would
266      produce:
267        - /path/to/dir/b.py
268        - /path/to/dir/c.py
269
270    Args:
271      file_pattern: A string or scalar string `tf.Tensor`, representing
272        the filename pattern that will be matched.
273
274    Returns:
275     A `Dataset` of strings corresponding to file names.
276    """
277    return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern))
278
279  def repeat(self, count=None):
280    """Repeats this dataset `count` times.
281
282    Args:
283      count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
284        number of times the elements of this dataset should be repeated. The
285        default behavior (if `count` is `None` or `-1`) is for the elements to
286        be repeated indefinitely.
287
288    Returns:
289      A `Dataset`.
290    """
291    return Dataset(dataset_ops.RepeatDataset(self._dataset, count))
292
293  @deprecation.deprecated(
294      None, "Use `ds.apply(tf.contrib.data.enumerate_dataset())`.")
295  def enumerate(self, start=0):
296    """Deprecated: Use `Dataset.apply(tf.contrib.data.enumerate_dataset(..)`."""
297
298    return self.apply(enumerate_ops.enumerate_dataset(start))
299
300  def shuffle(self, buffer_size, seed=None):
301    """Randomly shuffles the elements of this dataset.
302
303    Args:
304      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
305        number of elements from this dataset from which the new
306        dataset will sample.
307      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
308        random seed that will be used to create the distribution. See
309        @{tf.set_random_seed} for behavior.
310
311    Returns:
312      A `Dataset`.
313    """
314    return Dataset(dataset_ops.ShuffleDataset(self._dataset, buffer_size, seed))
315
316  def cache(self, filename=""):
317    """Caches the elements in this dataset.
318
319    Args:
320      filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
321        directory on the filesystem to use for caching tensors in this Dataset.
322        If a filename is not provided, the dataset will be cached in memory.
323
324    Returns:
325      A `Dataset`.
326    """
327    return Dataset(dataset_ops.CacheDataset(self._dataset, filename))
328
329  def take(self, count):
330    """Creates a `Dataset` with at most `count` elements from this dataset.
331
332    Args:
333      count: A `tf.int64` scalar `tf.Tensor`, representing the number of
334        elements of this dataset that should be taken to form the new dataset.
335        If `count` is -1, or if `count` is greater than the size of this
336        dataset, the new dataset will contain all elements of this dataset.
337
338    Returns:
339      A `Dataset`.
340    """
341    return Dataset(dataset_ops.TakeDataset(self._dataset, count))
342
343  def skip(self, count):
344    """Creates a `Dataset` that skips `count` elements from this dataset.
345
346    Args:
347      count: A `tf.int64` scalar `tf.Tensor`, representing the number
348        of elements of this dataset that should be skipped to form the
349        new dataset.  If `count` is greater than the size of this
350        dataset, the new dataset will contain no elements.  If `count`
351        is -1, skips the entire dataset.
352
353    Returns:
354      A `Dataset`.
355    """
356    return Dataset(dataset_ops.SkipDataset(self._dataset, count))
357
358  def shard(self, num_shards, index):
359    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
360
361    This dataset operator is very useful when running distributed training, as
362    it allows each worker to read a unique subset.
363
364    When reading a single input file, you can skip elements as follows:
365
366    ```python
367    d = tf.data.TFRecordDataset(FLAGS.input_file)
368    d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
369    d = d.repeat(FLAGS.num_epochs)
370    d = d.shuffle(FLAGS.shuffle_buffer_size)
371    d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
372    ```
373
374    Important caveats:
375
376    - Be sure to shard before you use any randomizing operator (such as
377      shuffle).
378    - Generally it is best if the shard operator is used early in the dataset
379      pipeline. For example, when reading from a set of TFRecord files, shard
380      before converting the dataset to input samples. This avoids reading every
381      file on every worker. The following is an example of an efficient
382      sharding strategy within a complete pipeline:
383
384    ```python
385    d = tf.data.Dataset.list_files(FLAGS.pattern)
386    d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
387    d = d.repeat(FLAGS.num_epochs)
388    d = d.shuffle(FLAGS.shuffle_buffer_size)
389    d = d.interleave(tf.data.TFRecordDataset,
390                     cycle_length=FLAGS.num_readers, block_length=1)
391    d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
392    ```
393
394    Args:
395      num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
396        shards operating in parallel.
397      index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
398
399    Returns:
400      A `Dataset`.
401
402    Raises:
403      ValueError: if `num_shards` or `index` are illegal values. Note: error
404        checking is done on a best-effort basis, and aren't guaranteed to be
405        caught upon dataset creation. (e.g. providing in a placeholder tensor
406        bypasses the early checking, and will instead result in an error during
407        a session.run call.)
408    """
409    return Dataset(self._dataset.shard(num_shards, index))
410
411  @deprecation.deprecated(
412      None, "Use `ds.apply(tf.contrib.data.ignore_errors())`.")
413  def ignore_errors(self):
414    """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors())`."""
415
416    return self.apply(error_ops.ignore_errors())
417
418  def batch(self, batch_size):
419    """Combines consecutive elements of this dataset into batches.
420
421    Args:
422      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
423        consecutive elements of this dataset to combine in a single batch.
424
425    Returns:
426      A `Dataset`.
427    """
428    return Dataset(dataset_ops.BatchDataset(self._dataset, batch_size))
429
430  def padded_batch(self, batch_size, padded_shapes, padding_values=None):
431    """Combines consecutive elements of this dataset into padded batches.
432
433    Like `Dataset.dense_to_sparse_batch()`, this method combines
434    multiple consecutive elements of this dataset, which might have
435    different shapes, into a single element. The tensors in the
436    resulting element have an additional outer dimension, and are
437    padded to the respective shape in `padded_shapes`.
438
439    Args:
440      batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
441        consecutive elements of this dataset to combine in a single batch.
442      padded_shapes: A nested structure of `tf.TensorShape` or
443        `tf.int64` vector tensor-like objects representing the shape
444        to which the respective component of each input element should
445        be padded prior to batching. Any unknown dimensions
446        (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a
447        tensor-like object) will be padded to the maximum size of that
448        dimension in each batch.
449      padding_values: (Optional.) A nested structure of scalar-shaped
450        `tf.Tensor`, representing the padding values to use for the
451        respective components.  Defaults are `0` for numeric types and
452        the empty string for string types.
453
454    Returns:
455      A `Dataset`.
456    """
457    return Dataset(
458        dataset_ops.PaddedBatchDataset(self._dataset, batch_size, padded_shapes,
459                                       padding_values))
460
461  @deprecation.deprecated(
462      None, "Use `ds.apply(tf.contrib.data.dense_to_sparse_batch())`.")
463  def dense_to_sparse_batch(self, batch_size, row_shape):
464    """Use: `Dataset.apply(tf.contrib.data.dense_to_sparse_batch(...))`."""
465
466    return self.apply(batching.dense_to_sparse_batch(batch_size, row_shape))
467
468  @deprecation.deprecated(
469      None, "Use `ds.apply(tf.contrib.data.group_by_window())`.")
470  def group_by_window(self, key_func, reduce_func, window_size):
471    """Deprecated: Use `Dataset.apply(tf.contrib.data.group_by_window(...))`."""
472
473    return self.apply(
474        grouping.group_by_window(key_func, reduce_func, window_size))
475
476  @deprecation.deprecated_args(
477      None,
478      "`output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.",
479      "num_threads", "output_buffer_size")
480  def map(self,
481          map_func,
482          num_threads=None,
483          output_buffer_size=None,
484          num_parallel_calls=None):
485    """Maps `map_func` across this dataset.
486
487    Args:
488      map_func: A function mapping a nested structure of tensors (having
489        shapes and types defined by `self.output_shapes` and
490       `self.output_types`) to another nested structure of tensors.
491      num_threads: (Optional.) Deprecated, use `num_parallel_calls` instead.
492      output_buffer_size: (Optional.) A `tf.int64` scalar `tf.Tensor`,
493        representing the maximum number of processed elements that will be
494        buffered.
495      num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
496        representing the number elements to process in parallel. If not
497        specified, elements will be processed sequentially.
498
499    Returns:
500      A `Dataset`.
501    """
502    if num_threads is None and num_parallel_calls is None:
503      ret = Dataset(dataset_ops.MapDataset(self._dataset, map_func))
504    else:
505      if num_threads is None:
506        ret = Dataset(
507            dataset_ops.ParallelMapDataset(self._dataset, map_func,
508                                           num_parallel_calls))
509      else:
510        ret = Dataset(
511            dataset_ops.ParallelMapDataset(self._dataset, map_func,
512                                           num_threads))
513    if output_buffer_size is not None:
514      ret = ret.prefetch(output_buffer_size)
515    return ret
516
517  def flat_map(self, map_func):
518    """Maps `map_func` across this dataset and flattens the result.
519
520    Args:
521      map_func: A function mapping a nested structure of tensors (having shapes
522        and types defined by `self.output_shapes` and `self.output_types`) to a
523        `Dataset`.
524
525    Returns:
526      A `Dataset`.
527    """
528    return Dataset(dataset_ops.FlatMapDataset(self._dataset, map_func))
529
530  def interleave(self, map_func, cycle_length, block_length=1):
531    """Maps `map_func` across this dataset, and interleaves the results.
532
533    For example, you can use `Dataset.interleave()` to process many input files
534    concurrently:
535
536    ```python
537    # Preprocess 4 files concurrently, and interleave blocks of 16 records from
538    # each file.
539    filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
540    dataset = (Dataset.from_tensor_slices(filenames)
541               .interleave(lambda x:
542                   TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
543                   cycle_length=4, block_length=16))
544    ```
545
546    The `cycle_length` and `block_length` arguments control the order in which
547    elements are produced. `cycle_length` controls the number of input elements
548    that are processed concurrently. If you set `cycle_length` to 1, this
549    transformation will handle one input element at a time, and will produce
550    identical results = to @{tf.data.Dataset.flat_map}. In general,
551    this transformation will apply `map_func` to `cycle_length` input elements,
552    open iterators on the returned `Dataset` objects, and cycle through them
553    producing `block_length` consecutive elements from each iterator, and
554    consuming the next input element each time it reaches the end of an
555    iterator.
556
557    For example:
558
559    ```python
560    # NOTE: The following examples use `{ ... }` to represent the
561    # contents of a dataset.
562    a = { 1, 2, 3, 4, 5 }
563
564    # NOTE: New lines indicate "block" boundaries.
565    a.interleave(lambda x: Dataset.from_tensors(x).repeat(6),
566                 cycle_length=2, block_length=4) == {
567        1, 1, 1, 1,
568        2, 2, 2, 2,
569        1, 1,
570        2, 2,
571        3, 3, 3, 3,
572        4, 4, 4, 4,
573        3, 3,
574        4, 4,
575        5, 5, 5, 5,
576        5, 5,
577    }
578    ```
579
580    NOTE: The order of elements yielded by this transformation is
581    deterministic, as long as `map_func` is a pure function. If
582    `map_func` contains any stateful operations, the order in which
583    that state is accessed is undefined.
584
585    Args:
586      map_func: A function mapping a nested structure of tensors (having shapes
587        and types defined by `self.output_shapes` and `self.output_types`) to a
588        `Dataset`.
589      cycle_length: The number of elements from this dataset that will be
590        processed concurrently.
591      block_length: The number of consecutive elements to produce from each
592        input element before cycling to another input element.
593
594    Returns:
595      A `Dataset`.
596    """
597    return Dataset(
598        dataset_ops.InterleaveDataset(self._dataset, map_func, cycle_length,
599                                      block_length))
600
601  @deprecation.deprecated(None, "Use `ds.apply(tf.contrib.data.unbatch())`.")
602  def unbatch(self):
603    """Deprecated: Use `Dataset.apply(tf.contrib.data.unbatch()`."""
604
605    return self.apply(batching.unbatch())
606
607  def filter(self, predicate):
608    """Filters this dataset according to `predicate`.
609
610    Args:
611      predicate: A function mapping a nested structure of tensors (having shapes
612        and types defined by `self.output_shapes` and `self.output_types`) to a
613        scalar `tf.bool` tensor.
614
615    Returns:
616      A `Dataset`.
617    """
618    return Dataset(dataset_ops.FilterDataset(self._dataset, predicate))
619
620  def apply(self, transformation_func):
621    """Apply a transformation function to this dataset.
622
623    `apply` enables chaining of custom `Dataset` transformations, which are
624    represented as functions that take one `Dataset` argument and return a
625    transformed `Dataset`.
626
627    For example:
628
629    ```
630    dataset = (dataset.map(lambda x: x ** 2)
631               .(group_by_window(key_func, reduce_func, window_size))
632               .map(lambda x: x ** 3))
633    ```
634
635    Args:
636      transformation_func: A function that takes one `Dataset` argument and
637        returns a `Dataset`.
638
639    Returns:
640      The `Dataset` returned by applying `transformation_func` to this dataset.
641    """
642    dataset = transformation_func(self)
643    if not isinstance(dataset, dataset_ops.Dataset):
644      raise TypeError("`transformation_func` must return a Dataset.")
645    return Dataset(dataset)
646
647
648def get_single_element(dataset):
649  """Returns the single element in `dataset` as a nested structure of tensors.
650
651  This function enables you to use a @{tf.data.Dataset} in a stateless
652  "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}.
653  This can be useful when your preprocessing transformations are expressed
654  as a `Dataset`, and you want to use the transformation at serving time.
655  For example:
656
657  ```python
658  input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE])
659
660  def preprocessing_fn(input_str):
661    # ...
662    return image, label
663
664  dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
665             .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
666             .batch(BATCH_SIZE))
667
668  image_batch, label_batch = tf.contrib.data.get_single_element(dataset)
669  ```
670
671  Args:
672    dataset: A @{tf.data.Dataset} object containing a single element.
673
674  Returns:
675    A nested structure of @{tf.Tensor} objects, corresponding to the single
676    element of `dataset`.
677
678  Raises:
679    TypeError: if `dataset` is not a `tf.data.Dataset` object.
680    InvalidArgumentError (at runtime): if `dataset` does not contain exactly
681      one element.
682  """
683  if not isinstance(dataset, dataset_ops.Dataset):
684    raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
685  return nest.pack_sequence_as(
686      dataset.output_types,
687      gen_dataset_ops.dataset_to_single_element(
688          dataset._as_variant_tensor(),  # pylint: disable=protected-access
689          output_types=nest.flatten(dataset.output_types),
690          output_shapes=nest.flatten(dataset.output_shapes)))
691