1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Helper functions for creating partitioned variables.
17
18This is a convenient abstraction to partition a large variable across
19multiple smaller variables that can be assigned to different devices.
20
21The full variable can be reconstructed by concatenating the smaller variables.
22Using partitioned variables instead of a single variable is mostly a
23performance choice.  It however also has an impact on:
24
251. Random initialization, as the random number generator is called once per
26   slice
272. Updates, as they happen in parallel across slices
28
29A key design goal is to allow a different graph to repartition a variable
30with the same name but different slicings, including possibly no partitions.
31
32TODO(touts): If an initializer provides a seed, the seed must be changed
33deterministically for each slice, maybe by adding one to it, otherwise each
34slice will use the same values.  Maybe this can be done by passing the
35slice offsets to the initializer functions.
36
37Typical usage:
38
39```python
40# Create a list of partitioned variables with:
41vs = create_partitioned_variables(
42    <shape>, <slicing>, <initializer>, name=<optional-name>)
43
44# Pass the list as inputs to embedding_lookup for sharded, parallel lookup:
45y = embedding_lookup(vs, ids, partition_strategy="div")
46
47# Or fetch the variables in parallel to speed up large matmuls:
48z = matmul(x, concat(slice_dim, vs))
49```
50"""
51from __future__ import absolute_import
52from __future__ import division
53from __future__ import print_function
54
55import math
56
57from tensorflow.python.framework import dtypes
58from tensorflow.python.framework import tensor_shape
59from tensorflow.python.ops import variable_scope
60from tensorflow.python.platform import tf_logging as logging
61from tensorflow.python.util.tf_export import tf_export
62
63__all__ = [
64    "create_partitioned_variables",
65    "variable_axis_size_partitioner",
66    "min_max_variable_partitioner",
67    "fixed_size_partitioner",
68]
69
70
71@tf_export("variable_axis_size_partitioner")
72def variable_axis_size_partitioner(
73    max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
74  """Get a partitioner for VariableScope to keep shards below `max_shard_bytes`.
75
76  This partitioner will shard a Variable along one axis, attempting to keep
77  the maximum shard size below `max_shard_bytes`.  In practice, this is not
78  always possible when sharding along only one axis.  When this happens,
79  this axis is sharded as much as possible (i.e., every dimension becomes
80  a separate shard).
81
82  If the partitioner hits the `max_shards` limit, then each shard may end up
83  larger than `max_shard_bytes`. By default `max_shards` equals `None` and no
84  limit on the number of shards is enforced.
85
86  One reasonable value for `max_shard_bytes` is `(64 << 20) - 1`, or almost
87  `64MB`, to keep below the protobuf byte limit.
88
89  Args:
90    max_shard_bytes: The maximum size any given shard is allowed to be.
91    axis: The axis to partition along.  Default: outermost axis.
92    bytes_per_string_element: If the `Variable` is of type string, this provides
93      an estimate of how large each scalar in the `Variable` is.
94    max_shards: The maximum number of shards in int created taking precedence
95      over `max_shard_bytes`.
96
97  Returns:
98    A partition function usable as the `partitioner` argument to
99    `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
100
101  Raises:
102    ValueError: If any of the byte counts are non-positive.
103  """
104  if max_shard_bytes < 1 or bytes_per_string_element < 1:
105    raise ValueError(
106        "Both max_shard_bytes and bytes_per_string_element must be positive.")
107  if max_shards and max_shards < 1:
108    raise ValueError(
109        "max_shards must be positive.")
110
111  def _partitioner(shape, dtype):
112    """Partitioner that partitions shards to have max_shard_bytes total size.
113
114    Args:
115      shape: A `TensorShape`.
116      dtype: A `DType`.
117
118    Returns:
119      A tuple representing how much to slice each axis in shape.
120
121    Raises:
122      ValueError: If shape is not a fully defined `TensorShape` or dtype is not
123        a `DType`.
124    """
125    if not isinstance(shape, tensor_shape.TensorShape):
126      raise ValueError("shape is not a TensorShape: %s" % shape)
127    if not shape.is_fully_defined():
128      raise ValueError("shape is not fully defined: %s" % shape)
129    if not isinstance(dtype, dtypes.DType):
130      raise ValueError("dtype is not a DType: %s" % dtype)
131
132    if dtype.base_dtype == dtypes.string:
133      element_size = bytes_per_string_element
134    else:
135      element_size = dtype.size
136
137    partitions = [1] * shape.ndims
138    bytes_per_slice = 1.0 * (
139        shape.num_elements() / shape[axis].value) * element_size
140    # How many slices can we fit on one shard of size at most max_shard_bytes?
141    # At least one slice is required.
142    slices_per_shard = max(1, math.floor(max_shard_bytes / bytes_per_slice))
143    # How many shards do we need for axis given that each shard fits
144    # slices_per_shard slices from a total of shape[axis].value slices?
145    axis_shards = int(math.ceil(1.0 * shape[axis].value / slices_per_shard))
146    if max_shards:
147      axis_shards = min(max_shards, axis_shards)
148
149    partitions[axis] = axis_shards
150
151    return partitions
152
153  return _partitioner
154
155
156@tf_export("min_max_variable_partitioner")
157def min_max_variable_partitioner(max_partitions=1, axis=0,
158                                 min_slice_size=256 << 10,
159                                 bytes_per_string_element=16):
160  """Partitioner to allocate minimum size per slice.
161
162  Returns a partitioner that partitions the variable of given shape and dtype
163  such that each partition has a minimum of `min_slice_size` slice of the
164  variable. The maximum number of such partitions (upper bound) is given by
165  `max_partitions`.
166
167  Args:
168    max_partitions: Upper bound on the number of partitions. Defaults to 1.
169    axis: Axis along which to partition the variable. Defaults to 0.
170    min_slice_size: Minimum size of the variable slice per partition. Defaults
171      to 256K.
172    bytes_per_string_element: If the `Variable` is of type string, this provides
173      an estimate of how large each scalar in the `Variable` is.
174
175  Returns:
176    A partition function usable as the `partitioner` argument to
177    `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
178
179  """
180  def _partitioner(shape, dtype):
181    """Partitioner that partitions list for a variable of given shape and type.
182
183    Ex: Consider partitioning a variable of type float32 with
184      shape=[1024, 1024].
185      If `max_partitions` >= 16, this function would return
186        [(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1].
187      If `max_partitions` < 16, this function would return
188        [`max_partitions`, 1].
189
190    Args:
191      shape: Shape of the variable.
192      dtype: Type of the variable.
193
194    Returns:
195      List of partitions for each axis (currently only one axis can be
196      partitioned).
197
198    Raises:
199      ValueError: If axis to partition along does not exist for the variable.
200    """
201    if axis >= len(shape):
202      raise ValueError("Can not partition variable along axis %d when shape is "
203                       "only %s" % (axis, shape))
204    if dtype.base_dtype == dtypes.string:
205      bytes_per_element = bytes_per_string_element
206    else:
207      bytes_per_element = dtype.size
208    total_size_bytes = shape.num_elements() * bytes_per_element
209    partitions = total_size_bytes / min_slice_size
210    partitions_list = [1] * len(shape)
211    # We can not partition the variable beyond what its shape or
212    # `max_partitions` allows.
213    partitions_list[axis] = max(1, min(shape[axis].value,
214                                       max_partitions,
215                                       int(math.ceil(partitions))))
216    return partitions_list
217  return _partitioner
218
219
220@tf_export("fixed_size_partitioner")
221def fixed_size_partitioner(num_shards, axis=0):
222  """Partitioner to specify a fixed number of shards along given axis.
223
224  Args:
225    num_shards: `int`, number of shards to partition variable.
226    axis: `int`, axis to partition on.
227
228  Returns:
229    A partition function usable as the `partitioner` argument to
230    `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
231  """
232  def _partitioner(shape, **unused_args):
233    partitions_list = [1] * len(shape)
234    partitions_list[axis] = min(num_shards, shape[axis].value)
235    return partitions_list
236  return _partitioner
237
238
239@tf_export("create_partitioned_variables")
240def create_partitioned_variables(
241    shape, slicing, initializer, dtype=dtypes.float32,
242    trainable=True, collections=None, name=None, reuse=None):
243  """Create a list of partitioned variables according to the given `slicing`.
244
245  Currently only one dimension of the full variable can be sliced, and the
246  full variable can be reconstructed by the concatenation of the returned
247  list along that dimension.
248
249  Args:
250    shape: List of integers.  The shape of the full variable.
251    slicing: List of integers.  How to partition the variable.
252      Must be of the same length as `shape`.  Each value
253      indicate how many slices to create in the corresponding
254      dimension.  Presently only one of the values can be more than 1;
255      that is, the variable can only be sliced along one dimension.
256
257      For convenience, The requested number of partitions does not have to
258      divide the corresponding dimension evenly.  If it does not, the
259      shapes of the partitions are incremented by 1 starting from partition
260      0 until all slack is absorbed.  The adjustment rules may change in the
261      future, but as you can save/restore these variables with different
262      slicing specifications this should not be a problem.
263    initializer: A `Tensor` of shape `shape` or a variable initializer
264      function.  If a function, it will be called once for each slice,
265      passing the shape and data type of the slice as parameters.  The
266      function must return a tensor with the same shape as the slice.
267    dtype: Type of the variables. Ignored if `initializer` is a `Tensor`.
268    trainable: If True also add all the variables to the graph collection
269      `GraphKeys.TRAINABLE_VARIABLES`.
270    collections: List of graph collections keys to add the variables to.
271      Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
272    name: Optional name for the full variable.  Defaults to
273      `"PartitionedVariable"` and gets uniquified automatically.
274    reuse: Boolean or `None`; if `True` and name is set, it would reuse
275      previously created variables. if `False` it will create new variables.
276      if `None`, it would inherit the parent scope reuse.
277
278  Returns:
279    A list of Variables corresponding to the slicing.
280
281  Raises:
282    ValueError: If any of the arguments is malformed.
283  """
284  logging.warn(
285      "create_partitioned_variables is deprecated.  Use "
286      "tf.get_variable with a partitioner set, or "
287      "tf.get_partitioned_variable_list, instead.")
288
289  if len(shape) != len(slicing):
290    raise ValueError("The 'shape' and 'slicing' of a partitioned Variable "
291                     "must have the length: shape: %s, slicing: %s" %
292                     (shape, slicing))
293  if len(shape) < 1:
294    raise ValueError("A partitioned Variable must have rank at least 1: "
295                     "shape: %s" % shape)
296
297  # Legacy: we are provided the slicing directly, so just pass it to
298  # the partitioner.
299  partitioner = lambda **unused_kwargs: slicing
300
301  with variable_scope.variable_scope(
302      name, "PartitionedVariable", reuse=reuse):
303    # pylint: disable=protected-access
304    partitioned_var = variable_scope._get_partitioned_variable(
305        name=None,
306        shape=shape,
307        dtype=dtype,
308        initializer=initializer,
309        trainable=trainable,
310        partitioner=partitioner,
311        collections=collections)
312    return list(partitioned_var)
313    # pylint: enable=protected-access
314