1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for utilities working with arbitrarily nested structures."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.data.util import nest
22from tensorflow.python.data.util import sparse
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.platform import test
29
30
31class SparseTest(test.TestCase):
32
33  def testAnySparse(self):
34    test_cases = (
35        {
36            "classes": (),
37            "expected": False
38        },
39        {
40            "classes": (ops.Tensor),
41            "expected": False
42        },
43        {
44            "classes": (((ops.Tensor))),
45            "expected": False
46        },
47        {
48            "classes": (ops.Tensor, ops.Tensor),
49            "expected": False
50        },
51        {
52            "classes": (ops.Tensor, sparse_tensor.SparseTensor),
53            "expected": True
54        },
55        {
56            "classes": (sparse_tensor.SparseTensor, sparse_tensor.SparseTensor),
57            "expected":
58                True
59        },
60        {
61            "classes": (sparse_tensor.SparseTensor, ops.Tensor),
62            "expected": True
63        },
64        {
65            "classes": (((sparse_tensor.SparseTensor))),
66            "expected": True
67        },
68    )
69    for test_case in test_cases:
70      self.assertEqual(
71          sparse.any_sparse(test_case["classes"]), test_case["expected"])
72
73  def assertShapesEqual(self, a, b):
74    for a, b in zip(nest.flatten(a), nest.flatten(b)):
75      self.assertEqual(a.ndims, b.ndims)
76      if a.ndims is None:
77        continue
78      for c, d in zip(a.as_list(), b.as_list()):
79        self.assertEqual(c, d)
80
81  def testAsDenseShapes(self):
82    test_cases = (
83        {
84            "types": (),
85            "classes": (),
86            "expected": ()
87        },
88        {
89            "types": tensor_shape.scalar(),
90            "classes": ops.Tensor,
91            "expected": tensor_shape.scalar()
92        },
93        {
94            "types": tensor_shape.scalar(),
95            "classes": sparse_tensor.SparseTensor,
96            "expected": tensor_shape.unknown_shape()
97        },
98        {
99            "types": (tensor_shape.scalar()),
100            "classes": (ops.Tensor),
101            "expected": (tensor_shape.scalar())
102        },
103        {
104            "types": (tensor_shape.scalar()),
105            "classes": (sparse_tensor.SparseTensor),
106            "expected": (tensor_shape.unknown_shape())
107        },
108        {
109            "types": (tensor_shape.scalar(), ()),
110            "classes": (ops.Tensor, ()),
111            "expected": (tensor_shape.scalar(), ())
112        },
113        {
114            "types": ((), tensor_shape.scalar()),
115            "classes": ((), ops.Tensor),
116            "expected": ((), tensor_shape.scalar())
117        },
118        {
119            "types": (tensor_shape.scalar(), ()),
120            "classes": (sparse_tensor.SparseTensor, ()),
121            "expected": (tensor_shape.unknown_shape(), ())
122        },
123        {
124            "types": ((), tensor_shape.scalar()),
125            "classes": ((), sparse_tensor.SparseTensor),
126            "expected": ((), tensor_shape.unknown_shape())
127        },
128        {
129            "types": (tensor_shape.scalar(), (), tensor_shape.scalar()),
130            "classes": (ops.Tensor, (), ops.Tensor),
131            "expected": (tensor_shape.scalar(), (), tensor_shape.scalar())
132        },
133        {
134            "types": (tensor_shape.scalar(), (), tensor_shape.scalar()),
135            "classes": (sparse_tensor.SparseTensor, (),
136                        sparse_tensor.SparseTensor),
137            "expected": (tensor_shape.unknown_shape(), (),
138                         tensor_shape.unknown_shape())
139        },
140        {
141            "types": ((), tensor_shape.scalar(), ()),
142            "classes": ((), ops.Tensor, ()),
143            "expected": ((), tensor_shape.scalar(), ())
144        },
145        {
146            "types": ((), tensor_shape.scalar(), ()),
147            "classes": ((), sparse_tensor.SparseTensor, ()),
148            "expected": ((), tensor_shape.unknown_shape(), ())
149        },
150    )
151    for test_case in test_cases:
152      self.assertShapesEqual(
153          sparse.as_dense_shapes(test_case["types"], test_case["classes"]),
154          test_case["expected"])
155
156  def testAsDenseTypes(self):
157    test_cases = (
158        {
159            "types": (),
160            "classes": (),
161            "expected": ()
162        },
163        {
164            "types": dtypes.int32,
165            "classes": ops.Tensor,
166            "expected": dtypes.int32
167        },
168        {
169            "types": dtypes.int32,
170            "classes": sparse_tensor.SparseTensor,
171            "expected": dtypes.variant
172        },
173        {
174            "types": (dtypes.int32),
175            "classes": (ops.Tensor),
176            "expected": (dtypes.int32)
177        },
178        {
179            "types": (dtypes.int32),
180            "classes": (sparse_tensor.SparseTensor),
181            "expected": (dtypes.variant)
182        },
183        {
184            "types": (dtypes.int32, ()),
185            "classes": (ops.Tensor, ()),
186            "expected": (dtypes.int32, ())
187        },
188        {
189            "types": ((), dtypes.int32),
190            "classes": ((), ops.Tensor),
191            "expected": ((), dtypes.int32)
192        },
193        {
194            "types": (dtypes.int32, ()),
195            "classes": (sparse_tensor.SparseTensor, ()),
196            "expected": (dtypes.variant, ())
197        },
198        {
199            "types": ((), dtypes.int32),
200            "classes": ((), sparse_tensor.SparseTensor),
201            "expected": ((), dtypes.variant)
202        },
203        {
204            "types": (dtypes.int32, (), dtypes.int32),
205            "classes": (ops.Tensor, (), ops.Tensor),
206            "expected": (dtypes.int32, (), dtypes.int32)
207        },
208        {
209            "types": (dtypes.int32, (), dtypes.int32),
210            "classes": (sparse_tensor.SparseTensor, (),
211                        sparse_tensor.SparseTensor),
212            "expected": (dtypes.variant, (), dtypes.variant)
213        },
214        {
215            "types": ((), dtypes.int32, ()),
216            "classes": ((), ops.Tensor, ()),
217            "expected": ((), dtypes.int32, ())
218        },
219        {
220            "types": ((), dtypes.int32, ()),
221            "classes": ((), sparse_tensor.SparseTensor, ()),
222            "expected": ((), dtypes.variant, ())
223        },
224    )
225    for test_case in test_cases:
226      self.assertEqual(
227          sparse.as_dense_types(test_case["types"], test_case["classes"]),
228          test_case["expected"])
229
230  def testGetClasses(self):
231    s = sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1])
232    d = ops.Tensor
233    t = sparse_tensor.SparseTensor
234    test_cases = (
235        {
236            "classes": (),
237            "expected": ()
238        },
239        {
240            "classes": s,
241            "expected": t
242        },
243        {
244            "classes": constant_op.constant([1]),
245            "expected": d
246        },
247        {
248            "classes": (s),
249            "expected": (t)
250        },
251        {
252            "classes": (constant_op.constant([1])),
253            "expected": (d)
254        },
255        {
256            "classes": (s, ()),
257            "expected": (t, ())
258        },
259        {
260            "classes": ((), s),
261            "expected": ((), t)
262        },
263        {
264            "classes": (constant_op.constant([1]), ()),
265            "expected": (d, ())
266        },
267        {
268            "classes": ((), constant_op.constant([1])),
269            "expected": ((), d)
270        },
271        {
272            "classes": (s, (), constant_op.constant([1])),
273            "expected": (t, (), d)
274        },
275        {
276            "classes": ((), s, ()),
277            "expected": ((), t, ())
278        },
279        {
280            "classes": ((), constant_op.constant([1]), ()),
281            "expected": ((), d, ())
282        },
283    )
284    for test_case in test_cases:
285      self.assertEqual(
286          sparse.get_classes(test_case["classes"]), test_case["expected"])
287
288  def assertSparseValuesEqual(self, a, b):
289    if not isinstance(a, sparse_tensor.SparseTensor):
290      self.assertFalse(isinstance(b, sparse_tensor.SparseTensor))
291      self.assertEqual(a, b)
292      return
293    self.assertTrue(isinstance(b, sparse_tensor.SparseTensor))
294    with self.test_session():
295      self.assertAllEqual(a.eval().indices, b.eval().indices)
296      self.assertAllEqual(a.eval().values, b.eval().values)
297      self.assertAllEqual(a.eval().dense_shape, b.eval().dense_shape)
298
299  def testSerializeDeserialize(self):
300    test_cases = (
301        (),
302        sparse_tensor.SparseTensor(
303            indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
304        sparse_tensor.SparseTensor(
305            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
306        sparse_tensor.SparseTensor(
307            indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
308        (sparse_tensor.SparseTensor(
309            indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
310        (sparse_tensor.SparseTensor(
311            indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
312        ((),
313         sparse_tensor.SparseTensor(
314             indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
315    )
316    for expected in test_cases:
317      classes = sparse.get_classes(expected)
318      shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
319                                  classes)
320      types = nest.map_structure(lambda _: dtypes.int32, classes)
321      actual = sparse.deserialize_sparse_tensors(
322          sparse.serialize_sparse_tensors(expected), types, shapes,
323          sparse.get_classes(expected))
324      nest.assert_same_structure(expected, actual)
325      for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
326        self.assertSparseValuesEqual(a, e)
327
328  def testSerializeManyDeserialize(self):
329    test_cases = (
330        (),
331        sparse_tensor.SparseTensor(
332            indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
333        sparse_tensor.SparseTensor(
334            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
335        sparse_tensor.SparseTensor(
336            indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
337        (sparse_tensor.SparseTensor(
338            indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
339        (sparse_tensor.SparseTensor(
340            indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
341        ((),
342         sparse_tensor.SparseTensor(
343             indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
344    )
345    for expected in test_cases:
346      classes = sparse.get_classes(expected)
347      shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
348                                  classes)
349      types = nest.map_structure(lambda _: dtypes.int32, classes)
350      actual = sparse.deserialize_sparse_tensors(
351          sparse.serialize_many_sparse_tensors(expected), types, shapes,
352          sparse.get_classes(expected))
353      nest.assert_same_structure(expected, actual)
354      for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
355        self.assertSparseValuesEqual(a, e)
356
357
358if __name__ == "__main__":
359  test.main()
360