1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for set_ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import sets 31from tensorflow.python.ops import sparse_ops 32from tensorflow.python.platform import googletest 33 34_DTYPES = set([ 35 dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, 36 dtypes.uint16, dtypes.string 37]) 38 39 40def _values(values, dtype): 41 return np.array( 42 values, 43 dtype=(np.unicode if (dtype == dtypes.string) else dtype.as_numpy_dtype)) 44 45 46def _constant(values, dtype): 47 return constant_op.constant(_values(values, dtype), dtype=dtype) 48 49 50def _dense_to_sparse(dense, dtype): 51 indices = [] 52 values = [] 53 max_row_len = 0 54 for row in dense: 55 max_row_len = max(max_row_len, len(row)) 56 shape = [len(dense), max_row_len] 57 row_ix = 0 58 for row in dense: 59 col_ix = 0 60 for cell in row: 61 indices.append([row_ix, col_ix]) 62 values.append(str(cell) if dtype == dtypes.string else cell) 63 col_ix += 1 64 row_ix += 1 65 return sparse_tensor_lib.SparseTensor( 66 constant_op.constant(indices, dtypes.int64), 67 constant_op.constant(values, dtype), 68 constant_op.constant(shape, dtypes.int64)) 69 70 71class SetOpsTest(test_util.TensorFlowTestCase): 72 73 def test_set_size_2d(self): 74 for dtype in _DTYPES: 75 self._test_set_size_2d(dtype) 76 77 def _test_set_size_2d(self, dtype): 78 self.assertAllEqual([1], self._set_size(_dense_to_sparse([[1]], dtype))) 79 self.assertAllEqual([2, 1], 80 self._set_size(_dense_to_sparse([[1, 9], [1]], dtype))) 81 self.assertAllEqual( 82 [3, 0], self._set_size(_dense_to_sparse([[1, 9, 2], []], dtype))) 83 self.assertAllEqual( 84 [0, 3], self._set_size(_dense_to_sparse([[], [1, 9, 2]], dtype))) 85 86 def test_set_size_duplicates_2d(self): 87 for dtype in _DTYPES: 88 self._test_set_size_duplicates_2d(dtype) 89 90 def _test_set_size_duplicates_2d(self, dtype): 91 self.assertAllEqual( 92 [1], self._set_size(_dense_to_sparse([[1, 1, 1, 1, 1, 1]], dtype))) 93 self.assertAllEqual([2, 7, 3, 0, 1], 94 self._set_size( 95 _dense_to_sparse([[1, 9], [ 96 6, 7, 8, 8, 6, 7, 5, 3, 3, 0, 6, 6, 9, 0, 0, 0 97 ], [999, 1, -1000], [], [-1]], dtype))) 98 99 def test_set_size_3d(self): 100 for dtype in _DTYPES: 101 self._test_set_size_3d(dtype) 102 103 def test_set_size_3d_invalid_indices(self): 104 for dtype in _DTYPES: 105 self._test_set_size_3d(dtype, invalid_indices=True) 106 107 def _test_set_size_3d(self, dtype, invalid_indices=False): 108 if invalid_indices: 109 indices = constant_op.constant([ 110 [0, 1, 0], [0, 1, 1], # 0,1 111 [1, 0, 0], # 1,0 112 [1, 1, 0], [1, 1, 1], [1, 1, 2], # 1,1 113 [0, 0, 0], [0, 0, 2], # 0,0 114 # 2,0 115 [2, 1, 1] # 2,1 116 ], dtypes.int64) 117 else: 118 indices = constant_op.constant([ 119 [0, 0, 0], [0, 0, 2], # 0,0 120 [0, 1, 0], [0, 1, 1], # 0,1 121 [1, 0, 0], # 1,0 122 [1, 1, 0], [1, 1, 1], [1, 1, 2], # 1,1 123 # 2,0 124 [2, 1, 1] # 2,1 125 ], dtypes.int64) 126 127 sp = sparse_tensor_lib.SparseTensor( 128 indices, 129 _constant([ 130 1, 9, # 0,0 131 3, 3, # 0,1 132 1, # 1,0 133 9, 7, 8, # 1,1 134 # 2,0 135 5 # 2,1 136 ], dtype), 137 constant_op.constant([3, 2, 3], dtypes.int64)) 138 139 if invalid_indices: 140 with self.assertRaisesRegexp(errors_impl.OpError, "out of order"): 141 self._set_size(sp) 142 else: 143 self.assertAllEqual([ 144 [2, # 0,0 145 1], # 0,1 146 [1, # 1,0 147 3], # 1,1 148 [0, # 2,0 149 1] # 2,1 150 ], self._set_size(sp)) 151 152 def _set_size(self, sparse_data): 153 # Validate that we get the same results with or without `validate_indices`. 154 ops = [ 155 sets.set_size(sparse_data, validate_indices=True), 156 sets.set_size(sparse_data, validate_indices=False) 157 ] 158 for op in ops: 159 self.assertEqual(None, op.get_shape().dims) 160 self.assertEqual(dtypes.int32, op.dtype) 161 with self.test_session() as sess: 162 results = sess.run(ops) 163 self.assertAllEqual(results[0], results[1]) 164 return results[0] 165 166 def test_set_intersection_multirow_2d(self): 167 for dtype in _DTYPES: 168 self._test_set_intersection_multirow_2d(dtype) 169 170 def _test_set_intersection_multirow_2d(self, dtype): 171 a_values = [[9, 1, 5], [2, 4, 3]] 172 b_values = [[1, 9], [1]] 173 expected_indices = [[0, 0], [0, 1]] 174 expected_values = _values([1, 9], dtype) 175 expected_shape = [2, 2] 176 expected_counts = [2, 0] 177 178 # Dense to sparse. 179 a = _constant(a_values, dtype=dtype) 180 sp_b = _dense_to_sparse(b_values, dtype=dtype) 181 intersection = self._set_intersection(a, sp_b) 182 self._assert_set_operation( 183 expected_indices, 184 expected_values, 185 expected_shape, 186 intersection, 187 dtype=dtype) 188 self.assertAllEqual(expected_counts, self._set_intersection_count(a, sp_b)) 189 190 # Sparse to sparse. 191 sp_a = _dense_to_sparse(a_values, dtype=dtype) 192 intersection = self._set_intersection(sp_a, sp_b) 193 self._assert_set_operation( 194 expected_indices, 195 expected_values, 196 expected_shape, 197 intersection, 198 dtype=dtype) 199 self.assertAllEqual(expected_counts, 200 self._set_intersection_count(sp_a, sp_b)) 201 202 def test_dense_set_intersection_multirow_2d(self): 203 for dtype in _DTYPES: 204 self._test_dense_set_intersection_multirow_2d(dtype) 205 206 def _test_dense_set_intersection_multirow_2d(self, dtype): 207 a_values = [[9, 1, 5], [2, 4, 3]] 208 b_values = [[1, 9], [1, 5]] 209 expected_indices = [[0, 0], [0, 1]] 210 expected_values = _values([1, 9], dtype) 211 expected_shape = [2, 2] 212 expected_counts = [2, 0] 213 214 # Dense to dense. 215 a = _constant(a_values, dtype) 216 b = _constant(b_values, dtype) 217 intersection = self._set_intersection(a, b) 218 self._assert_set_operation( 219 expected_indices, 220 expected_values, 221 expected_shape, 222 intersection, 223 dtype=dtype) 224 self.assertAllEqual(expected_counts, self._set_intersection_count(a, b)) 225 226 def test_set_intersection_duplicates_2d(self): 227 for dtype in _DTYPES: 228 self._test_set_intersection_duplicates_2d(dtype) 229 230 def _test_set_intersection_duplicates_2d(self, dtype): 231 a_values = [[1, 1, 3]] 232 b_values = [[1]] 233 expected_indices = [[0, 0]] 234 expected_values = _values([1], dtype) 235 expected_shape = [1, 1] 236 expected_counts = [1] 237 238 # Dense to dense. 239 a = _constant(a_values, dtype=dtype) 240 b = _constant(b_values, dtype=dtype) 241 intersection = self._set_intersection(a, b) 242 self._assert_set_operation( 243 expected_indices, 244 expected_values, 245 expected_shape, 246 intersection, 247 dtype=dtype) 248 self.assertAllEqual(expected_counts, self._set_intersection_count(a, b)) 249 250 # Dense to sparse. 251 sp_b = _dense_to_sparse(b_values, dtype=dtype) 252 intersection = self._set_intersection(a, sp_b) 253 self._assert_set_operation( 254 expected_indices, 255 expected_values, 256 expected_shape, 257 intersection, 258 dtype=dtype) 259 self.assertAllEqual(expected_counts, self._set_intersection_count(a, sp_b)) 260 261 # Sparse to sparse. 262 sp_a = _dense_to_sparse(a_values, dtype=dtype) 263 intersection = self._set_intersection(sp_a, sp_b) 264 self._assert_set_operation( 265 expected_indices, 266 expected_values, 267 expected_shape, 268 intersection, 269 dtype=dtype) 270 self.assertAllEqual(expected_counts, 271 self._set_intersection_count(sp_a, sp_b)) 272 273 def test_set_intersection_3d(self): 274 for dtype in _DTYPES: 275 self._test_set_intersection_3d(dtype=dtype) 276 277 def test_set_intersection_3d_invalid_indices(self): 278 for dtype in _DTYPES: 279 self._test_set_intersection_3d(dtype=dtype, invalid_indices=True) 280 281 def _test_set_intersection_3d(self, dtype, invalid_indices=False): 282 if invalid_indices: 283 indices = constant_op.constant( 284 [ 285 [0, 1, 0], 286 [0, 1, 1], # 0,1 287 [1, 0, 0], # 1,0 288 [1, 1, 0], 289 [1, 1, 1], 290 [1, 1, 2], # 1,1 291 [0, 0, 0], 292 [0, 0, 2], # 0,0 293 # 2,0 294 [2, 1, 1] # 2,1 295 # 3,* 296 ], 297 dtypes.int64) 298 else: 299 indices = constant_op.constant( 300 [ 301 [0, 0, 0], 302 [0, 0, 2], # 0,0 303 [0, 1, 0], 304 [0, 1, 1], # 0,1 305 [1, 0, 0], # 1,0 306 [1, 1, 0], 307 [1, 1, 1], 308 [1, 1, 2], # 1,1 309 # 2,0 310 [2, 1, 1] # 2,1 311 # 3,* 312 ], 313 dtypes.int64) 314 sp_a = sparse_tensor_lib.SparseTensor( 315 indices, 316 _constant( 317 [ 318 1, 319 9, # 0,0 320 3, 321 3, # 0,1 322 1, # 1,0 323 9, 324 7, 325 8, # 1,1 326 # 2,0 327 5 # 2,1 328 # 3,* 329 ], 330 dtype), 331 constant_op.constant([4, 2, 3], dtypes.int64)) 332 sp_b = sparse_tensor_lib.SparseTensor( 333 constant_op.constant( 334 [ 335 [0, 0, 0], 336 [0, 0, 3], # 0,0 337 # 0,1 338 [1, 0, 0], # 1,0 339 [1, 1, 0], 340 [1, 1, 1], # 1,1 341 [2, 0, 1], # 2,0 342 [2, 1, 1], # 2,1 343 [3, 0, 0], # 3,0 344 [3, 1, 0] # 3,1 345 ], 346 dtypes.int64), 347 _constant( 348 [ 349 1, 350 3, # 0,0 351 # 0,1 352 3, # 1,0 353 7, 354 8, # 1,1 355 2, # 2,0 356 5, # 2,1 357 4, # 3,0 358 4 # 3,1 359 ], 360 dtype), 361 constant_op.constant([4, 2, 4], dtypes.int64)) 362 363 if invalid_indices: 364 with self.assertRaisesRegexp(errors_impl.OpError, "out of order"): 365 self._set_intersection(sp_a, sp_b) 366 else: 367 expected_indices = [ 368 [0, 0, 0], # 0,0 369 # 0,1 370 # 1,0 371 [1, 1, 0], 372 [1, 1, 1], # 1,1 373 # 2,0 374 [2, 1, 0], # 2,1 375 # 3,* 376 ] 377 expected_values = _values( 378 [ 379 1, # 0,0 380 # 0,1 381 # 1,0 382 7, 383 8, # 1,1 384 # 2,0 385 5, # 2,1 386 # 3,* 387 ], 388 dtype) 389 expected_shape = [4, 2, 2] 390 expected_counts = [ 391 [ 392 1, # 0,0 393 0 # 0,1 394 ], 395 [ 396 0, # 1,0 397 2 # 1,1 398 ], 399 [ 400 0, # 2,0 401 1 # 2,1 402 ], 403 [ 404 0, # 3,0 405 0 # 3,1 406 ] 407 ] 408 409 # Sparse to sparse. 410 intersection = self._set_intersection(sp_a, sp_b) 411 self._assert_set_operation( 412 expected_indices, 413 expected_values, 414 expected_shape, 415 intersection, 416 dtype=dtype) 417 self.assertAllEqual(expected_counts, 418 self._set_intersection_count(sp_a, sp_b)) 419 420 # NOTE: sparse_to_dense doesn't support uint8 and uint16. 421 if dtype not in [dtypes.uint8, dtypes.uint16]: 422 # Dense to sparse. 423 a = math_ops.cast( 424 sparse_ops.sparse_to_dense( 425 sp_a.indices, 426 sp_a.dense_shape, 427 sp_a.values, 428 default_value="-1" if dtype == dtypes.string else -1), 429 dtype=dtype) 430 intersection = self._set_intersection(a, sp_b) 431 self._assert_set_operation( 432 expected_indices, 433 expected_values, 434 expected_shape, 435 intersection, 436 dtype=dtype) 437 self.assertAllEqual(expected_counts, 438 self._set_intersection_count(a, sp_b)) 439 440 # Dense to dense. 441 b = math_ops.cast( 442 sparse_ops.sparse_to_dense( 443 sp_b.indices, 444 sp_b.dense_shape, 445 sp_b.values, 446 default_value="-2" if dtype == dtypes.string else -2), 447 dtype=dtype) 448 intersection = self._set_intersection(a, b) 449 self._assert_set_operation( 450 expected_indices, 451 expected_values, 452 expected_shape, 453 intersection, 454 dtype=dtype) 455 self.assertAllEqual(expected_counts, self._set_intersection_count(a, b)) 456 457 def _assert_static_shapes(self, input_tensor, result_sparse_tensor): 458 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 459 sparse_shape_dims = input_tensor.dense_shape.get_shape().dims 460 if sparse_shape_dims is None: 461 expected_rank = None 462 else: 463 expected_rank = sparse_shape_dims[0].value 464 else: 465 expected_rank = input_tensor.get_shape().ndims 466 self.assertAllEqual((None, expected_rank), 467 result_sparse_tensor.indices.get_shape().as_list()) 468 self.assertAllEqual((None,), 469 result_sparse_tensor.values.get_shape().as_list()) 470 self.assertAllEqual((expected_rank,), 471 result_sparse_tensor.dense_shape.get_shape().as_list()) 472 473 def _run_equivalent_set_ops(self, ops): 474 """Assert all ops return the same shapes, and return 1st result.""" 475 # Collect shapes and results for all ops, and assert static shapes match. 476 dynamic_indices_shape_ops = [] 477 dynamic_values_shape_ops = [] 478 static_indices_shape = None 479 static_values_shape = None 480 with self.test_session() as sess: 481 for op in ops: 482 if static_indices_shape is None: 483 static_indices_shape = op.indices.get_shape() 484 else: 485 self.assertAllEqual( 486 static_indices_shape.as_list(), op.indices.get_shape().as_list()) 487 if static_values_shape is None: 488 static_values_shape = op.values.get_shape() 489 else: 490 self.assertAllEqual( 491 static_values_shape.as_list(), op.values.get_shape().as_list()) 492 dynamic_indices_shape_ops.append(array_ops.shape(op.indices)) 493 dynamic_values_shape_ops.append(array_ops.shape(op.values)) 494 results = sess.run( 495 list(ops) + dynamic_indices_shape_ops + dynamic_values_shape_ops) 496 op_count = len(ops) 497 op_results = results[0:op_count] 498 dynamic_indices_shapes = results[op_count:2 * op_count] 499 dynamic_values_shapes = results[2 * op_count:3 * op_count] 500 501 # Assert static and dynamic tensor shapes, and result shapes, are all 502 # consistent. 503 static_indices_shape.assert_is_compatible_with(dynamic_indices_shapes[0]) 504 static_values_shape.assert_is_compatible_with(dynamic_values_shapes[0]) 505 self.assertAllEqual(dynamic_indices_shapes[0], op_results[0].indices.shape) 506 self.assertAllEqual(dynamic_values_shapes[0], op_results[0].values.shape) 507 508 # Assert dynamic shapes and values are the same for all ops. 509 for i in range(1, len(ops)): 510 self.assertAllEqual(dynamic_indices_shapes[0], dynamic_indices_shapes[i]) 511 self.assertAllEqual(dynamic_values_shapes[0], dynamic_values_shapes[i]) 512 self.assertAllEqual(op_results[0].indices, op_results[i].indices) 513 self.assertAllEqual(op_results[0].values, op_results[i].values) 514 self.assertAllEqual(op_results[0].dense_shape, op_results[i].dense_shape) 515 516 return op_results[0] 517 518 def _set_intersection(self, a, b): 519 # Validate that we get the same results with or without `validate_indices`, 520 # and with a & b swapped. 521 ops = ( 522 sets.set_intersection( 523 a, b, validate_indices=True), 524 sets.set_intersection( 525 a, b, validate_indices=False), 526 sets.set_intersection( 527 b, a, validate_indices=True), 528 sets.set_intersection( 529 b, a, validate_indices=False),) 530 for op in ops: 531 self._assert_static_shapes(a, op) 532 return self._run_equivalent_set_ops(ops) 533 534 def _set_intersection_count(self, a, b): 535 op = sets.set_size(sets.set_intersection(a, b)) 536 with self.test_session() as sess: 537 return sess.run(op) 538 539 def test_set_difference_multirow_2d(self): 540 for dtype in _DTYPES: 541 self._test_set_difference_multirow_2d(dtype) 542 543 def _test_set_difference_multirow_2d(self, dtype): 544 a_values = [[1, 1, 1], [1, 5, 9], [4, 5, 3], [5, 5, 1]] 545 b_values = [[], [1, 2], [1, 2, 2], []] 546 547 # a - b. 548 expected_indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1], [2, 2], [3, 0], 549 [3, 1]] 550 expected_values = _values([1, 5, 9, 3, 4, 5, 1, 5], dtype) 551 expected_shape = [4, 3] 552 expected_counts = [1, 2, 3, 2] 553 554 # Dense to sparse. 555 a = _constant(a_values, dtype=dtype) 556 sp_b = _dense_to_sparse(b_values, dtype=dtype) 557 difference = self._set_difference(a, sp_b, True) 558 self._assert_set_operation( 559 expected_indices, 560 expected_values, 561 expected_shape, 562 difference, 563 dtype=dtype) 564 self.assertAllEqual(expected_counts, 565 self._set_difference_count(a, sp_b, True)) 566 567 # Sparse to sparse. 568 sp_a = _dense_to_sparse(a_values, dtype=dtype) 569 difference = self._set_difference(sp_a, sp_b, True) 570 self._assert_set_operation( 571 expected_indices, 572 expected_values, 573 expected_shape, 574 difference, 575 dtype=dtype) 576 self.assertAllEqual(expected_counts, 577 self._set_difference_count(sp_a, sp_b, True)) 578 579 # b - a. 580 expected_indices = [[1, 0], [2, 0], [2, 1]] 581 expected_values = _values([2, 1, 2], dtype) 582 expected_shape = [4, 2] 583 expected_counts = [0, 1, 2, 0] 584 585 # Dense to sparse. 586 difference = self._set_difference(a, sp_b, False) 587 self._assert_set_operation( 588 expected_indices, 589 expected_values, 590 expected_shape, 591 difference, 592 dtype=dtype) 593 self.assertAllEqual(expected_counts, 594 self._set_difference_count(a, sp_b, False)) 595 596 # Sparse to sparse. 597 difference = self._set_difference(sp_a, sp_b, False) 598 self._assert_set_operation( 599 expected_indices, 600 expected_values, 601 expected_shape, 602 difference, 603 dtype=dtype) 604 self.assertAllEqual(expected_counts, 605 self._set_difference_count(sp_a, sp_b, False)) 606 607 def test_dense_set_difference_multirow_2d(self): 608 for dtype in _DTYPES: 609 self._test_dense_set_difference_multirow_2d(dtype) 610 611 def _test_dense_set_difference_multirow_2d(self, dtype): 612 a_values = [[1, 5, 9], [4, 5, 3]] 613 b_values = [[1, 2, 6], [1, 2, 2]] 614 615 # a - b. 616 expected_indices = [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]] 617 expected_values = _values([5, 9, 3, 4, 5], dtype) 618 expected_shape = [2, 3] 619 expected_counts = [2, 3] 620 621 # Dense to dense. 622 a = _constant(a_values, dtype=dtype) 623 b = _constant(b_values, dtype=dtype) 624 difference = self._set_difference(a, b, True) 625 self._assert_set_operation( 626 expected_indices, 627 expected_values, 628 expected_shape, 629 difference, 630 dtype=dtype) 631 self.assertAllEqual(expected_counts, self._set_difference_count(a, b, True)) 632 633 # b - a. 634 expected_indices = [[0, 0], [0, 1], [1, 0], [1, 1]] 635 expected_values = _values([2, 6, 1, 2], dtype) 636 expected_shape = [2, 2] 637 expected_counts = [2, 2] 638 639 # Dense to dense. 640 difference = self._set_difference(a, b, False) 641 self._assert_set_operation( 642 expected_indices, 643 expected_values, 644 expected_shape, 645 difference, 646 dtype=dtype) 647 self.assertAllEqual(expected_counts, 648 self._set_difference_count(a, b, False)) 649 650 def test_sparse_set_difference_multirow_2d(self): 651 for dtype in _DTYPES: 652 self._test_sparse_set_difference_multirow_2d(dtype) 653 654 def _test_sparse_set_difference_multirow_2d(self, dtype): 655 sp_a = _dense_to_sparse( 656 [[], [1, 5, 9], [4, 5, 3, 3, 4, 5], [5, 1]], dtype=dtype) 657 sp_b = _dense_to_sparse([[], [1, 2], [1, 2, 2], []], dtype=dtype) 658 659 # a - b. 660 expected_indices = [[1, 0], [1, 1], [2, 0], [2, 1], [2, 2], [3, 0], [3, 1]] 661 expected_values = _values([5, 9, 3, 4, 5, 1, 5], dtype) 662 expected_shape = [4, 3] 663 expected_counts = [0, 2, 3, 2] 664 665 difference = self._set_difference(sp_a, sp_b, True) 666 self._assert_set_operation( 667 expected_indices, 668 expected_values, 669 expected_shape, 670 difference, 671 dtype=dtype) 672 self.assertAllEqual(expected_counts, 673 self._set_difference_count(sp_a, sp_b, True)) 674 675 # b - a. 676 expected_indices = [[1, 0], [2, 0], [2, 1]] 677 expected_values = _values([2, 1, 2], dtype) 678 expected_shape = [4, 2] 679 expected_counts = [0, 1, 2, 0] 680 681 difference = self._set_difference(sp_a, sp_b, False) 682 self._assert_set_operation( 683 expected_indices, 684 expected_values, 685 expected_shape, 686 difference, 687 dtype=dtype) 688 self.assertAllEqual(expected_counts, 689 self._set_difference_count(sp_a, sp_b, False)) 690 691 def test_set_difference_duplicates_2d(self): 692 for dtype in _DTYPES: 693 self._test_set_difference_duplicates_2d(dtype) 694 695 def _test_set_difference_duplicates_2d(self, dtype): 696 a_values = [[1, 1, 3]] 697 b_values = [[1, 2, 2]] 698 699 # a - b. 700 expected_indices = [[0, 0]] 701 expected_values = _values([3], dtype) 702 expected_shape = [1, 1] 703 expected_counts = [1] 704 705 # Dense to sparse. 706 a = _constant(a_values, dtype=dtype) 707 sp_b = _dense_to_sparse(b_values, dtype=dtype) 708 difference = self._set_difference(a, sp_b, True) 709 self._assert_set_operation( 710 expected_indices, 711 expected_values, 712 expected_shape, 713 difference, 714 dtype=dtype) 715 self.assertAllEqual(expected_counts, 716 self._set_difference_count(a, sp_b, True)) 717 718 # Sparse to sparse. 719 sp_a = _dense_to_sparse(a_values, dtype=dtype) 720 difference = self._set_difference(sp_a, sp_b, True) 721 self._assert_set_operation( 722 expected_indices, 723 expected_values, 724 expected_shape, 725 difference, 726 dtype=dtype) 727 self.assertAllEqual(expected_counts, 728 self._set_difference_count(a, sp_b, True)) 729 730 # b - a. 731 expected_indices = [[0, 0]] 732 expected_values = _values([2], dtype) 733 expected_shape = [1, 1] 734 expected_counts = [1] 735 736 # Dense to sparse. 737 difference = self._set_difference(a, sp_b, False) 738 self._assert_set_operation( 739 expected_indices, 740 expected_values, 741 expected_shape, 742 difference, 743 dtype=dtype) 744 self.assertAllEqual(expected_counts, 745 self._set_difference_count(a, sp_b, False)) 746 747 # Sparse to sparse. 748 difference = self._set_difference(sp_a, sp_b, False) 749 self._assert_set_operation( 750 expected_indices, 751 expected_values, 752 expected_shape, 753 difference, 754 dtype=dtype) 755 self.assertAllEqual(expected_counts, 756 self._set_difference_count(a, sp_b, False)) 757 758 def test_sparse_set_difference_3d(self): 759 for dtype in _DTYPES: 760 self._test_sparse_set_difference_3d(dtype) 761 762 def test_sparse_set_difference_3d_invalid_indices(self): 763 for dtype in _DTYPES: 764 self._test_sparse_set_difference_3d(dtype, invalid_indices=True) 765 766 def _test_sparse_set_difference_3d(self, dtype, invalid_indices=False): 767 if invalid_indices: 768 indices = constant_op.constant( 769 [ 770 [0, 1, 0], 771 [0, 1, 1], # 0,1 772 [1, 0, 0], # 1,0 773 [1, 1, 0], 774 [1, 1, 1], 775 [1, 1, 2], # 1,1 776 [0, 0, 0], 777 [0, 0, 2], # 0,0 778 # 2,0 779 [2, 1, 1] # 2,1 780 # 3,* 781 ], 782 dtypes.int64) 783 else: 784 indices = constant_op.constant( 785 [ 786 [0, 0, 0], 787 [0, 0, 2], # 0,0 788 [0, 1, 0], 789 [0, 1, 1], # 0,1 790 [1, 0, 0], # 1,0 791 [1, 1, 0], 792 [1, 1, 1], 793 [1, 1, 2], # 1,1 794 # 2,0 795 [2, 1, 1] # 2,1 796 # 3,* 797 ], 798 dtypes.int64) 799 sp_a = sparse_tensor_lib.SparseTensor( 800 indices, 801 _constant( 802 [ 803 1, 804 9, # 0,0 805 3, 806 3, # 0,1 807 1, # 1,0 808 9, 809 7, 810 8, # 1,1 811 # 2,0 812 5 # 2,1 813 # 3,* 814 ], 815 dtype), 816 constant_op.constant([4, 2, 3], dtypes.int64)) 817 sp_b = sparse_tensor_lib.SparseTensor( 818 constant_op.constant( 819 [ 820 [0, 0, 0], 821 [0, 0, 3], # 0,0 822 # 0,1 823 [1, 0, 0], # 1,0 824 [1, 1, 0], 825 [1, 1, 1], # 1,1 826 [2, 0, 1], # 2,0 827 [2, 1, 1], # 2,1 828 [3, 0, 0], # 3,0 829 [3, 1, 0] # 3,1 830 ], 831 dtypes.int64), 832 _constant( 833 [ 834 1, 835 3, # 0,0 836 # 0,1 837 3, # 1,0 838 7, 839 8, # 1,1 840 2, # 2,0 841 5, # 2,1 842 4, # 3,0 843 4 # 3,1 844 ], 845 dtype), 846 constant_op.constant([4, 2, 4], dtypes.int64)) 847 848 if invalid_indices: 849 with self.assertRaisesRegexp(errors_impl.OpError, "out of order"): 850 self._set_difference(sp_a, sp_b, False) 851 with self.assertRaisesRegexp(errors_impl.OpError, "out of order"): 852 self._set_difference(sp_a, sp_b, True) 853 else: 854 # a-b 855 expected_indices = [ 856 [0, 0, 0], # 0,0 857 [0, 1, 0], # 0,1 858 [1, 0, 0], # 1,0 859 [1, 1, 0], # 1,1 860 # 2,* 861 # 3,* 862 ] 863 expected_values = _values( 864 [ 865 9, # 0,0 866 3, # 0,1 867 1, # 1,0 868 9, # 1,1 869 # 2,* 870 # 3,* 871 ], 872 dtype) 873 expected_shape = [4, 2, 1] 874 expected_counts = [ 875 [ 876 1, # 0,0 877 1 # 0,1 878 ], 879 [ 880 1, # 1,0 881 1 # 1,1 882 ], 883 [ 884 0, # 2,0 885 0 # 2,1 886 ], 887 [ 888 0, # 3,0 889 0 # 3,1 890 ] 891 ] 892 893 difference = self._set_difference(sp_a, sp_b, True) 894 self._assert_set_operation( 895 expected_indices, 896 expected_values, 897 expected_shape, 898 difference, 899 dtype=dtype) 900 self.assertAllEqual(expected_counts, 901 self._set_difference_count(sp_a, sp_b)) 902 903 # b-a 904 expected_indices = [ 905 [0, 0, 0], # 0,0 906 # 0,1 907 [1, 0, 0], # 1,0 908 # 1,1 909 [2, 0, 0], # 2,0 910 # 2,1 911 [3, 0, 0], # 3,0 912 [3, 1, 0] # 3,1 913 ] 914 expected_values = _values( 915 [ 916 3, # 0,0 917 # 0,1 918 3, # 1,0 919 # 1,1 920 2, # 2,0 921 # 2,1 922 4, # 3,0 923 4, # 3,1 924 ], 925 dtype) 926 expected_shape = [4, 2, 1] 927 expected_counts = [ 928 [ 929 1, # 0,0 930 0 # 0,1 931 ], 932 [ 933 1, # 1,0 934 0 # 1,1 935 ], 936 [ 937 1, # 2,0 938 0 # 2,1 939 ], 940 [ 941 1, # 3,0 942 1 # 3,1 943 ] 944 ] 945 946 difference = self._set_difference(sp_a, sp_b, False) 947 self._assert_set_operation( 948 expected_indices, 949 expected_values, 950 expected_shape, 951 difference, 952 dtype=dtype) 953 self.assertAllEqual(expected_counts, 954 self._set_difference_count(sp_a, sp_b, False)) 955 956 def _set_difference(self, a, b, aminusb=True): 957 # Validate that we get the same results with or without `validate_indices`, 958 # and with a & b swapped. 959 ops = ( 960 sets.set_difference( 961 a, b, aminusb=aminusb, validate_indices=True), 962 sets.set_difference( 963 a, b, aminusb=aminusb, validate_indices=False), 964 sets.set_difference( 965 b, a, aminusb=not aminusb, validate_indices=True), 966 sets.set_difference( 967 b, a, aminusb=not aminusb, validate_indices=False),) 968 for op in ops: 969 self._assert_static_shapes(a, op) 970 return self._run_equivalent_set_ops(ops) 971 972 def _set_difference_count(self, a, b, aminusb=True): 973 op = sets.set_size(sets.set_difference(a, b, aminusb)) 974 with self.test_session() as sess: 975 return sess.run(op) 976 977 def test_set_union_multirow_2d(self): 978 for dtype in _DTYPES: 979 self._test_set_union_multirow_2d(dtype) 980 981 def _test_set_union_multirow_2d(self, dtype): 982 a_values = [[9, 1, 5], [2, 4, 3]] 983 b_values = [[1, 9], [1]] 984 expected_indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]] 985 expected_values = _values([1, 5, 9, 1, 2, 3, 4], dtype) 986 expected_shape = [2, 4] 987 expected_counts = [3, 4] 988 989 # Dense to sparse. 990 a = _constant(a_values, dtype=dtype) 991 sp_b = _dense_to_sparse(b_values, dtype=dtype) 992 union = self._set_union(a, sp_b) 993 self._assert_set_operation( 994 expected_indices, expected_values, expected_shape, union, dtype=dtype) 995 self.assertAllEqual(expected_counts, self._set_union_count(a, sp_b)) 996 997 # Sparse to sparse. 998 sp_a = _dense_to_sparse(a_values, dtype=dtype) 999 union = self._set_union(sp_a, sp_b) 1000 self._assert_set_operation( 1001 expected_indices, expected_values, expected_shape, union, dtype=dtype) 1002 self.assertAllEqual(expected_counts, self._set_union_count(sp_a, sp_b)) 1003 1004 def test_dense_set_union_multirow_2d(self): 1005 for dtype in _DTYPES: 1006 self._test_dense_set_union_multirow_2d(dtype) 1007 1008 def _test_dense_set_union_multirow_2d(self, dtype): 1009 a_values = [[9, 1, 5], [2, 4, 3]] 1010 b_values = [[1, 9], [1, 2]] 1011 expected_indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]] 1012 expected_values = _values([1, 5, 9, 1, 2, 3, 4], dtype) 1013 expected_shape = [2, 4] 1014 expected_counts = [3, 4] 1015 1016 # Dense to dense. 1017 a = _constant(a_values, dtype=dtype) 1018 b = _constant(b_values, dtype=dtype) 1019 union = self._set_union(a, b) 1020 self._assert_set_operation( 1021 expected_indices, expected_values, expected_shape, union, dtype=dtype) 1022 self.assertAllEqual(expected_counts, self._set_union_count(a, b)) 1023 1024 def test_set_union_duplicates_2d(self): 1025 for dtype in _DTYPES: 1026 self._test_set_union_duplicates_2d(dtype) 1027 1028 def _test_set_union_duplicates_2d(self, dtype): 1029 a_values = [[1, 1, 3]] 1030 b_values = [[1]] 1031 expected_indices = [[0, 0], [0, 1]] 1032 expected_values = _values([1, 3], dtype) 1033 expected_shape = [1, 2] 1034 1035 # Dense to sparse. 1036 a = _constant(a_values, dtype=dtype) 1037 sp_b = _dense_to_sparse(b_values, dtype=dtype) 1038 union = self._set_union(a, sp_b) 1039 self._assert_set_operation( 1040 expected_indices, expected_values, expected_shape, union, dtype=dtype) 1041 self.assertAllEqual([2], self._set_union_count(a, sp_b)) 1042 1043 # Sparse to sparse. 1044 sp_a = _dense_to_sparse(a_values, dtype=dtype) 1045 union = self._set_union(sp_a, sp_b) 1046 self._assert_set_operation( 1047 expected_indices, expected_values, expected_shape, union, dtype=dtype) 1048 self.assertAllEqual([2], self._set_union_count(sp_a, sp_b)) 1049 1050 def test_sparse_set_union_3d(self): 1051 for dtype in _DTYPES: 1052 self._test_sparse_set_union_3d(dtype) 1053 1054 def test_sparse_set_union_3d_invalid_indices(self): 1055 for dtype in _DTYPES: 1056 self._test_sparse_set_union_3d(dtype, invalid_indices=True) 1057 1058 def _test_sparse_set_union_3d(self, dtype, invalid_indices=False): 1059 if invalid_indices: 1060 indices = constant_op.constant( 1061 [ 1062 [0, 1, 0], 1063 [0, 1, 1], # 0,1 1064 [1, 0, 0], # 1,0 1065 [0, 0, 0], 1066 [0, 0, 2], # 0,0 1067 [1, 1, 0], 1068 [1, 1, 1], 1069 [1, 1, 2], # 1,1 1070 # 2,0 1071 [2, 1, 1] # 2,1 1072 # 3,* 1073 ], 1074 dtypes.int64) 1075 else: 1076 indices = constant_op.constant( 1077 [ 1078 [0, 0, 0], 1079 [0, 0, 2], # 0,0 1080 [0, 1, 0], 1081 [0, 1, 1], # 0,1 1082 [1, 0, 0], # 1,0 1083 [1, 1, 0], 1084 [1, 1, 1], 1085 [1, 1, 2], # 1,1 1086 # 2,0 1087 [2, 1, 1] # 2,1 1088 # 3,* 1089 ], 1090 dtypes.int64) 1091 sp_a = sparse_tensor_lib.SparseTensor( 1092 indices, 1093 _constant( 1094 [ 1095 1, 1096 9, # 0,0 1097 3, 1098 3, # 0,1 1099 1, # 1,0 1100 9, 1101 7, 1102 8, # 1,1 1103 # 2,0 1104 5 # 2,1 1105 # 3,* 1106 ], 1107 dtype), 1108 constant_op.constant([4, 2, 3], dtypes.int64)) 1109 sp_b = sparse_tensor_lib.SparseTensor( 1110 constant_op.constant( 1111 [ 1112 [0, 0, 0], 1113 [0, 0, 3], # 0,0 1114 # 0,1 1115 [1, 0, 0], # 1,0 1116 [1, 1, 0], 1117 [1, 1, 1], # 1,1 1118 [2, 0, 1], # 2,0 1119 [2, 1, 1], # 2,1 1120 [3, 0, 0], # 3,0 1121 [3, 1, 0] # 3,1 1122 ], 1123 dtypes.int64), 1124 _constant( 1125 [ 1126 1, 1127 3, # 0,0 1128 # 0,1 1129 3, # 1,0 1130 7, 1131 8, # 1,1 1132 2, # 2,0 1133 5, # 2,1 1134 4, # 3,0 1135 4 # 3,1 1136 ], 1137 dtype), 1138 constant_op.constant([4, 2, 4], dtypes.int64)) 1139 1140 if invalid_indices: 1141 with self.assertRaisesRegexp(errors_impl.OpError, "out of order"): 1142 self._set_union(sp_a, sp_b) 1143 else: 1144 expected_indices = [ 1145 [0, 0, 0], 1146 [0, 0, 1], 1147 [0, 0, 2], # 0,0 1148 [0, 1, 0], # 0,1 1149 [1, 0, 0], 1150 [1, 0, 1], # 1,0 1151 [1, 1, 0], 1152 [1, 1, 1], 1153 [1, 1, 2], # 1,1 1154 [2, 0, 0], # 2,0 1155 [2, 1, 0], # 2,1 1156 [3, 0, 0], # 3,0 1157 [3, 1, 0], # 3,1 1158 ] 1159 expected_values = _values( 1160 [ 1161 1, 1162 3, 1163 9, # 0,0 1164 3, # 0,1 1165 1, 1166 3, # 1,0 1167 7, 1168 8, 1169 9, # 1,1 1170 2, # 2,0 1171 5, # 2,1 1172 4, # 3,0 1173 4, # 3,1 1174 ], 1175 dtype) 1176 expected_shape = [4, 2, 3] 1177 expected_counts = [ 1178 [ 1179 3, # 0,0 1180 1 # 0,1 1181 ], 1182 [ 1183 2, # 1,0 1184 3 # 1,1 1185 ], 1186 [ 1187 1, # 2,0 1188 1 # 2,1 1189 ], 1190 [ 1191 1, # 3,0 1192 1 # 3,1 1193 ] 1194 ] 1195 1196 intersection = self._set_union(sp_a, sp_b) 1197 self._assert_set_operation( 1198 expected_indices, 1199 expected_values, 1200 expected_shape, 1201 intersection, 1202 dtype=dtype) 1203 self.assertAllEqual(expected_counts, self._set_union_count(sp_a, sp_b)) 1204 1205 def _set_union(self, a, b): 1206 # Validate that we get the same results with or without `validate_indices`, 1207 # and with a & b swapped. 1208 ops = ( 1209 sets.set_union( 1210 a, b, validate_indices=True), 1211 sets.set_union( 1212 a, b, validate_indices=False), 1213 sets.set_union( 1214 b, a, validate_indices=True), 1215 sets.set_union( 1216 b, a, validate_indices=False),) 1217 for op in ops: 1218 self._assert_static_shapes(a, op) 1219 return self._run_equivalent_set_ops(ops) 1220 1221 def _set_union_count(self, a, b): 1222 op = sets.set_size(sets.set_union(a, b)) 1223 with self.test_session() as sess: 1224 return sess.run(op) 1225 1226 def _assert_set_operation(self, expected_indices, expected_values, 1227 expected_shape, sparse_tensor_value, dtype): 1228 self.assertAllEqual(expected_indices, sparse_tensor_value.indices) 1229 self.assertAllEqual(len(expected_indices), len(expected_values)) 1230 self.assertAllEqual(len(expected_values), len(sparse_tensor_value.values)) 1231 expected_set = set() 1232 actual_set = set() 1233 last_indices = None 1234 for indices, expected_value, actual_value in zip( 1235 expected_indices, expected_values, sparse_tensor_value.values): 1236 if dtype == dtypes.string: 1237 actual_value = actual_value.decode("utf-8") 1238 if last_indices and (last_indices[:-1] != indices[:-1]): 1239 self.assertEqual(expected_set, actual_set, 1240 "Expected %s, got %s, at %s." % (expected_set, 1241 actual_set, indices)) 1242 expected_set.clear() 1243 actual_set.clear() 1244 expected_set.add(expected_value) 1245 actual_set.add(actual_value) 1246 last_indices = indices 1247 self.assertEqual(expected_set, actual_set, 1248 "Expected %s, got %s, at %s." % (expected_set, actual_set, 1249 last_indices)) 1250 self.assertAllEqual(expected_shape, sparse_tensor_value.dense_shape) 1251 1252 1253if __name__ == "__main__": 1254 googletest.main() 1255