error_ops.py revision b0b4b608dcc68a9efeaa325e069275bae0de045d
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ignore_errors dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.util import nest
22from tensorflow.python.ops import gen_dataset_ops
23
24
25def ignore_errors():
26  """Creates a `Dataset` from another `Dataset` and silently ignores any errors.
27
28  Use this transformation to produce a dataset that contains the same elements
29  as the input, but silently drops any elements that caused an error. For
30  example:
31
32  ```python
33  dataset = tf.contrib.data.Dataset.from_tensor_slices([1., 2., 0., 4.])
34
35  # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError.
36  dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error"))
37
38  # Using `ignore_errors()` will drop the element that causes an error.
39  dataset =
40      dataset.apply(tf.contrib.data.ignore_errors())  # ==> { 1., 0.5, 0.2 }
41  ```
42
43  Returns:
44    A `Dataset` transformation function, which can be passed to
45    @{tf.contrib.data.Dataset.apply}.
46  """
47
48  def _apply_fn(dataset):
49    return IgnoreErrorsDataset(dataset)
50
51  return _apply_fn
52
53
54class IgnoreErrorsDataset(dataset_ops.Dataset):
55  """A `Dataset` that silently ignores errors when computing its input."""
56
57  def __init__(self, input_dataset):
58    """See `Dataset.ignore_errors()` for details."""
59    super(IgnoreErrorsDataset, self).__init__()
60    self._input_dataset = input_dataset
61
62  def _as_variant_tensor(self):
63    return gen_dataset_ops.ignore_errors_dataset(
64        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
65        output_shapes=nest.flatten(self.output_shapes),
66        output_types=nest.flatten(self.output_types))
67
68  @property
69  def output_shapes(self):
70    return self._input_dataset.output_shapes
71
72  @property
73  def output_types(self):
74    return self._input_dataset.output_types
75