1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for utilities working with arbitrarily nested structures.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.data.util import nest 22from tensorflow.python.data.util import sparse 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.platform import test 29 30 31class SparseTest(test.TestCase): 32 33 def testAnySparse(self): 34 test_cases = ( 35 { 36 "classes": (), 37 "expected": False 38 }, 39 { 40 "classes": (ops.Tensor), 41 "expected": False 42 }, 43 { 44 "classes": (((ops.Tensor))), 45 "expected": False 46 }, 47 { 48 "classes": (ops.Tensor, ops.Tensor), 49 "expected": False 50 }, 51 { 52 "classes": (ops.Tensor, sparse_tensor.SparseTensor), 53 "expected": True 54 }, 55 { 56 "classes": (sparse_tensor.SparseTensor, sparse_tensor.SparseTensor), 57 "expected": 58 True 59 }, 60 { 61 "classes": (sparse_tensor.SparseTensor, ops.Tensor), 62 "expected": True 63 }, 64 { 65 "classes": (((sparse_tensor.SparseTensor))), 66 "expected": True 67 }, 68 ) 69 for test_case in test_cases: 70 self.assertEqual( 71 sparse.any_sparse(test_case["classes"]), test_case["expected"]) 72 73 def assertShapesEqual(self, a, b): 74 for a, b in zip(nest.flatten(a), nest.flatten(b)): 75 self.assertEqual(a.ndims, b.ndims) 76 if a.ndims is None: 77 continue 78 for c, d in zip(a.as_list(), b.as_list()): 79 self.assertEqual(c, d) 80 81 def testAsDenseShapes(self): 82 test_cases = ( 83 { 84 "types": (), 85 "classes": (), 86 "expected": () 87 }, 88 { 89 "types": tensor_shape.scalar(), 90 "classes": ops.Tensor, 91 "expected": tensor_shape.scalar() 92 }, 93 { 94 "types": tensor_shape.scalar(), 95 "classes": sparse_tensor.SparseTensor, 96 "expected": tensor_shape.unknown_shape() 97 }, 98 { 99 "types": (tensor_shape.scalar()), 100 "classes": (ops.Tensor), 101 "expected": (tensor_shape.scalar()) 102 }, 103 { 104 "types": (tensor_shape.scalar()), 105 "classes": (sparse_tensor.SparseTensor), 106 "expected": (tensor_shape.unknown_shape()) 107 }, 108 { 109 "types": (tensor_shape.scalar(), ()), 110 "classes": (ops.Tensor, ()), 111 "expected": (tensor_shape.scalar(), ()) 112 }, 113 { 114 "types": ((), tensor_shape.scalar()), 115 "classes": ((), ops.Tensor), 116 "expected": ((), tensor_shape.scalar()) 117 }, 118 { 119 "types": (tensor_shape.scalar(), ()), 120 "classes": (sparse_tensor.SparseTensor, ()), 121 "expected": (tensor_shape.unknown_shape(), ()) 122 }, 123 { 124 "types": ((), tensor_shape.scalar()), 125 "classes": ((), sparse_tensor.SparseTensor), 126 "expected": ((), tensor_shape.unknown_shape()) 127 }, 128 { 129 "types": (tensor_shape.scalar(), (), tensor_shape.scalar()), 130 "classes": (ops.Tensor, (), ops.Tensor), 131 "expected": (tensor_shape.scalar(), (), tensor_shape.scalar()) 132 }, 133 { 134 "types": (tensor_shape.scalar(), (), tensor_shape.scalar()), 135 "classes": (sparse_tensor.SparseTensor, (), 136 sparse_tensor.SparseTensor), 137 "expected": (tensor_shape.unknown_shape(), (), 138 tensor_shape.unknown_shape()) 139 }, 140 { 141 "types": ((), tensor_shape.scalar(), ()), 142 "classes": ((), ops.Tensor, ()), 143 "expected": ((), tensor_shape.scalar(), ()) 144 }, 145 { 146 "types": ((), tensor_shape.scalar(), ()), 147 "classes": ((), sparse_tensor.SparseTensor, ()), 148 "expected": ((), tensor_shape.unknown_shape(), ()) 149 }, 150 ) 151 for test_case in test_cases: 152 self.assertShapesEqual( 153 sparse.as_dense_shapes(test_case["types"], test_case["classes"]), 154 test_case["expected"]) 155 156 def testAsDenseTypes(self): 157 test_cases = ( 158 { 159 "types": (), 160 "classes": (), 161 "expected": () 162 }, 163 { 164 "types": dtypes.int32, 165 "classes": ops.Tensor, 166 "expected": dtypes.int32 167 }, 168 { 169 "types": dtypes.int32, 170 "classes": sparse_tensor.SparseTensor, 171 "expected": dtypes.variant 172 }, 173 { 174 "types": (dtypes.int32), 175 "classes": (ops.Tensor), 176 "expected": (dtypes.int32) 177 }, 178 { 179 "types": (dtypes.int32), 180 "classes": (sparse_tensor.SparseTensor), 181 "expected": (dtypes.variant) 182 }, 183 { 184 "types": (dtypes.int32, ()), 185 "classes": (ops.Tensor, ()), 186 "expected": (dtypes.int32, ()) 187 }, 188 { 189 "types": ((), dtypes.int32), 190 "classes": ((), ops.Tensor), 191 "expected": ((), dtypes.int32) 192 }, 193 { 194 "types": (dtypes.int32, ()), 195 "classes": (sparse_tensor.SparseTensor, ()), 196 "expected": (dtypes.variant, ()) 197 }, 198 { 199 "types": ((), dtypes.int32), 200 "classes": ((), sparse_tensor.SparseTensor), 201 "expected": ((), dtypes.variant) 202 }, 203 { 204 "types": (dtypes.int32, (), dtypes.int32), 205 "classes": (ops.Tensor, (), ops.Tensor), 206 "expected": (dtypes.int32, (), dtypes.int32) 207 }, 208 { 209 "types": (dtypes.int32, (), dtypes.int32), 210 "classes": (sparse_tensor.SparseTensor, (), 211 sparse_tensor.SparseTensor), 212 "expected": (dtypes.variant, (), dtypes.variant) 213 }, 214 { 215 "types": ((), dtypes.int32, ()), 216 "classes": ((), ops.Tensor, ()), 217 "expected": ((), dtypes.int32, ()) 218 }, 219 { 220 "types": ((), dtypes.int32, ()), 221 "classes": ((), sparse_tensor.SparseTensor, ()), 222 "expected": ((), dtypes.variant, ()) 223 }, 224 ) 225 for test_case in test_cases: 226 self.assertEqual( 227 sparse.as_dense_types(test_case["types"], test_case["classes"]), 228 test_case["expected"]) 229 230 def testGetClasses(self): 231 s = sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]) 232 d = ops.Tensor 233 t = sparse_tensor.SparseTensor 234 test_cases = ( 235 { 236 "classes": (), 237 "expected": () 238 }, 239 { 240 "classes": s, 241 "expected": t 242 }, 243 { 244 "classes": constant_op.constant([1]), 245 "expected": d 246 }, 247 { 248 "classes": (s), 249 "expected": (t) 250 }, 251 { 252 "classes": (constant_op.constant([1])), 253 "expected": (d) 254 }, 255 { 256 "classes": (s, ()), 257 "expected": (t, ()) 258 }, 259 { 260 "classes": ((), s), 261 "expected": ((), t) 262 }, 263 { 264 "classes": (constant_op.constant([1]), ()), 265 "expected": (d, ()) 266 }, 267 { 268 "classes": ((), constant_op.constant([1])), 269 "expected": ((), d) 270 }, 271 { 272 "classes": (s, (), constant_op.constant([1])), 273 "expected": (t, (), d) 274 }, 275 { 276 "classes": ((), s, ()), 277 "expected": ((), t, ()) 278 }, 279 { 280 "classes": ((), constant_op.constant([1]), ()), 281 "expected": ((), d, ()) 282 }, 283 ) 284 for test_case in test_cases: 285 self.assertEqual( 286 sparse.get_classes(test_case["classes"]), test_case["expected"]) 287 288 def assertSparseValuesEqual(self, a, b): 289 if not isinstance(a, sparse_tensor.SparseTensor): 290 self.assertFalse(isinstance(b, sparse_tensor.SparseTensor)) 291 self.assertEqual(a, b) 292 return 293 self.assertTrue(isinstance(b, sparse_tensor.SparseTensor)) 294 with self.test_session(): 295 self.assertAllEqual(a.eval().indices, b.eval().indices) 296 self.assertAllEqual(a.eval().values, b.eval().values) 297 self.assertAllEqual(a.eval().dense_shape, b.eval().dense_shape) 298 299 def testSerializeDeserialize(self): 300 test_cases = ( 301 (), 302 sparse_tensor.SparseTensor( 303 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 304 sparse_tensor.SparseTensor( 305 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 306 sparse_tensor.SparseTensor( 307 indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), 308 (sparse_tensor.SparseTensor( 309 indices=[[0, 0]], values=[1], dense_shape=[1, 1])), 310 (sparse_tensor.SparseTensor( 311 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), 312 ((), 313 sparse_tensor.SparseTensor( 314 indices=[[0, 0]], values=[1], dense_shape=[1, 1])), 315 ) 316 for expected in test_cases: 317 classes = sparse.get_classes(expected) 318 shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), 319 classes) 320 types = nest.map_structure(lambda _: dtypes.int32, classes) 321 actual = sparse.deserialize_sparse_tensors( 322 sparse.serialize_sparse_tensors(expected), types, shapes, 323 sparse.get_classes(expected)) 324 nest.assert_same_structure(expected, actual) 325 for a, e in zip(nest.flatten(actual), nest.flatten(expected)): 326 self.assertSparseValuesEqual(a, e) 327 328 def testSerializeManyDeserialize(self): 329 test_cases = ( 330 (), 331 sparse_tensor.SparseTensor( 332 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 333 sparse_tensor.SparseTensor( 334 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 335 sparse_tensor.SparseTensor( 336 indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), 337 (sparse_tensor.SparseTensor( 338 indices=[[0, 0]], values=[1], dense_shape=[1, 1])), 339 (sparse_tensor.SparseTensor( 340 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), 341 ((), 342 sparse_tensor.SparseTensor( 343 indices=[[0, 0]], values=[1], dense_shape=[1, 1])), 344 ) 345 for expected in test_cases: 346 classes = sparse.get_classes(expected) 347 shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), 348 classes) 349 types = nest.map_structure(lambda _: dtypes.int32, classes) 350 actual = sparse.deserialize_sparse_tensors( 351 sparse.serialize_many_sparse_tensors(expected), types, shapes, 352 sparse.get_classes(expected)) 353 nest.assert_same_structure(expected, actual) 354 for a, e in zip(nest.flatten(actual), nest.flatten(expected)): 355 self.assertSparseValuesEqual(a, e) 356 357 358if __name__ == "__main__": 359 test.main() 360