1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of tf.sets.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.ops import gen_set_ops 26from tensorflow.python.util.tf_export import tf_export 27 28 29_VALID_DTYPES = set([ 30 dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, 31 dtypes.uint8, dtypes.uint16, dtypes.string]) 32 33 34@tf_export("sets.set_size") 35def set_size(a, validate_indices=True): 36 """Compute number of unique elements along last dimension of `a`. 37 38 Args: 39 a: `SparseTensor`, with indices sorted in row-major order. 40 validate_indices: Whether to validate the order and range of sparse indices 41 in `a`. 42 43 Returns: 44 `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with 45 rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the 46 number of unique elements in the corresponding `[0...n-1]` dimension of `a`. 47 48 Raises: 49 TypeError: If `a` is an invalid types. 50 """ 51 a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") 52 if not isinstance(a, sparse_tensor.SparseTensor): 53 raise TypeError("Expected `SparseTensor`, got %s." % a) 54 if a.values.dtype.base_dtype not in _VALID_DTYPES: 55 raise TypeError("Invalid dtype %s." % a.values.dtype) 56 # pylint: disable=protected-access 57 return gen_set_ops.set_size( 58 a.indices, a.values, a.dense_shape, validate_indices) 59 60ops.NotDifferentiable("SetSize") 61 62 63ops.NotDifferentiable("DenseToDenseSetOperation") 64ops.NotDifferentiable("DenseToSparseSetOperation") 65ops.NotDifferentiable("SparseToSparseSetOperation") 66 67 68def _convert_to_tensors_or_sparse_tensors(a, b): 69 """Convert to tensor types, and flip order if necessary. 70 71 Args: 72 a: `Tensor` or `SparseTensor` of the same type as `b`. 73 b: `Tensor` or `SparseTensor` of the same type as `a`. 74 75 Returns: 76 Tuple of `(a, b, flipped)`, where `a` and `b` have been converted to 77 `Tensor` or `SparseTensor`, and `flipped` indicates whether the order has 78 been flipped to make it dense,sparse instead of sparse,dense (since the set 79 ops do not support the latter). 80 """ 81 a = sparse_tensor.convert_to_tensor_or_sparse_tensor(a, name="a") 82 if a.dtype.base_dtype not in _VALID_DTYPES: 83 raise TypeError("'a' invalid dtype %s." % a.dtype) 84 b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b") 85 if b.dtype.base_dtype != a.dtype.base_dtype: 86 raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype)) 87 if (isinstance(a, sparse_tensor.SparseTensor) and 88 not isinstance(b, sparse_tensor.SparseTensor)): 89 return b, a, True 90 return a, b, False 91 92 93def _set_operation(a, b, set_operation, validate_indices=True): 94 """Compute set operation of elements in last dimension of `a` and `b`. 95 96 All but the last dimension of `a` and `b` must match. 97 98 Args: 99 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 100 must be sorted in row-major order. 101 b: `Tensor` or `SparseTensor` of the same type as `a`. Must be 102 `SparseTensor` if `a` is `SparseTensor`. If sparse, indices must be 103 sorted in row-major order. 104 set_operation: String indicating set operation. See 105 SetOperationOp::SetOperationFromContext for valid values. 106 validate_indices: Whether to validate the order and range of sparse indices 107 in `a` and `b`. 108 109 Returns: 110 A `SparseTensor` with the same rank as `a` and `b`, and all but the last 111 dimension the same. Elements along the last dimension contain the results 112 of the set operation. 113 114 Raises: 115 TypeError: If inputs are invalid types. 116 ValueError: If `a` is sparse and `b` is dense. 117 """ 118 if isinstance(a, sparse_tensor.SparseTensor): 119 if isinstance(b, sparse_tensor.SparseTensor): 120 indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation( 121 a.indices, a.values, a.dense_shape, 122 b.indices, b.values, b.dense_shape, 123 set_operation, validate_indices) 124 else: 125 raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. " 126 "Please flip the order of your inputs.") 127 elif isinstance(b, sparse_tensor.SparseTensor): 128 indices, values, shape = gen_set_ops.dense_to_sparse_set_operation( 129 a, b.indices, b.values, b.dense_shape, set_operation, validate_indices) 130 else: 131 indices, values, shape = gen_set_ops.dense_to_dense_set_operation( 132 a, b, set_operation, validate_indices) 133 return sparse_tensor.SparseTensor(indices, values, shape) 134 135 136@tf_export("sets.set_intersection") 137def set_intersection(a, b, validate_indices=True): 138 """Compute set intersection of elements in last dimension of `a` and `b`. 139 140 All but the last dimension of `a` and `b` must match. 141 142 Example: 143 144 ```python 145 import tensorflow as tf 146 import collections 147 148 # Represent the following array of sets as a sparse tensor: 149 # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]]) 150 a = collections.OrderedDict([ 151 ((0, 0, 0), 1), 152 ((0, 0, 1), 2), 153 ((0, 1, 0), 3), 154 ((1, 0, 0), 4), 155 ((1, 1, 0), 5), 156 ((1, 1, 1), 6), 157 ]) 158 a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2,2,2]) 159 160 # b = np.array([[{1}, {}], [{4}, {5, 6, 7, 8}]]) 161 b = collections.OrderedDict([ 162 ((0, 0, 0), 1), 163 ((1, 0, 0), 4), 164 ((1, 1, 0), 5), 165 ((1, 1, 1), 6), 166 ((1, 1, 2), 7), 167 ((1, 1, 3), 8), 168 ]) 169 b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4]) 170 171 # `tf.sets.set_intersection` is applied to each aligned pair of sets. 172 tf.sets.set_intersection(a, b) 173 174 # The result will be equivalent to either of: 175 # 176 # np.array([[{1}, {}], [{4}, {5, 6}]]) 177 # 178 # collections.OrderedDict([ 179 # ((0, 0, 0), 1), 180 # ((1, 0, 0), 4), 181 # ((1, 1, 0), 5), 182 # ((1, 1, 1), 6), 183 # ]) 184 ``` 185 186 Args: 187 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 188 must be sorted in row-major order. 189 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 190 must be sorted in row-major order. 191 validate_indices: Whether to validate the order and range of sparse indices 192 in `a` and `b`. 193 194 Returns: 195 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 196 the last dimension the same. Elements along the last dimension contain the 197 intersections. 198 """ 199 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) 200 return _set_operation(a, b, "intersection", validate_indices) 201 202 203@tf_export("sets.set_difference") 204def set_difference(a, b, aminusb=True, validate_indices=True): 205 """Compute set difference of elements in last dimension of `a` and `b`. 206 207 All but the last dimension of `a` and `b` must match. 208 209 Example: 210 211 ```python 212 import tensorflow as tf 213 import collections 214 215 # Represent the following array of sets as a sparse tensor: 216 # a = np.array([[{1, 2}, {3}], [{4}, {5, 6}]]) 217 a = collections.OrderedDict([ 218 ((0, 0, 0), 1), 219 ((0, 0, 1), 2), 220 ((0, 1, 0), 3), 221 ((1, 0, 0), 4), 222 ((1, 1, 0), 5), 223 ((1, 1, 1), 6), 224 ]) 225 a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2]) 226 227 # np.array([[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]]) 228 b = collections.OrderedDict([ 229 ((0, 0, 0), 1), 230 ((0, 0, 1), 3), 231 ((0, 1, 0), 2), 232 ((1, 0, 0), 4), 233 ((1, 0, 1), 5), 234 ((1, 1, 0), 5), 235 ((1, 1, 1), 6), 236 ((1, 1, 2), 7), 237 ((1, 1, 3), 8), 238 ]) 239 b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4]) 240 241 # `set_difference` is applied to each aligned pair of sets. 242 tf.sets.set_difference(a, b) 243 244 # The result will be equivalent to either of: 245 # 246 # np.array([[{2}, {3}], [{}, {}]]) 247 # 248 # collections.OrderedDict([ 249 # ((0, 0, 0), 2), 250 # ((0, 0, 1), 3), 251 # ]) 252 ``` 253 254 Args: 255 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 256 must be sorted in row-major order. 257 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 258 must be sorted in row-major order. 259 aminusb: Whether to subtract `b` from `a`, vs vice versa. 260 validate_indices: Whether to validate the order and range of sparse indices 261 in `a` and `b`. 262 263 Returns: 264 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 265 the last dimension the same. Elements along the last dimension contain the 266 differences. 267 """ 268 a, b, flipped = _convert_to_tensors_or_sparse_tensors(a, b) 269 if flipped: 270 aminusb = not aminusb 271 return _set_operation(a, b, "a-b" if aminusb else "b-a", validate_indices) 272 273 274@tf_export("sets.set_union") 275def set_union(a, b, validate_indices=True): 276 """Compute set union of elements in last dimension of `a` and `b`. 277 278 All but the last dimension of `a` and `b` must match. 279 280 Example: 281 282 ```python 283 import tensorflow as tf 284 import collections 285 286 # [[{1, 2}, {3}], [{4}, {5, 6}]] 287 a = collections.OrderedDict([ 288 ((0, 0, 0), 1), 289 ((0, 0, 1), 2), 290 ((0, 1, 0), 3), 291 ((1, 0, 0), 4), 292 ((1, 1, 0), 5), 293 ((1, 1, 1), 6), 294 ]) 295 a = tf.SparseTensor(list(a.keys()), list(a.values()), dense_shape=[2, 2, 2]) 296 297 # [[{1, 3}, {2}], [{4, 5}, {5, 6, 7, 8}]] 298 b = collections.OrderedDict([ 299 ((0, 0, 0), 1), 300 ((0, 0, 1), 3), 301 ((0, 1, 0), 2), 302 ((1, 0, 0), 4), 303 ((1, 0, 1), 5), 304 ((1, 1, 0), 5), 305 ((1, 1, 1), 6), 306 ((1, 1, 2), 7), 307 ((1, 1, 3), 8), 308 ]) 309 b = tf.SparseTensor(list(b.keys()), list(b.values()), dense_shape=[2, 2, 4]) 310 311 # `set_union` is applied to each aligned pair of sets. 312 tf.sets.set_union(a, b) 313 314 # The result will be a equivalent to either of: 315 # 316 # np.array([[{1, 2, 3}, {2, 3}], [{4, 5}, {5, 6, 7, 8}]]) 317 # 318 # collections.OrderedDict([ 319 # ((0, 0, 0), 1), 320 # ((0, 0, 1), 2), 321 # ((0, 0, 2), 3), 322 # ((0, 1, 0), 2), 323 # ((0, 1, 1), 3), 324 # ((1, 0, 0), 4), 325 # ((1, 0, 1), 5), 326 # ((1, 1, 0), 5), 327 # ((1, 1, 1), 6), 328 # ((1, 1, 2), 7), 329 # ((1, 1, 3), 8), 330 # ]) 331 ``` 332 333 Args: 334 a: `Tensor` or `SparseTensor` of the same type as `b`. If sparse, indices 335 must be sorted in row-major order. 336 b: `Tensor` or `SparseTensor` of the same type as `a`. If sparse, indices 337 must be sorted in row-major order. 338 validate_indices: Whether to validate the order and range of sparse indices 339 in `a` and `b`. 340 341 Returns: 342 A `SparseTensor` whose shape is the same rank as `a` and `b`, and all but 343 the last dimension the same. Elements along the last dimension contain the 344 unions. 345 """ 346 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b) 347 return _set_operation(a, b, "union", validate_indices) 348