1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the experimental input pipeline ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.ops import math_ops 26from tensorflow.python.platform import test 27 28 29class FilterDatasetSerializationTest( 30 dataset_serialization_test_base.DatasetSerializationTestBase): 31 32 def _build_filter_range_graph(self, div): 33 return dataset_ops.Dataset.range(100).filter( 34 lambda x: math_ops.not_equal(math_ops.mod(x, div), 2)) 35 36 def testFilterCore(self): 37 div = 3 38 num_outputs = np.sum([x % 3 is not 2 for x in range(100)]) 39 self.run_core_tests(lambda: self._build_filter_range_graph(div), 40 lambda: self._build_filter_range_graph(div * 2), 41 num_outputs) 42 43 def _build_filter_dict_graph(self): 44 return dataset_ops.Dataset.range(10).map( 45 lambda x: {"foo": x * 2, "bar": x ** 2}).filter( 46 lambda d: math_ops.equal(d["bar"] % 2, 0)).map( 47 lambda d: d["foo"] + d["bar"]) 48 49 def testFilterDictCore(self): 50 num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)]) 51 self.run_core_tests(self._build_filter_dict_graph, None, num_outputs) 52 53 def _build_sparse_filter(self): 54 55 def _map_fn(i): 56 return sparse_tensor.SparseTensor( 57 indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i 58 59 def _filter_fn(_, i): 60 return math_ops.equal(i % 2, 0) 61 62 return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( 63 lambda x, i: x) 64 65 def testSparseCore(self): 66 num_outputs = 5 67 self.run_core_tests(self._build_sparse_filter, None, num_outputs) 68 69 70if __name__ == "__main__": 71 test.main() 72