1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python dataset sparse tensor utility functitons."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.util import nest
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import sparse_tensor
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.ops import sparse_ops
26
27
28def any_sparse(classes):
29  """Checks for sparse tensor.
30
31  Args:
32    classes: a structure of objects that identify the dataset item classes
33
34  Returns:
35    `True` if `classes` contains a sparse tensor type and `False` otherwise.
36  """
37  return any([c is sparse_tensor.SparseTensor for c in nest.flatten(classes)])
38
39
40def as_dense_shapes(shapes, classes):
41  """Converts sparse tensor shapes to their physical shapes.
42
43  Args:
44    shapes: a structure of shapes to convert.
45    classes: a structure of objects that identify the dataset item classes
46
47  Returns:
48    a structure matching the nested structure of `shapes`, containing
49    `tensor_shape.unknown_shape()` at positions where `classes` contains
50    `tf.SparseTensor` and matching contents of `shapes` otherwise
51  """
52  ret = nest.pack_sequence_as(shapes, [
53      tensor_shape.unknown_shape() if c is sparse_tensor.SparseTensor else shape
54      for shape, c in zip(nest.flatten(shapes), nest.flatten(classes))
55  ])
56  return ret
57
58
59def as_dense_types(types, classes):
60  """Converts sparse tensor types to `dtypes.variant`.
61
62  Args:
63    types: a structure of types to convert.
64    classes: a structure of objects that identify the dataset item classes
65
66  Returns:
67    a structure matching the nested structure of `types`, containing
68    `dtypes.variant` at positions where `classes` contains `tf.SparseTensor` and
69    matching contents of `types` otherwise
70  """
71  ret = nest.pack_sequence_as(types, [
72      dtypes.variant if c is sparse_tensor.SparseTensor else ty
73      for ty, c in zip(nest.flatten(types), nest.flatten(classes))
74  ])
75  return ret
76
77
78def deserialize_sparse_tensors(tensors, types, shapes, classes):
79  """Deserializes sparse tensors.
80
81  Args:
82    tensors: a structure of tensors to deserialize.
83    types: a structure that holds information about types of `tensors`
84    shapes: a structure that holds information about shapes of `tensors`
85    classes: a structure of objects that identify the dataset item classes
86
87  Returns:
88    `tensors` with any serialized sparse tensors replaced by their deserialized
89    version.
90  """
91  ret = nest.pack_sequence_as(types, [
92      sparse_ops.deserialize_sparse(tensor, dtype=ty, rank=shape.ndims)
93      if c is sparse_tensor.SparseTensor else tensor
94      for (tensor, ty, shape, c) in zip(
95          nest.flatten(tensors), nest.flatten(types), nest.flatten(shapes),
96          nest.flatten(classes))
97  ])
98  return ret
99
100
101def get_classes(tensors):
102  """Gets classes for a structure of tensors.
103
104  Args:
105    tensors: the tensor structure to get classes for.
106
107  Returns:
108    a structure matching the nested structure of `tensors`, containing
109    `tf.SparseTensor` at positions where `tensors` contains a sparse tensor and
110    `tf.Tensor` otherwise
111  """
112  return nest.pack_sequence_as(tensors, [
113      sparse_tensor.SparseTensor
114      if isinstance(tensor, sparse_tensor.SparseTensor) else ops.Tensor
115      for tensor in nest.flatten(tensors)
116  ])
117
118
119def serialize_many_sparse_tensors(tensors):
120  """Serializes many sparse tensors into a batch.
121
122  Args:
123    tensors: a tensor structure to serialize.
124
125  Returns:
126    `tensors` with any sparse tensors replaced by the serialized batch.
127  """
128
129  ret = nest.pack_sequence_as(tensors, [
130      sparse_ops.serialize_many_sparse(tensor, out_type=dtypes.variant)
131      if sparse_tensor.is_sparse(tensor) else tensor
132      for tensor in nest.flatten(tensors)
133  ])
134  return ret
135
136
137def serialize_sparse_tensors(tensors):
138  """Serializes sparse tensors.
139
140  Args:
141    tensors: a tensor structure to serialize.
142
143  Returns:
144    `tensors` with any sparse tensors replaced by their serialized version.
145  """
146
147  ret = nest.pack_sequence_as(tensors, [
148      sparse_ops.serialize_sparse(tensor, out_type=dtypes.variant)
149      if isinstance(tensor, sparse_tensor.SparseTensor) else tensor
150      for tensor in nest.flatten(tensors)
151  ])
152  return ret
153