1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of tf.sets."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.ops import gen_set_ops
26from tensorflow.python.util.tf_export import tf_export
27
28
29_VALID_DTYPES = set([
30    dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
31    dtypes.uint8, dtypes.uint16, dtypes.string])
32
33
34@tf_export("sets.set_size")
35def set_size(a, validate_indices=True):
36  """Compute number of unique elements along last dimension of `a`.
37
38  Args:
39    a: `SparseTensor`, with indices sorted in row-major order.
40    validate_indices: Whether to validate the order and range of sparse indices
41       in `a`.
42
43  Returns:
44    `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with
45    rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the
46    number of unique elements in the corresponding `[0...n-1]` dimension of `a`.
47
48  Raises:
49    TypeError: If `a` is an invalid types.
50  """
51  a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
52  if not isinstance(a, sparse_tensor.SparseTensor):
53    raise TypeError("Expected `SparseTensor`, got %s." % a)
54  if a.values.dtype.base_dtype not in _VALID_DTYPES:
55    raise TypeError("Invalid dtype %s." % a.values.dtype)
56  # pylint: disable=protected-access
57  return gen_set_ops.set_size(
58      a.indices, a.values, a.dense_shape, validate_indices)
59
60ops.NotDifferentiable("SetSize")
61
62
63ops.NotDifferentiable("DenseToDenseSetOperation")
64ops.NotDifferentiable("DenseToSparseSetOperation")
65ops.NotDifferentiable("SparseToSparseSetOperation")
66
67
68def _convert_to_tensors_or_sparse_tensors(a, b):
69  """Convert to tensor types, and flip order if necessary.
70
71  Args:
72    a: `Tensor` or `SparseTensor` of the same type as `b`.
73    b: `Tensor` or `SparseTensor` of the same type as `a`.
74
75  Returns:
76    Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to
77    `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has
78    been flipped to make it dense,sparse instead of sparse,dense (since the set
79    ops do not support the latter).
80  """
81  a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a")
82  if a.dtype.base_dtype not in _VALID_DTYPES:
83    raise TypeError("'a' invalid dtype %s." % a.dtype)
84  b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b")
85  if b.dtype.base_dtype != a.dtype.base_dtype:
86    raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
87  if (isinstance(a, sparse_tensor.SparseTensor) and
88      not isinstance(b, sparse_tensor.SparseTensor)):
89    return b, a, True
90  return a, b, False
91
92
93def _set_operation(a, b, set_operation, validate_indices=True):
94  """Compute set operation of elements in last dimension of `a` and `b`.
95
96  All but the last dimension of `a` and `b` must match.
97
98  Args:
99    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
100        must be sorted in row-major order.
101    b: `Tensor` or `SparseTensor` of the same type as `a`. Must be
102        `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be
103        sorted in row-major order.
104    set_operation: String indicating set operation. See
105        SetOperationOp::SetOperationFromContext for valid values.
106    validate_indices: Whether to validate the order and range of sparse indices
107       in `a` and `b`.
108
109  Returns:
110    A `SparseTensor` with the same rank as `a` and `b`, and all but the last
111    dimension the same. Elements along the last dimension contain the results
112    of the set operation.
113
114  Raises:
115    TypeError: If inputs are invalid types.
116    ValueError: If `a` is sparse and `b` is dense.
117  """
118  if isinstance(a, sparse_tensor.SparseTensor):
119    if isinstance(b, sparse_tensor.SparseTensor):
120      indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation(
121          a.indices, a.values, a.dense_shape,
122          b.indices, b.values, b.dense_shape,
123          set_operation, validate_indices)
124    else:
125      raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
126                       "Please flip the order of your inputs.")
127  elif isinstance(b, sparse_tensor.SparseTensor):
128    indices, values, shape = gen_set_ops.dense_to_sparse_set_operation(
129        a, b.indices, b.values, b.dense_shape, set_operation, validate_indices)
130  else:
131    indices, values, shape = gen_set_ops.dense_to_dense_set_operation(
132        a, b, set_operation, validate_indices)
133  return sparse_tensor.SparseTensor(indices, values, shape)
134
135
136@tf_export("sets.set_intersection")
137def set_intersection(a, b, validate_indices=True):
138  """Compute set intersection of elements in last dimension of `a` and `b`.
139
140  All but the last dimension of `a` and `b` must match.
141
142  Example:
143
144  ```python
145    import tensorflow as tf
146    import collections
147
148    # Represent the following array of sets as a sparse tensor:
149    # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]])
150    a = collections.OrderedDict([
151        ((0, 0, 0), 1),
152        ((0, 0, 1), 2),
153        ((0, 1, 0), 3),
154        ((1, 0, 0), 4),
155        ((1, 1, 0), 5),
156        ((1, 1, 1), 6),
157    ])
158    a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2,2,2])
159
160    # b = np.array([[{1}, {}], [{4}, {5, 6, 7, 8}]])
161    b = collections.OrderedDict([
162        ((0, 0, 0), 1),
163        ((1, 0, 0), 4),
164        ((1, 1, 0), 5),
165        ((1, 1, 1), 6),
166        ((1, 1, 2), 7),
167        ((1, 1, 3), 8),
168    ])
169    b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4])
170
171    # `tf.sets.set_intersection` is applied to each aligned pair of sets.
172    tf.sets.set_intersection(a, b)
173
174    # The result will be equivalent to either of:
175    #
176    # np.array([[{1}, {}], [{4}, {5, 6}]])
177    #
178    # collections.OrderedDict([
179    #     ((0, 0, 0), 1),
180    #     ((1, 0, 0), 4),
181    #     ((1, 1, 0), 5),
182    #     ((1, 1, 1), 6),
183    # ])
184  ```
185
186  Args:
187    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
188        must be sorted in row-major order.
189    b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
190        must be sorted in row-major order.
191    validate_indices: Whether to validate the order and range of sparse indices
192       in `a` and `b`.
193
194  Returns:
195    A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
196    the last dimension the same. Elements along the last dimension contain the
197    intersections.
198  """
199  a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
200  return _set_operation(a, b, "intersection", validate_indices)
201
202
203@tf_export("sets.set_difference")
204def set_difference(a, b, aminusb=True, validate_indices=True):
205  """Compute set difference of elements in last dimension of `a` and `b`.
206
207  All but the last dimension of `a` and `b` must match.
208
209  Example:
210
211  ```python
212    import tensorflow as tf
213    import collections
214
215    # Represent the following array of sets as a sparse tensor:
216    # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]])
217    a = collections.OrderedDict([
218        ((0, 0, 0), 1),
219        ((0, 0, 1), 2),
220        ((0, 1, 0), 3),
221        ((1, 0, 0), 4),
222        ((1, 1, 0), 5),
223        ((1, 1, 1), 6),
224    ])
225    a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2])
226
227    # np.array([[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]])
228    b = collections.OrderedDict([
229        ((0, 0, 0), 1),
230        ((0, 0, 1), 3),
231        ((0, 1, 0), 2),
232        ((1, 0, 0), 4),
233        ((1, 0, 1), 5),
234        ((1, 1, 0), 5),
235        ((1, 1, 1), 6),
236        ((1, 1, 2), 7),
237        ((1, 1, 3), 8),
238    ])
239    b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4])
240
241    # `set_difference` is applied to each aligned pair of sets.
242    tf.sets.set_difference(a, b)
243
244    # The result will be equivalent to either of:
245    #
246    # np.array([[{2}, {3}], [{}, {}]])
247    #
248    # collections.OrderedDict([
249    #     ((0, 0, 0), 2),
250    #     ((0, 0, 1), 3),
251    # ])
252  ```
253
254  Args:
255    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
256        must be sorted in row-major order.
257    b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
258        must be sorted in row-major order.
259    aminusb: Whether to subtract `b` from `a`, vs vice versa.
260    validate_indices: Whether to validate the order and range of sparse indices
261       in `a` and `b`.
262
263  Returns:
264    A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
265    the last dimension the same. Elements along the last dimension contain the
266    differences.
267  """
268  a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b)
269  if flipped:
270    aminusb = not aminusb
271  return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices)
272
273
274@tf_export("sets.set_union")
275def set_union(a, b, validate_indices=True):
276  """Compute set union of elements in last dimension of `a` and `b`.
277
278  All but the last dimension of `a` and `b` must match.
279
280  Example:
281
282  ```python
283    import tensorflow as tf
284    import collections
285
286    # [[{1, 2}, {3}], [{4}, {5, 6}]]
287    a = collections.OrderedDict([
288        ((0, 0, 0), 1),
289        ((0, 0, 1), 2),
290        ((0, 1, 0), 3),
291        ((1, 0, 0), 4),
292        ((1, 1, 0), 5),
293        ((1, 1, 1), 6),
294    ])
295    a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2])
296
297    # [[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]]
298    b = collections.OrderedDict([
299        ((0, 0, 0), 1),
300        ((0, 0, 1), 3),
301        ((0, 1, 0), 2),
302        ((1, 0, 0), 4),
303        ((1, 0, 1), 5),
304        ((1, 1, 0), 5),
305        ((1, 1, 1), 6),
306        ((1, 1, 2), 7),
307        ((1, 1, 3), 8),
308    ])
309    b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4])
310
311    # `set_union` is applied to each aligned pair of sets.
312    tf.sets.set_union(a, b)
313
314    # The result will be a equivalent to either of:
315    #
316    # np.array([[{1, 2, 3}, {2, 3}], [{4, 5}, {5, 6, 7, 8}]])
317    #
318    # collections.OrderedDict([
319    #     ((0, 0, 0), 1),
320    #     ((0, 0, 1), 2),
321    #     ((0, 0, 2), 3),
322    #     ((0, 1, 0), 2),
323    #     ((0, 1, 1), 3),
324    #     ((1, 0, 0), 4),
325    #     ((1, 0, 1), 5),
326    #     ((1, 1, 0), 5),
327    #     ((1, 1, 1), 6),
328    #     ((1, 1, 2), 7),
329    #     ((1, 1, 3), 8),
330    # ])
331  ```
332
333  Args:
334    a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices
335        must be sorted in row-major order.
336    b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices
337        must be sorted in row-major order.
338    validate_indices: Whether to validate the order and range of sparse indices
339       in `a` and `b`.
340
341  Returns:
342    A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but
343    the last dimension the same. Elements along the last dimension contain the
344    unions.
345  """
346  a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
347  return _set_operation(a, b, "union", validate_indices)
348