1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the experimental input pipeline ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.ops import math_ops
26from tensorflow.python.platform import test
27
28
29class FilterDatasetSerializationTest(
30    dataset_serialization_test_base.DatasetSerializationTestBase):
31
32  def _build_filter_range_graph(self, div):
33    return dataset_ops.Dataset.range(100).filter(
34        lambda x: math_ops.not_equal(math_ops.mod(x, div), 2))
35
36  def testFilterCore(self):
37    div = 3
38    num_outputs = np.sum([x % 3 is not 2 for x in range(100)])
39    self.run_core_tests(lambda: self._build_filter_range_graph(div),
40                        lambda: self._build_filter_range_graph(div * 2),
41                        num_outputs)
42
43  def _build_filter_dict_graph(self):
44    return dataset_ops.Dataset.range(10).map(
45        lambda x: {"foo": x * 2, "bar": x ** 2}).filter(
46            lambda d: math_ops.equal(d["bar"] % 2, 0)).map(
47                lambda d: d["foo"] + d["bar"])
48
49  def testFilterDictCore(self):
50    num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)])
51    self.run_core_tests(self._build_filter_dict_graph, None, num_outputs)
52
53  def _build_sparse_filter(self):
54
55    def _map_fn(i):
56      return sparse_tensor.SparseTensor(
57          indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
58
59    def _filter_fn(_, i):
60      return math_ops.equal(i % 2, 0)
61
62    return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map(
63        lambda x, i: x)
64
65  def testSparseCore(self):
66    num_outputs = 5
67    self.run_core_tests(self._build_sparse_filter, None, num_outputs)
68
69
70if __name__ == "__main__":
71  test.main()
72