1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for set_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import sets
31from tensorflow.python.ops import sparse_ops
32from tensorflow.python.platform import googletest
33
34_DTYPES = set([
35    dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8,
36    dtypes.uint16, dtypes.string
37])
38
39
40def _values(values, dtype):
41  return np.array(
42      values,
43      dtype=(np.unicode if (dtype == dtypes.string) else dtype.as_numpy_dtype))
44
45
46def _constant(values, dtype):
47  return constant_op.constant(_values(values, dtype), dtype=dtype)
48
49
50def _dense_to_sparse(dense, dtype):
51  indices = []
52  values = []
53  max_row_len = 0
54  for row in dense:
55    max_row_len = max(max_row_len, len(row))
56  shape = [len(dense), max_row_len]
57  row_ix = 0
58  for row in dense:
59    col_ix = 0
60    for cell in row:
61      indices.append([row_ix, col_ix])
62      values.append(str(cell) if dtype == dtypes.string else cell)
63      col_ix += 1
64    row_ix += 1
65  return sparse_tensor_lib.SparseTensor(
66      constant_op.constant(indices, dtypes.int64),
67      constant_op.constant(values, dtype),
68      constant_op.constant(shape, dtypes.int64))
69
70
71class SetOpsTest(test_util.TensorFlowTestCase):
72
73  def test_set_size_2d(self):
74    for dtype in _DTYPES:
75      self._test_set_size_2d(dtype)
76
77  def _test_set_size_2d(self, dtype):
78    self.assertAllEqual([1], self._set_size(_dense_to_sparse([[1]], dtype)))
79    self.assertAllEqual([2, 1],
80                        self._set_size(_dense_to_sparse([[1, 9], [1]], dtype)))
81    self.assertAllEqual(
82        [3, 0], self._set_size(_dense_to_sparse([[1, 9, 2], []], dtype)))
83    self.assertAllEqual(
84        [0, 3], self._set_size(_dense_to_sparse([[], [1, 9, 2]], dtype)))
85
86  def test_set_size_duplicates_2d(self):
87    for dtype in _DTYPES:
88      self._test_set_size_duplicates_2d(dtype)
89
90  def _test_set_size_duplicates_2d(self, dtype):
91    self.assertAllEqual(
92        [1], self._set_size(_dense_to_sparse([[1, 1, 1, 1, 1, 1]], dtype)))
93    self.assertAllEqual([2, 7, 3, 0, 1],
94                        self._set_size(
95                            _dense_to_sparse([[1, 9], [
96                                6, 7, 8, 8, 6, 7, 5, 3, 3, 0, 6, 6, 9, 0, 0, 0
97                            ], [999, 1, -1000], [], [-1]], dtype)))
98
99  def test_set_size_3d(self):
100    for dtype in _DTYPES:
101      self._test_set_size_3d(dtype)
102
103  def test_set_size_3d_invalid_indices(self):
104    for dtype in _DTYPES:
105      self._test_set_size_3d(dtype, invalid_indices=True)
106
107  def _test_set_size_3d(self, dtype, invalid_indices=False):
108    if invalid_indices:
109      indices = constant_op.constant([
110          [0, 1, 0], [0, 1, 1],             # 0,1
111          [1, 0, 0],                        # 1,0
112          [1, 1, 0], [1, 1, 1], [1, 1, 2],  # 1,1
113          [0, 0, 0], [0, 0, 2],             # 0,0
114                                            # 2,0
115          [2, 1, 1]                         # 2,1
116      ], dtypes.int64)
117    else:
118      indices = constant_op.constant([
119          [0, 0, 0], [0, 0, 2],             # 0,0
120          [0, 1, 0], [0, 1, 1],             # 0,1
121          [1, 0, 0],                        # 1,0
122          [1, 1, 0], [1, 1, 1], [1, 1, 2],  # 1,1
123                                            # 2,0
124          [2, 1, 1]                         # 2,1
125      ], dtypes.int64)
126
127    sp = sparse_tensor_lib.SparseTensor(
128        indices,
129        _constant([
130            1, 9,     # 0,0
131            3, 3,     # 0,1
132            1,        # 1,0
133            9, 7, 8,  # 1,1
134                      # 2,0
135            5         # 2,1
136        ], dtype),
137        constant_op.constant([3, 2, 3], dtypes.int64))
138
139    if invalid_indices:
140      with self.assertRaisesRegexp(errors_impl.OpError, "out of order"):
141        self._set_size(sp)
142    else:
143      self.assertAllEqual([
144          [2,   # 0,0
145           1],  # 0,1
146          [1,   # 1,0
147           3],  # 1,1
148          [0,   # 2,0
149           1]   # 2,1
150      ], self._set_size(sp))
151
152  def _set_size(self, sparse_data):
153    # Validate that we get the same results with or without `validate_indices`.
154    ops = [
155        sets.set_size(sparse_data, validate_indices=True),
156        sets.set_size(sparse_data, validate_indices=False)
157    ]
158    for op in ops:
159      self.assertEqual(None, op.get_shape().dims)
160      self.assertEqual(dtypes.int32, op.dtype)
161    with self.test_session() as sess:
162      results = sess.run(ops)
163    self.assertAllEqual(results[0], results[1])
164    return results[0]
165
166  def test_set_intersection_multirow_2d(self):
167    for dtype in _DTYPES:
168      self._test_set_intersection_multirow_2d(dtype)
169
170  def _test_set_intersection_multirow_2d(self, dtype):
171    a_values = [[9, 1, 5], [2, 4, 3]]
172    b_values = [[1, 9], [1]]
173    expected_indices = [[0, 0], [0, 1]]
174    expected_values = _values([1, 9], dtype)
175    expected_shape = [2, 2]
176    expected_counts = [2, 0]
177
178    # Dense to sparse.
179    a = _constant(a_values, dtype=dtype)
180    sp_b = _dense_to_sparse(b_values, dtype=dtype)
181    intersection = self._set_intersection(a, sp_b)
182    self._assert_set_operation(
183        expected_indices,
184        expected_values,
185        expected_shape,
186        intersection,
187        dtype=dtype)
188    self.assertAllEqual(expected_counts, self._set_intersection_count(a, sp_b))
189
190    # Sparse to sparse.
191    sp_a = _dense_to_sparse(a_values, dtype=dtype)
192    intersection = self._set_intersection(sp_a, sp_b)
193    self._assert_set_operation(
194        expected_indices,
195        expected_values,
196        expected_shape,
197        intersection,
198        dtype=dtype)
199    self.assertAllEqual(expected_counts,
200                        self._set_intersection_count(sp_a, sp_b))
201
202  def test_dense_set_intersection_multirow_2d(self):
203    for dtype in _DTYPES:
204      self._test_dense_set_intersection_multirow_2d(dtype)
205
206  def _test_dense_set_intersection_multirow_2d(self, dtype):
207    a_values = [[9, 1, 5], [2, 4, 3]]
208    b_values = [[1, 9], [1, 5]]
209    expected_indices = [[0, 0], [0, 1]]
210    expected_values = _values([1, 9], dtype)
211    expected_shape = [2, 2]
212    expected_counts = [2, 0]
213
214    # Dense to dense.
215    a = _constant(a_values, dtype)
216    b = _constant(b_values, dtype)
217    intersection = self._set_intersection(a, b)
218    self._assert_set_operation(
219        expected_indices,
220        expected_values,
221        expected_shape,
222        intersection,
223        dtype=dtype)
224    self.assertAllEqual(expected_counts, self._set_intersection_count(a, b))
225
226  def test_set_intersection_duplicates_2d(self):
227    for dtype in _DTYPES:
228      self._test_set_intersection_duplicates_2d(dtype)
229
230  def _test_set_intersection_duplicates_2d(self, dtype):
231    a_values = [[1, 1, 3]]
232    b_values = [[1]]
233    expected_indices = [[0, 0]]
234    expected_values = _values([1], dtype)
235    expected_shape = [1, 1]
236    expected_counts = [1]
237
238    # Dense to dense.
239    a = _constant(a_values, dtype=dtype)
240    b = _constant(b_values, dtype=dtype)
241    intersection = self._set_intersection(a, b)
242    self._assert_set_operation(
243        expected_indices,
244        expected_values,
245        expected_shape,
246        intersection,
247        dtype=dtype)
248    self.assertAllEqual(expected_counts, self._set_intersection_count(a, b))
249
250    # Dense to sparse.
251    sp_b = _dense_to_sparse(b_values, dtype=dtype)
252    intersection = self._set_intersection(a, sp_b)
253    self._assert_set_operation(
254        expected_indices,
255        expected_values,
256        expected_shape,
257        intersection,
258        dtype=dtype)
259    self.assertAllEqual(expected_counts, self._set_intersection_count(a, sp_b))
260
261    # Sparse to sparse.
262    sp_a = _dense_to_sparse(a_values, dtype=dtype)
263    intersection = self._set_intersection(sp_a, sp_b)
264    self._assert_set_operation(
265        expected_indices,
266        expected_values,
267        expected_shape,
268        intersection,
269        dtype=dtype)
270    self.assertAllEqual(expected_counts,
271                        self._set_intersection_count(sp_a, sp_b))
272
273  def test_set_intersection_3d(self):
274    for dtype in _DTYPES:
275      self._test_set_intersection_3d(dtype=dtype)
276
277  def test_set_intersection_3d_invalid_indices(self):
278    for dtype in _DTYPES:
279      self._test_set_intersection_3d(dtype=dtype, invalid_indices=True)
280
281  def _test_set_intersection_3d(self, dtype, invalid_indices=False):
282    if invalid_indices:
283      indices = constant_op.constant(
284          [
285              [0, 1, 0],
286              [0, 1, 1],  # 0,1
287              [1, 0, 0],  # 1,0
288              [1, 1, 0],
289              [1, 1, 1],
290              [1, 1, 2],  # 1,1
291              [0, 0, 0],
292              [0, 0, 2],  # 0,0
293              # 2,0
294              [2, 1, 1]  # 2,1
295              # 3,*
296          ],
297          dtypes.int64)
298    else:
299      indices = constant_op.constant(
300          [
301              [0, 0, 0],
302              [0, 0, 2],  # 0,0
303              [0, 1, 0],
304              [0, 1, 1],  # 0,1
305              [1, 0, 0],  # 1,0
306              [1, 1, 0],
307              [1, 1, 1],
308              [1, 1, 2],  # 1,1
309              # 2,0
310              [2, 1, 1]  # 2,1
311              # 3,*
312          ],
313          dtypes.int64)
314    sp_a = sparse_tensor_lib.SparseTensor(
315        indices,
316        _constant(
317            [
318                1,
319                9,  # 0,0
320                3,
321                3,  # 0,1
322                1,  # 1,0
323                9,
324                7,
325                8,  # 1,1
326                # 2,0
327                5  # 2,1
328                # 3,*
329            ],
330            dtype),
331        constant_op.constant([4, 2, 3], dtypes.int64))
332    sp_b = sparse_tensor_lib.SparseTensor(
333        constant_op.constant(
334            [
335                [0, 0, 0],
336                [0, 0, 3],  # 0,0
337                # 0,1
338                [1, 0, 0],  # 1,0
339                [1, 1, 0],
340                [1, 1, 1],  # 1,1
341                [2, 0, 1],  # 2,0
342                [2, 1, 1],  # 2,1
343                [3, 0, 0],  # 3,0
344                [3, 1, 0]  # 3,1
345            ],
346            dtypes.int64),
347        _constant(
348            [
349                1,
350                3,  # 0,0
351                # 0,1
352                3,  # 1,0
353                7,
354                8,  # 1,1
355                2,  # 2,0
356                5,  # 2,1
357                4,  # 3,0
358                4  # 3,1
359            ],
360            dtype),
361        constant_op.constant([4, 2, 4], dtypes.int64))
362
363    if invalid_indices:
364      with self.assertRaisesRegexp(errors_impl.OpError, "out of order"):
365        self._set_intersection(sp_a, sp_b)
366    else:
367      expected_indices = [
368          [0, 0, 0],  # 0,0
369          # 0,1
370          # 1,0
371          [1, 1, 0],
372          [1, 1, 1],  # 1,1
373          # 2,0
374          [2, 1, 0],  # 2,1
375          # 3,*
376      ]
377      expected_values = _values(
378          [
379              1,  # 0,0
380              # 0,1
381              # 1,0
382              7,
383              8,  # 1,1
384              # 2,0
385              5,  # 2,1
386              # 3,*
387          ],
388          dtype)
389      expected_shape = [4, 2, 2]
390      expected_counts = [
391          [
392              1,  # 0,0
393              0  # 0,1
394          ],
395          [
396              0,  # 1,0
397              2  # 1,1
398          ],
399          [
400              0,  # 2,0
401              1  # 2,1
402          ],
403          [
404              0,  # 3,0
405              0  # 3,1
406          ]
407      ]
408
409      # Sparse to sparse.
410      intersection = self._set_intersection(sp_a, sp_b)
411      self._assert_set_operation(
412          expected_indices,
413          expected_values,
414          expected_shape,
415          intersection,
416          dtype=dtype)
417      self.assertAllEqual(expected_counts,
418                          self._set_intersection_count(sp_a, sp_b))
419
420      # NOTE: sparse_to_dense doesn't support uint8 and uint16.
421      if dtype not in [dtypes.uint8, dtypes.uint16]:
422        # Dense to sparse.
423        a = math_ops.cast(
424            sparse_ops.sparse_to_dense(
425                sp_a.indices,
426                sp_a.dense_shape,
427                sp_a.values,
428                default_value="-1" if dtype == dtypes.string else -1),
429            dtype=dtype)
430        intersection = self._set_intersection(a, sp_b)
431        self._assert_set_operation(
432            expected_indices,
433            expected_values,
434            expected_shape,
435            intersection,
436            dtype=dtype)
437        self.assertAllEqual(expected_counts,
438                            self._set_intersection_count(a, sp_b))
439
440        # Dense to dense.
441        b = math_ops.cast(
442            sparse_ops.sparse_to_dense(
443                sp_b.indices,
444                sp_b.dense_shape,
445                sp_b.values,
446                default_value="-2" if dtype == dtypes.string else -2),
447            dtype=dtype)
448        intersection = self._set_intersection(a, b)
449        self._assert_set_operation(
450            expected_indices,
451            expected_values,
452            expected_shape,
453            intersection,
454            dtype=dtype)
455        self.assertAllEqual(expected_counts, self._set_intersection_count(a, b))
456
457  def _assert_static_shapes(self, input_tensor, result_sparse_tensor):
458    if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
459      sparse_shape_dims = input_tensor.dense_shape.get_shape().dims
460      if sparse_shape_dims is None:
461        expected_rank = None
462      else:
463        expected_rank = sparse_shape_dims[0].value
464    else:
465      expected_rank = input_tensor.get_shape().ndims
466    self.assertAllEqual((None, expected_rank),
467                        result_sparse_tensor.indices.get_shape().as_list())
468    self.assertAllEqual((None,),
469                        result_sparse_tensor.values.get_shape().as_list())
470    self.assertAllEqual((expected_rank,),
471                        result_sparse_tensor.dense_shape.get_shape().as_list())
472
473  def _run_equivalent_set_ops(self, ops):
474    """Assert all ops return the same shapes, and return 1st result."""
475    # Collect shapes and results for all ops, and assert static shapes match.
476    dynamic_indices_shape_ops = []
477    dynamic_values_shape_ops = []
478    static_indices_shape = None
479    static_values_shape = None
480    with self.test_session() as sess:
481      for op in ops:
482        if static_indices_shape is None:
483          static_indices_shape = op.indices.get_shape()
484        else:
485          self.assertAllEqual(
486              static_indices_shape.as_list(), op.indices.get_shape().as_list())
487        if static_values_shape is None:
488          static_values_shape = op.values.get_shape()
489        else:
490          self.assertAllEqual(
491              static_values_shape.as_list(), op.values.get_shape().as_list())
492        dynamic_indices_shape_ops.append(array_ops.shape(op.indices))
493        dynamic_values_shape_ops.append(array_ops.shape(op.values))
494      results = sess.run(
495          list(ops) + dynamic_indices_shape_ops + dynamic_values_shape_ops)
496      op_count = len(ops)
497      op_results = results[0:op_count]
498      dynamic_indices_shapes = results[op_count:2 * op_count]
499      dynamic_values_shapes = results[2 * op_count:3 * op_count]
500
501    # Assert static and dynamic tensor shapes, and result shapes, are all
502    # consistent.
503    static_indices_shape.assert_is_compatible_with(dynamic_indices_shapes[0])
504    static_values_shape.assert_is_compatible_with(dynamic_values_shapes[0])
505    self.assertAllEqual(dynamic_indices_shapes[0], op_results[0].indices.shape)
506    self.assertAllEqual(dynamic_values_shapes[0], op_results[0].values.shape)
507
508    # Assert dynamic shapes and values are the same for all ops.
509    for i in range(1, len(ops)):
510      self.assertAllEqual(dynamic_indices_shapes[0], dynamic_indices_shapes[i])
511      self.assertAllEqual(dynamic_values_shapes[0], dynamic_values_shapes[i])
512      self.assertAllEqual(op_results[0].indices, op_results[i].indices)
513      self.assertAllEqual(op_results[0].values, op_results[i].values)
514      self.assertAllEqual(op_results[0].dense_shape, op_results[i].dense_shape)
515
516    return op_results[0]
517
518  def _set_intersection(self, a, b):
519    # Validate that we get the same results with or without `validate_indices`,
520    # and with a & b swapped.
521    ops = (
522        sets.set_intersection(
523            a, b, validate_indices=True),
524        sets.set_intersection(
525            a, b, validate_indices=False),
526        sets.set_intersection(
527            b, a, validate_indices=True),
528        sets.set_intersection(
529            b, a, validate_indices=False),)
530    for op in ops:
531      self._assert_static_shapes(a, op)
532    return self._run_equivalent_set_ops(ops)
533
534  def _set_intersection_count(self, a, b):
535    op = sets.set_size(sets.set_intersection(a, b))
536    with self.test_session() as sess:
537      return sess.run(op)
538
539  def test_set_difference_multirow_2d(self):
540    for dtype in _DTYPES:
541      self._test_set_difference_multirow_2d(dtype)
542
543  def _test_set_difference_multirow_2d(self, dtype):
544    a_values = [[1, 1, 1], [1, 5, 9], [4, 5, 3], [5, 5, 1]]
545    b_values = [[], [1, 2], [1, 2, 2], []]
546
547    # a - b.
548    expected_indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1], [2, 2], [3, 0],
549                        [3, 1]]
550    expected_values = _values([1, 5, 9, 3, 4, 5, 1, 5], dtype)
551    expected_shape = [4, 3]
552    expected_counts = [1, 2, 3, 2]
553
554    # Dense to sparse.
555    a = _constant(a_values, dtype=dtype)
556    sp_b = _dense_to_sparse(b_values, dtype=dtype)
557    difference = self._set_difference(a, sp_b, True)
558    self._assert_set_operation(
559        expected_indices,
560        expected_values,
561        expected_shape,
562        difference,
563        dtype=dtype)
564    self.assertAllEqual(expected_counts,
565                        self._set_difference_count(a, sp_b, True))
566
567    # Sparse to sparse.
568    sp_a = _dense_to_sparse(a_values, dtype=dtype)
569    difference = self._set_difference(sp_a, sp_b, True)
570    self._assert_set_operation(
571        expected_indices,
572        expected_values,
573        expected_shape,
574        difference,
575        dtype=dtype)
576    self.assertAllEqual(expected_counts,
577                        self._set_difference_count(sp_a, sp_b, True))
578
579    # b - a.
580    expected_indices = [[1, 0], [2, 0], [2, 1]]
581    expected_values = _values([2, 1, 2], dtype)
582    expected_shape = [4, 2]
583    expected_counts = [0, 1, 2, 0]
584
585    # Dense to sparse.
586    difference = self._set_difference(a, sp_b, False)
587    self._assert_set_operation(
588        expected_indices,
589        expected_values,
590        expected_shape,
591        difference,
592        dtype=dtype)
593    self.assertAllEqual(expected_counts,
594                        self._set_difference_count(a, sp_b, False))
595
596    # Sparse to sparse.
597    difference = self._set_difference(sp_a, sp_b, False)
598    self._assert_set_operation(
599        expected_indices,
600        expected_values,
601        expected_shape,
602        difference,
603        dtype=dtype)
604    self.assertAllEqual(expected_counts,
605                        self._set_difference_count(sp_a, sp_b, False))
606
607  def test_dense_set_difference_multirow_2d(self):
608    for dtype in _DTYPES:
609      self._test_dense_set_difference_multirow_2d(dtype)
610
611  def _test_dense_set_difference_multirow_2d(self, dtype):
612    a_values = [[1, 5, 9], [4, 5, 3]]
613    b_values = [[1, 2, 6], [1, 2, 2]]
614
615    # a - b.
616    expected_indices = [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]]
617    expected_values = _values([5, 9, 3, 4, 5], dtype)
618    expected_shape = [2, 3]
619    expected_counts = [2, 3]
620
621    # Dense to dense.
622    a = _constant(a_values, dtype=dtype)
623    b = _constant(b_values, dtype=dtype)
624    difference = self._set_difference(a, b, True)
625    self._assert_set_operation(
626        expected_indices,
627        expected_values,
628        expected_shape,
629        difference,
630        dtype=dtype)
631    self.assertAllEqual(expected_counts, self._set_difference_count(a, b, True))
632
633    # b - a.
634    expected_indices = [[0, 0], [0, 1], [1, 0], [1, 1]]
635    expected_values = _values([2, 6, 1, 2], dtype)
636    expected_shape = [2, 2]
637    expected_counts = [2, 2]
638
639    # Dense to dense.
640    difference = self._set_difference(a, b, False)
641    self._assert_set_operation(
642        expected_indices,
643        expected_values,
644        expected_shape,
645        difference,
646        dtype=dtype)
647    self.assertAllEqual(expected_counts,
648                        self._set_difference_count(a, b, False))
649
650  def test_sparse_set_difference_multirow_2d(self):
651    for dtype in _DTYPES:
652      self._test_sparse_set_difference_multirow_2d(dtype)
653
654  def _test_sparse_set_difference_multirow_2d(self, dtype):
655    sp_a = _dense_to_sparse(
656        [[], [1, 5, 9], [4, 5, 3, 3, 4, 5], [5, 1]], dtype=dtype)
657    sp_b = _dense_to_sparse([[], [1, 2], [1, 2, 2], []], dtype=dtype)
658
659    # a - b.
660    expected_indices = [[1, 0], [1, 1], [2, 0], [2, 1], [2, 2], [3, 0], [3, 1]]
661    expected_values = _values([5, 9, 3, 4, 5, 1, 5], dtype)
662    expected_shape = [4, 3]
663    expected_counts = [0, 2, 3, 2]
664
665    difference = self._set_difference(sp_a, sp_b, True)
666    self._assert_set_operation(
667        expected_indices,
668        expected_values,
669        expected_shape,
670        difference,
671        dtype=dtype)
672    self.assertAllEqual(expected_counts,
673                        self._set_difference_count(sp_a, sp_b, True))
674
675    # b - a.
676    expected_indices = [[1, 0], [2, 0], [2, 1]]
677    expected_values = _values([2, 1, 2], dtype)
678    expected_shape = [4, 2]
679    expected_counts = [0, 1, 2, 0]
680
681    difference = self._set_difference(sp_a, sp_b, False)
682    self._assert_set_operation(
683        expected_indices,
684        expected_values,
685        expected_shape,
686        difference,
687        dtype=dtype)
688    self.assertAllEqual(expected_counts,
689                        self._set_difference_count(sp_a, sp_b, False))
690
691  def test_set_difference_duplicates_2d(self):
692    for dtype in _DTYPES:
693      self._test_set_difference_duplicates_2d(dtype)
694
695  def _test_set_difference_duplicates_2d(self, dtype):
696    a_values = [[1, 1, 3]]
697    b_values = [[1, 2, 2]]
698
699    # a - b.
700    expected_indices = [[0, 0]]
701    expected_values = _values([3], dtype)
702    expected_shape = [1, 1]
703    expected_counts = [1]
704
705    # Dense to sparse.
706    a = _constant(a_values, dtype=dtype)
707    sp_b = _dense_to_sparse(b_values, dtype=dtype)
708    difference = self._set_difference(a, sp_b, True)
709    self._assert_set_operation(
710        expected_indices,
711        expected_values,
712        expected_shape,
713        difference,
714        dtype=dtype)
715    self.assertAllEqual(expected_counts,
716                        self._set_difference_count(a, sp_b, True))
717
718    # Sparse to sparse.
719    sp_a = _dense_to_sparse(a_values, dtype=dtype)
720    difference = self._set_difference(sp_a, sp_b, True)
721    self._assert_set_operation(
722        expected_indices,
723        expected_values,
724        expected_shape,
725        difference,
726        dtype=dtype)
727    self.assertAllEqual(expected_counts,
728                        self._set_difference_count(a, sp_b, True))
729
730    # b - a.
731    expected_indices = [[0, 0]]
732    expected_values = _values([2], dtype)
733    expected_shape = [1, 1]
734    expected_counts = [1]
735
736    # Dense to sparse.
737    difference = self._set_difference(a, sp_b, False)
738    self._assert_set_operation(
739        expected_indices,
740        expected_values,
741        expected_shape,
742        difference,
743        dtype=dtype)
744    self.assertAllEqual(expected_counts,
745                        self._set_difference_count(a, sp_b, False))
746
747    # Sparse to sparse.
748    difference = self._set_difference(sp_a, sp_b, False)
749    self._assert_set_operation(
750        expected_indices,
751        expected_values,
752        expected_shape,
753        difference,
754        dtype=dtype)
755    self.assertAllEqual(expected_counts,
756                        self._set_difference_count(a, sp_b, False))
757
758  def test_sparse_set_difference_3d(self):
759    for dtype in _DTYPES:
760      self._test_sparse_set_difference_3d(dtype)
761
762  def test_sparse_set_difference_3d_invalid_indices(self):
763    for dtype in _DTYPES:
764      self._test_sparse_set_difference_3d(dtype, invalid_indices=True)
765
766  def _test_sparse_set_difference_3d(self, dtype, invalid_indices=False):
767    if invalid_indices:
768      indices = constant_op.constant(
769          [
770              [0, 1, 0],
771              [0, 1, 1],  # 0,1
772              [1, 0, 0],  # 1,0
773              [1, 1, 0],
774              [1, 1, 1],
775              [1, 1, 2],  # 1,1
776              [0, 0, 0],
777              [0, 0, 2],  # 0,0
778              # 2,0
779              [2, 1, 1]  # 2,1
780              # 3,*
781          ],
782          dtypes.int64)
783    else:
784      indices = constant_op.constant(
785          [
786              [0, 0, 0],
787              [0, 0, 2],  # 0,0
788              [0, 1, 0],
789              [0, 1, 1],  # 0,1
790              [1, 0, 0],  # 1,0
791              [1, 1, 0],
792              [1, 1, 1],
793              [1, 1, 2],  # 1,1
794              # 2,0
795              [2, 1, 1]  # 2,1
796              # 3,*
797          ],
798          dtypes.int64)
799    sp_a = sparse_tensor_lib.SparseTensor(
800        indices,
801        _constant(
802            [
803                1,
804                9,  # 0,0
805                3,
806                3,  # 0,1
807                1,  # 1,0
808                9,
809                7,
810                8,  # 1,1
811                # 2,0
812                5  # 2,1
813                # 3,*
814            ],
815            dtype),
816        constant_op.constant([4, 2, 3], dtypes.int64))
817    sp_b = sparse_tensor_lib.SparseTensor(
818        constant_op.constant(
819            [
820                [0, 0, 0],
821                [0, 0, 3],  # 0,0
822                # 0,1
823                [1, 0, 0],  # 1,0
824                [1, 1, 0],
825                [1, 1, 1],  # 1,1
826                [2, 0, 1],  # 2,0
827                [2, 1, 1],  # 2,1
828                [3, 0, 0],  # 3,0
829                [3, 1, 0]  # 3,1
830            ],
831            dtypes.int64),
832        _constant(
833            [
834                1,
835                3,  # 0,0
836                # 0,1
837                3,  # 1,0
838                7,
839                8,  # 1,1
840                2,  # 2,0
841                5,  # 2,1
842                4,  # 3,0
843                4  # 3,1
844            ],
845            dtype),
846        constant_op.constant([4, 2, 4], dtypes.int64))
847
848    if invalid_indices:
849      with self.assertRaisesRegexp(errors_impl.OpError, "out of order"):
850        self._set_difference(sp_a, sp_b, False)
851      with self.assertRaisesRegexp(errors_impl.OpError, "out of order"):
852        self._set_difference(sp_a, sp_b, True)
853    else:
854      # a-b
855      expected_indices = [
856          [0, 0, 0],  # 0,0
857          [0, 1, 0],  # 0,1
858          [1, 0, 0],  # 1,0
859          [1, 1, 0],  # 1,1
860          # 2,*
861          # 3,*
862      ]
863      expected_values = _values(
864          [
865              9,  # 0,0
866              3,  # 0,1
867              1,  # 1,0
868              9,  # 1,1
869              # 2,*
870              # 3,*
871          ],
872          dtype)
873      expected_shape = [4, 2, 1]
874      expected_counts = [
875          [
876              1,  # 0,0
877              1  # 0,1
878          ],
879          [
880              1,  # 1,0
881              1  # 1,1
882          ],
883          [
884              0,  # 2,0
885              0  # 2,1
886          ],
887          [
888              0,  # 3,0
889              0  # 3,1
890          ]
891      ]
892
893      difference = self._set_difference(sp_a, sp_b, True)
894      self._assert_set_operation(
895          expected_indices,
896          expected_values,
897          expected_shape,
898          difference,
899          dtype=dtype)
900      self.assertAllEqual(expected_counts,
901                          self._set_difference_count(sp_a, sp_b))
902
903      # b-a
904      expected_indices = [
905          [0, 0, 0],  # 0,0
906          # 0,1
907          [1, 0, 0],  # 1,0
908          # 1,1
909          [2, 0, 0],  # 2,0
910          # 2,1
911          [3, 0, 0],  # 3,0
912          [3, 1, 0]  # 3,1
913      ]
914      expected_values = _values(
915          [
916              3,  # 0,0
917              # 0,1
918              3,  # 1,0
919              # 1,1
920              2,  # 2,0
921              # 2,1
922              4,  # 3,0
923              4,  # 3,1
924          ],
925          dtype)
926      expected_shape = [4, 2, 1]
927      expected_counts = [
928          [
929              1,  # 0,0
930              0  # 0,1
931          ],
932          [
933              1,  # 1,0
934              0  # 1,1
935          ],
936          [
937              1,  # 2,0
938              0  # 2,1
939          ],
940          [
941              1,  # 3,0
942              1  # 3,1
943          ]
944      ]
945
946      difference = self._set_difference(sp_a, sp_b, False)
947      self._assert_set_operation(
948          expected_indices,
949          expected_values,
950          expected_shape,
951          difference,
952          dtype=dtype)
953      self.assertAllEqual(expected_counts,
954                          self._set_difference_count(sp_a, sp_b, False))
955
956  def _set_difference(self, a, b, aminusb=True):
957    # Validate that we get the same results with or without `validate_indices`,
958    # and with a & b swapped.
959    ops = (
960        sets.set_difference(
961            a, b, aminusb=aminusb, validate_indices=True),
962        sets.set_difference(
963            a, b, aminusb=aminusb, validate_indices=False),
964        sets.set_difference(
965            b, a, aminusb=not aminusb, validate_indices=True),
966        sets.set_difference(
967            b, a, aminusb=not aminusb, validate_indices=False),)
968    for op in ops:
969      self._assert_static_shapes(a, op)
970    return self._run_equivalent_set_ops(ops)
971
972  def _set_difference_count(self, a, b, aminusb=True):
973    op = sets.set_size(sets.set_difference(a, b, aminusb))
974    with self.test_session() as sess:
975      return sess.run(op)
976
977  def test_set_union_multirow_2d(self):
978    for dtype in _DTYPES:
979      self._test_set_union_multirow_2d(dtype)
980
981  def _test_set_union_multirow_2d(self, dtype):
982    a_values = [[9, 1, 5], [2, 4, 3]]
983    b_values = [[1, 9], [1]]
984    expected_indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]]
985    expected_values = _values([1, 5, 9, 1, 2, 3, 4], dtype)
986    expected_shape = [2, 4]
987    expected_counts = [3, 4]
988
989    # Dense to sparse.
990    a = _constant(a_values, dtype=dtype)
991    sp_b = _dense_to_sparse(b_values, dtype=dtype)
992    union = self._set_union(a, sp_b)
993    self._assert_set_operation(
994        expected_indices, expected_values, expected_shape, union, dtype=dtype)
995    self.assertAllEqual(expected_counts, self._set_union_count(a, sp_b))
996
997    # Sparse to sparse.
998    sp_a = _dense_to_sparse(a_values, dtype=dtype)
999    union = self._set_union(sp_a, sp_b)
1000    self._assert_set_operation(
1001        expected_indices, expected_values, expected_shape, union, dtype=dtype)
1002    self.assertAllEqual(expected_counts, self._set_union_count(sp_a, sp_b))
1003
1004  def test_dense_set_union_multirow_2d(self):
1005    for dtype in _DTYPES:
1006      self._test_dense_set_union_multirow_2d(dtype)
1007
1008  def _test_dense_set_union_multirow_2d(self, dtype):
1009    a_values = [[9, 1, 5], [2, 4, 3]]
1010    b_values = [[1, 9], [1, 2]]
1011    expected_indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]]
1012    expected_values = _values([1, 5, 9, 1, 2, 3, 4], dtype)
1013    expected_shape = [2, 4]
1014    expected_counts = [3, 4]
1015
1016    # Dense to dense.
1017    a = _constant(a_values, dtype=dtype)
1018    b = _constant(b_values, dtype=dtype)
1019    union = self._set_union(a, b)
1020    self._assert_set_operation(
1021        expected_indices, expected_values, expected_shape, union, dtype=dtype)
1022    self.assertAllEqual(expected_counts, self._set_union_count(a, b))
1023
1024  def test_set_union_duplicates_2d(self):
1025    for dtype in _DTYPES:
1026      self._test_set_union_duplicates_2d(dtype)
1027
1028  def _test_set_union_duplicates_2d(self, dtype):
1029    a_values = [[1, 1, 3]]
1030    b_values = [[1]]
1031    expected_indices = [[0, 0], [0, 1]]
1032    expected_values = _values([1, 3], dtype)
1033    expected_shape = [1, 2]
1034
1035    # Dense to sparse.
1036    a = _constant(a_values, dtype=dtype)
1037    sp_b = _dense_to_sparse(b_values, dtype=dtype)
1038    union = self._set_union(a, sp_b)
1039    self._assert_set_operation(
1040        expected_indices, expected_values, expected_shape, union, dtype=dtype)
1041    self.assertAllEqual([2], self._set_union_count(a, sp_b))
1042
1043    # Sparse to sparse.
1044    sp_a = _dense_to_sparse(a_values, dtype=dtype)
1045    union = self._set_union(sp_a, sp_b)
1046    self._assert_set_operation(
1047        expected_indices, expected_values, expected_shape, union, dtype=dtype)
1048    self.assertAllEqual([2], self._set_union_count(sp_a, sp_b))
1049
1050  def test_sparse_set_union_3d(self):
1051    for dtype in _DTYPES:
1052      self._test_sparse_set_union_3d(dtype)
1053
1054  def test_sparse_set_union_3d_invalid_indices(self):
1055    for dtype in _DTYPES:
1056      self._test_sparse_set_union_3d(dtype, invalid_indices=True)
1057
1058  def _test_sparse_set_union_3d(self, dtype, invalid_indices=False):
1059    if invalid_indices:
1060      indices = constant_op.constant(
1061          [
1062              [0, 1, 0],
1063              [0, 1, 1],  # 0,1
1064              [1, 0, 0],  # 1,0
1065              [0, 0, 0],
1066              [0, 0, 2],  # 0,0
1067              [1, 1, 0],
1068              [1, 1, 1],
1069              [1, 1, 2],  # 1,1
1070              # 2,0
1071              [2, 1, 1]  # 2,1
1072              # 3,*
1073          ],
1074          dtypes.int64)
1075    else:
1076      indices = constant_op.constant(
1077          [
1078              [0, 0, 0],
1079              [0, 0, 2],  # 0,0
1080              [0, 1, 0],
1081              [0, 1, 1],  # 0,1
1082              [1, 0, 0],  # 1,0
1083              [1, 1, 0],
1084              [1, 1, 1],
1085              [1, 1, 2],  # 1,1
1086              # 2,0
1087              [2, 1, 1]  # 2,1
1088              # 3,*
1089          ],
1090          dtypes.int64)
1091    sp_a = sparse_tensor_lib.SparseTensor(
1092        indices,
1093        _constant(
1094            [
1095                1,
1096                9,  # 0,0
1097                3,
1098                3,  # 0,1
1099                1,  # 1,0
1100                9,
1101                7,
1102                8,  # 1,1
1103                # 2,0
1104                5  # 2,1
1105                # 3,*
1106            ],
1107            dtype),
1108        constant_op.constant([4, 2, 3], dtypes.int64))
1109    sp_b = sparse_tensor_lib.SparseTensor(
1110        constant_op.constant(
1111            [
1112                [0, 0, 0],
1113                [0, 0, 3],  # 0,0
1114                # 0,1
1115                [1, 0, 0],  # 1,0
1116                [1, 1, 0],
1117                [1, 1, 1],  # 1,1
1118                [2, 0, 1],  # 2,0
1119                [2, 1, 1],  # 2,1
1120                [3, 0, 0],  # 3,0
1121                [3, 1, 0]  # 3,1
1122            ],
1123            dtypes.int64),
1124        _constant(
1125            [
1126                1,
1127                3,  # 0,0
1128                # 0,1
1129                3,  # 1,0
1130                7,
1131                8,  # 1,1
1132                2,  # 2,0
1133                5,  # 2,1
1134                4,  # 3,0
1135                4  # 3,1
1136            ],
1137            dtype),
1138        constant_op.constant([4, 2, 4], dtypes.int64))
1139
1140    if invalid_indices:
1141      with self.assertRaisesRegexp(errors_impl.OpError, "out of order"):
1142        self._set_union(sp_a, sp_b)
1143    else:
1144      expected_indices = [
1145          [0, 0, 0],
1146          [0, 0, 1],
1147          [0, 0, 2],  # 0,0
1148          [0, 1, 0],  # 0,1
1149          [1, 0, 0],
1150          [1, 0, 1],  # 1,0
1151          [1, 1, 0],
1152          [1, 1, 1],
1153          [1, 1, 2],  # 1,1
1154          [2, 0, 0],  # 2,0
1155          [2, 1, 0],  # 2,1
1156          [3, 0, 0],  # 3,0
1157          [3, 1, 0],  # 3,1
1158      ]
1159      expected_values = _values(
1160          [
1161              1,
1162              3,
1163              9,  # 0,0
1164              3,  # 0,1
1165              1,
1166              3,  # 1,0
1167              7,
1168              8,
1169              9,  # 1,1
1170              2,  # 2,0
1171              5,  # 2,1
1172              4,  # 3,0
1173              4,  # 3,1
1174          ],
1175          dtype)
1176      expected_shape = [4, 2, 3]
1177      expected_counts = [
1178          [
1179              3,  # 0,0
1180              1  # 0,1
1181          ],
1182          [
1183              2,  # 1,0
1184              3  # 1,1
1185          ],
1186          [
1187              1,  # 2,0
1188              1  # 2,1
1189          ],
1190          [
1191              1,  # 3,0
1192              1  # 3,1
1193          ]
1194      ]
1195
1196      intersection = self._set_union(sp_a, sp_b)
1197      self._assert_set_operation(
1198          expected_indices,
1199          expected_values,
1200          expected_shape,
1201          intersection,
1202          dtype=dtype)
1203      self.assertAllEqual(expected_counts, self._set_union_count(sp_a, sp_b))
1204
1205  def _set_union(self, a, b):
1206    # Validate that we get the same results with or without `validate_indices`,
1207    # and with a & b swapped.
1208    ops = (
1209        sets.set_union(
1210            a, b, validate_indices=True),
1211        sets.set_union(
1212            a, b, validate_indices=False),
1213        sets.set_union(
1214            b, a, validate_indices=True),
1215        sets.set_union(
1216            b, a, validate_indices=False),)
1217    for op in ops:
1218      self._assert_static_shapes(a, op)
1219    return self._run_equivalent_set_ops(ops)
1220
1221  def _set_union_count(self, a, b):
1222    op = sets.set_size(sets.set_union(a, b))
1223    with self.test_session() as sess:
1224      return sess.run(op)
1225
1226  def _assert_set_operation(self, expected_indices, expected_values,
1227                            expected_shape, sparse_tensor_value, dtype):
1228    self.assertAllEqual(expected_indices, sparse_tensor_value.indices)
1229    self.assertAllEqual(len(expected_indices), len(expected_values))
1230    self.assertAllEqual(len(expected_values), len(sparse_tensor_value.values))
1231    expected_set = set()
1232    actual_set = set()
1233    last_indices = None
1234    for indices, expected_value, actual_value in zip(
1235        expected_indices, expected_values, sparse_tensor_value.values):
1236      if dtype == dtypes.string:
1237        actual_value = actual_value.decode("utf-8")
1238      if last_indices and (last_indices[:-1] != indices[:-1]):
1239        self.assertEqual(expected_set, actual_set,
1240                         "Expected %s, got %s, at %s." % (expected_set,
1241                                                          actual_set, indices))
1242        expected_set.clear()
1243        actual_set.clear()
1244      expected_set.add(expected_value)
1245      actual_set.add(actual_value)
1246      last_indices = indices
1247    self.assertEqual(expected_set, actual_set,
1248                     "Expected %s, got %s, at %s." % (expected_set, actual_set,
1249                                                      last_indices))
1250    self.assertAllEqual(expected_shape, sparse_tensor_value.dense_shape)
1251
1252
1253if __name__ == "__main__":
1254  googletest.main()
1255