error_ops.py revision 18135df3a56d0fb0a4f8e93d7b8332e4de3283e2
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ignore_errors dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.util import nest 22from tensorflow.python.ops import gen_dataset_ops 23 24 25def ignore_errors(): 26 """Creates a `Dataset` from another `Dataset` and silently ignores any errors. 27 28 Use this transformation to produce a dataset that contains the same elements 29 as the input, but silently drops any elements that caused an error. For 30 example: 31 32 ```python 33 dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) 34 35 # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError. 36 dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error")) 37 38 # Using `ignore_errors()` will drop the element that causes an error. 39 dataset = 40 dataset.apply(tf.contrib.data.ignore_errors()) # ==> { 1., 0.5, 0.2 } 41 ``` 42 43 Returns: 44 A `Dataset` transformation function, which can be passed to 45 @{tf.data.Dataset.apply}. 46 """ 47 48 def _apply_fn(dataset): 49 return IgnoreErrorsDataset(dataset) 50 51 return _apply_fn 52 53 54class IgnoreErrorsDataset(dataset_ops.Dataset): 55 """A `Dataset` that silently ignores errors when computing its input.""" 56 57 def __init__(self, input_dataset): 58 """See `Dataset.ignore_errors()` for details.""" 59 super(IgnoreErrorsDataset, self).__init__() 60 self._input_dataset = input_dataset 61 62 def _as_variant_tensor(self): 63 return gen_dataset_ops.ignore_errors_dataset( 64 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 65 output_shapes=nest.flatten(self.output_shapes), 66 output_types=nest.flatten(self.output_types)) 67 68 @property 69 def output_shapes(self): 70 return self._input_dataset.output_shapes 71 72 @property 73 def output_types(self): 74 return self._input_dataset.output_types 75