1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ignore_errors dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.contrib.data.python.ops import contrib_op_loader  # pylint: disable=unused-import
21from tensorflow.contrib.data.python.ops import gen_dataset_ops
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.util import nest
24from tensorflow.python.data.util import sparse
25
26
27def ignore_errors():
28  """Creates a `Dataset` from another `Dataset` and silently ignores any errors.
29
30  Use this transformation to produce a dataset that contains the same elements
31  as the input, but silently drops any elements that caused an error. For
32  example:
33
34  ```python
35  dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.])
36
37  # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError.
38  dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error"))
39
40  # Using `ignore_errors()` will drop the element that causes an error.
41  dataset =
42      dataset.apply(tf.contrib.data.ignore_errors())  # ==> { 1., 0.5, 0.2 }
43  ```
44
45  Returns:
46    A `Dataset` transformation function, which can be passed to
47    @{tf.data.Dataset.apply}.
48  """
49
50  def _apply_fn(dataset):
51    return IgnoreErrorsDataset(dataset)
52
53  return _apply_fn
54
55
56class IgnoreErrorsDataset(dataset_ops.Dataset):
57  """A `Dataset` that silently ignores errors when computing its input."""
58
59  def __init__(self, input_dataset):
60    """See `Dataset.ignore_errors()` for details."""
61    super(IgnoreErrorsDataset, self).__init__()
62    self._input_dataset = input_dataset
63
64  def _as_variant_tensor(self):
65    return gen_dataset_ops.ignore_errors_dataset(
66        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
67        output_shapes=nest.flatten(
68            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
69        output_types=nest.flatten(
70            sparse.as_dense_types(self.output_types, self.output_classes)))
71
72  @property
73  def output_classes(self):
74    return self._input_dataset.output_classes
75
76  @property
77  def output_shapes(self):
78    return self._input_dataset.output_shapes
79
80  @property
81  def output_types(self):
82    return self._input_dataset.output_types
83