1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15
16"""Helper library for handling infeed between hosts and TPUs.
17"""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23from six.moves import xrange  # pylint: disable=redefined-builtin
24
25from tensorflow.contrib.tpu.python.ops import tpu_ops
26from tensorflow.contrib.tpu.python.tpu import tpu_sharding
27
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.ops import array_ops
32
33
34class InfeedQueue(object):
35  """A helper object to build a device infeed queue.
36
37  The InfeedQueue builds the host-side and device-side Ops to enqueue and
38  dequeue elements, respectively, and ensures that their types and
39  shapes match.
40  """
41
42  def __init__(self,
43               number_of_tuple_elements=None,
44               tuple_types=None,
45               tuple_shapes=None,
46               shard_dimensions=None,
47               name=None):
48    """Creates a new InfeedQueue with the given configuration.
49
50    The configuration need not be fully specified at creation since it
51    can be modified subsequently by methods that set the values
52    explicitly or infer them from the shapes of inputs.
53
54    Args:
55      number_of_tuple_elements: the number of Tensors fed atomically through the
56        queue, must be present unless it can be inferred from other arguments.
57      tuple_types: if not None, a list of types of the elements of the queue.
58      tuple_shapes: if not None, a list of shapes of the elements of the queue.
59      shard_dimensions: if not None, a list of dimensions on which the
60        elements of the queue should be sharded during automatic
61        parallelization.
62      name: the name of the queue.
63
64    Raises:
65      ValueError: if number_of_tuple_elements <= 0; or
66        number_of_tuple_arguments, tuple_types, tuple_shapes, and
67        shard_dimensions are all None; or the length of tuple_types,
68        tuple_shapes, or shard_dimensions is not equal to
69        number_of_tuple_elements; or any element of shard_dimensions
70        can't be converted to a Dimension.
71      TypeError: if any element of tuple_types or tuple_shapes can't
72        be converted to a dtype or TensorShape, respectively.
73    """
74    self._frozen = False
75    self._generated_enqueue_ops = False
76    self._generated_dequeue_op = False
77    self._name = "InfeedQueue" if name is None else name
78    if number_of_tuple_elements is None:
79      if tuple_types is not None:
80        number_of_tuple_elements = len(tuple_types)
81      elif tuple_shapes is not None:
82        number_of_tuple_elements = len(tuple_shapes)
83      elif shard_dimensions is not None:
84        number_of_tuple_elements = len(shard_dimensions)
85      else:
86        raise ValueError(
87            "number of tuple elements cannot be inferred from InfeedQueue "
88            "constructor"
89        )
90    if number_of_tuple_elements <= 0:
91      raise ValueError("number_of_tuple_elements %d must be > 0" %
92                       number_of_tuple_elements)
93    # Make an empty sharding policy for each tuple element.
94    self._sharding_policies = [
95        tpu_sharding.ShardingPolicy()
96        for _ in xrange(number_of_tuple_elements)
97    ]
98    if tuple_types is not None:
99      self.set_tuple_types(tuple_types)
100    else:
101      self._tuple_types = None
102    if tuple_shapes is not None:
103      self.set_tuple_shapes(tuple_shapes)
104    else:
105      self._tuple_shapes = None
106    if shard_dimensions is not None:
107      self.set_shard_dimensions(shard_dimensions)
108    self._validate()
109
110  def _validate(self):
111    """Checks that the configuration is self-consistent.
112
113    Raises:
114      ValueError: if the shapes and sharding policies don't match.
115    """
116    if self.tuple_shapes is not None:
117      for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes):
118        # Raise an error if the policy is incompatible with the shape.
119        _ = policy.get_sharded_shape(shape)
120
121  @property
122  def number_of_tuple_elements(self):
123    """Returns the number of InfeedQueue tuple elements."""
124    return len(self._sharding_policies)
125
126  @property
127  def tuple_types(self):
128    """Returns the types of the InfeedQueue tuple elements."""
129    return self._tuple_types
130
131  def set_tuple_types(self, tuple_types):
132    """Sets the type of each element of the queue.
133
134    tuple_types must be a list of length
135    self.number_of_tuple_elements, and each element must be
136    convertible to a dtype.
137
138    Args:
139      tuple_types: the types of each queue element.
140
141    Raises:
142      ValueError: if tuple_types is not of length
143        self.number_of_tuple_elements.
144      TypeError: if an element of tuple_types cannot be converted to a
145        dtype.
146    """
147    if len(tuple_types) != self.number_of_tuple_elements:
148      raise ValueError("tuple_types is %s, but must be a list of length %d" %
149                       (str(tuple_types), self.number_of_tuple_elements))
150    if self._frozen:
151      for (frozen, updated) in zip(self._tuple_types, tuple_types):
152        if frozen != updated:
153          raise ValueError(
154              "Trying to update InfeedQueue with frozen configuration with an "
155              "incompatible type. Frozen types are %s, updated types are %s" % (
156                  str(self._tuple_types), str(tuple_types)))
157    else:
158      try:
159        self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types]
160      except (TypeError) as e:
161        raise TypeError(
162            "tuple_types is %s, but must be a list of elements each "
163            "convertible to dtype: got error %s" % (str(tuple_types), str(e)))
164
165  @property
166  def tuple_shapes(self):
167    """Returns the shapes of the InfeedQueue tuple elements."""
168    return self._tuple_shapes
169
170  def set_tuple_shapes(self, tuple_shapes):
171    """Sets the shape of each element of the queue.
172
173    tuple_shapes must be a list of length
174    self.number_of_tuple_elements, and each element must be
175    convertible to a TensorShape.
176
177    Args:
178      tuple_shapes: the shapes of each queue element.
179
180    Raises:
181      ValueError: if tuple_shapes is not of length
182        self.number_of_tuple_elements.
183      TypeError: if an element of tuple_shapes cannot be converted to
184        a TensorShape.
185    """
186    if len(tuple_shapes) != self.number_of_tuple_elements:
187      raise ValueError("tuple_shapes is %s, but must be a list of length %d" %
188                       (str(tuple_shapes), self.number_of_tuple_elements))
189    try:
190      tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes]
191    except (ValueError, TypeError) as e:
192      raise TypeError(
193          "tuple_shapes is %s, but must be a list of elements each "
194          "convertible to TensorShape: got error %s" % (str(tuple_shapes),
195                                                        str(e)))
196    if self._frozen:
197      for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes):
198        if frozen != updated:
199          raise ValueError(
200              "Trying to update InfeedQueue with frozen configuration with an "
201              "incompatible shape. Frozen shapes are %s, updated shapes are %s"
202              % (str(self._tuple_shapes), str(tuple_shapes)))
203    else:
204      self._tuple_shapes = tuple_shapes
205    self._validate()
206
207  @property
208  def sharding_policies(self):
209    """Returns the sharding policies of the InfeedQueue tuple elements."""
210    return self._sharding_policies
211
212  @property
213  def shard_dimensions(self):
214    """Gets the shard dimension of each tuple element.
215
216    Returns:
217      A list of length number_of_tuple_elements, where each list entry
218      is the shard dimension of that tuple element or None if the
219      shard dimension has not been set.
220    """
221    # The number of shards is always the same for all the policies.
222    return [policy.shard_dimension for policy in self._sharding_policies]
223
224  def set_shard_dimensions(self, shard_dimensions):
225    """Sets the shard_dimension of each element of the queue.
226
227    shard_dimensions must be a list of length
228    self.number_of_tuple_elements, and each element must be
229    convertible to a Dimension compatible with self.tuple_shapes.
230
231    Args:
232      shard_dimensions: the dimensions of each queue element.
233
234    Raises:
235      ValueError: if shard_dimensions is not of length
236        self.number_of_tuple_elements; or an element of
237        shard_dimensions cannot be converted to a Dimension; or an
238        element of shard_dimensions is a Dimension that is out of
239        range for the corresponding tuple element shape.
240    """
241    if len(shard_dimensions) != self.number_of_tuple_elements:
242      raise ValueError("shard_dimensions is %s, but must be a list of length %d"
243                       % (str(shard_dimensions),
244                          self.number_of_tuple_elements))
245    for (policy, dimension) in zip(self._sharding_policies, shard_dimensions):
246      policy.set_shard_dimension(dimension)
247    self._validate()
248
249  @property
250  def number_of_shards(self):
251    """Gets the number of shards to use for the InfeedQueue.
252
253    Returns:
254      Number of shards or None if the number of shards has not been set.
255    """
256    # The number of shards is always the same for all the policies.
257    return self._sharding_policies[0].number_of_shards
258
259  def set_number_of_shards(self, number_of_shards):
260    """Sets the number of shards to use for the InfeedQueue.
261
262    Args:
263      number_of_shards: number of ways to shard the InfeedQueue.
264
265    Raises:
266      ValueError: if number_of_shards is not > 0; or the policies have
267        been frozen and number_of_shards was already set to something
268        else.
269    """
270    for policy in self._sharding_policies:
271      policy.set_number_of_shards(number_of_shards)
272    self._validate()
273
274  def set_configuration_from_input_tensors(self, input_tensors):
275    """Sets the shapes and types of the queue tuple elements.
276
277    input_tensors is a list of Tensors whose types and shapes are used
278    to set the queue configuration.
279
280    Args:
281      input_tensors: list of Tensors of the same types and shapes as
282        the desired queue Tuple.
283
284    Raises:
285      ValueError: if input_tensors is not a list of length
286        self.number_of_tuple_elements
287    """
288    if len(input_tensors) != self.number_of_tuple_elements:
289      raise ValueError(
290          "input_tensors is %s, but should be a list of %d Tensors", (
291              str(input_tensors), self.number_of_tuple_elements))
292    self.set_tuple_shapes([t.shape for t in input_tensors])
293    self.set_tuple_types([t.dtype for t in input_tensors])
294
295  def set_configuration_from_sharded_input_tensors(self, input_tensors):
296    """Sets the shapes and types of the queue tuple elements.
297
298    input_tensors is a list of lists of Tensors whose types and shapes are used
299    to set the queue configuration. The length of the outer list is the number
300    of shards required, and each inner list is the tuple of Tensors to use to
301    determine the types and shapes of the corresponding shard. This method
302    depends on the shard dimension, and calling it freezes the shard policy.
303
304    Args:
305      input_tensors: list of lists of Tensors. The outer list length corresponds
306        to the desired number of shards, and each inner list is the size
307        and shape of the desired configuration of the corresponding shard.
308
309    Raises:
310      ValueError: if any inner list is not a list of length
311        self.number_of_tuple_elements; or the inner lists do not combine to
312        form a consistent unsharded shape.
313      TypeError: if the types of the Tensors in the inner lists do not match.
314    """
315    if not self._frozen:
316      # Unset the tuple shapes in case the configuration becomes
317      # transiently inconsistent.
318      self._tuple_shapes = None
319    number_of_shards = len(input_tensors)
320    self.set_number_of_shards(number_of_shards)
321    for t in input_tensors:
322      if len(t) != self.number_of_tuple_elements:
323        raise ValueError(
324            "input_tensors is %s but must be a list of lists, where each inner"
325            " list has length number_of_tuple_elements=%d" % (
326                str(input_tensors), self.number_of_tuple_elements))
327    # Transpose the inputs to make a list of shard shapes for each tuple
328    # element.
329    sharded_shapes = [[t[i].shape for t in input_tensors]
330                      for i in xrange(self.number_of_tuple_elements)]
331    # For each tuple, get the unsharded shape using that tuple's policy.
332    unsharded_shapes = [
333        policy.get_unsharded_shape(s)
334        for (policy, s) in zip(self._sharding_policies, sharded_shapes)
335    ]
336    self.set_tuple_shapes(unsharded_shapes)
337    for i in xrange(1, self.number_of_shards):
338      for (t1, t2) in zip(input_tensors[0], input_tensors[i]):
339        if t1.dtype != t2.dtype:
340          raise TypeError(
341              "types of the tuple elements of input_tensors %s are not "
342              "consistent" % str(input_tensors))
343    self.set_tuple_types([t.dtype for t in input_tensors[0]])
344
345  def freeze(self):
346    """Freezes the InfeedQueue so it can no longer be modified.
347
348    The configuration is implicitly frozen before any host-side or
349    device-side Ops are generated. The configuration cannot be frozen
350    until the types and shapes of the tuple elements have been set.
351
352    Raises:
353      ValueError: if the types or shapes of the tuple elements have not been
354      set.
355    """
356    self._frozen = True
357    if self._tuple_types is None:
358      raise ValueError(
359          "Can't freeze an InfeedQueue without setting all tuple types.")
360    if self._tuple_shapes is None:
361      raise ValueError(
362          "Can't freeze an InfeedQueue without setting all tuple shapes.")
363    for shape in self._tuple_shapes:
364      if shape.dims is None:
365        raise ValueError(
366            "Can't freeze an InfeedQueue without setting all tuple shapes.")
367    for policy in self._sharding_policies:
368      policy.freeze()
369    self._validate()
370
371  def generate_dequeue_op(self):
372    """Generates the device-side Op to dequeue a tuple from the queue.
373
374    Implicitly freezes the queue configuration if it is not already
375    frozen, which will raise errors if the shapes and types have not
376    been fully specified.
377
378    Returns:
379      A list of Outputs corresponding to a shard of infeed dequeued
380      into XLA, suitable for use within a replicated block.
381
382    Raises:
383      ValueError: if the types or shapes of the tuple elements have not been
384      set; or if a dequeue op has already been generated.
385    """
386    self.freeze()
387    if self._generated_dequeue_op:
388      raise ValueError("Can't generate two dequeue Ops from the same queue")
389    self._generated_dequeue_op = True
390    full_name = "%s/dequeue" % self._name
391    sharded_shapes = [
392        policy.get_sharded_shape(shape)
393        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
394    ]
395    return tpu_ops.infeed_dequeue_tuple(
396        dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
397
398  def _generate_enqueue_op(self,
399                           inputs,
400                           name_prefix,
401                           index,
402                           device=None,
403                           tpu_ordinal=-1):
404    """Generate a host-side Op to enqueue a tuple to the queue.
405
406    If device is None the inputs are all required to have the same
407    device specification, and the enqueue Op is colocated with
408    inputs[0]. Otherwise the enqueue Op is placed on 'device'.
409
410    Args:
411      inputs: a list of Tensors with the types and shapes of the tuple elements.
412      name_prefix: the base name for the Op.
413      index: the shard index, used to uniquify the Op name.
414      device: device to place the Op on, or None if it should be
415        colocated with the inputs.
416      tpu_ordinal: ordinal of the TPU device on the host to use for
417      infeed if device is a CPU device. Should be set to -1 if device
418      is a TPU device.
419
420    Returns:
421      An Op corresponding to a shard of infeed enqueued at the host,
422      suitable for use within a replicated block.
423
424    Raises:
425      ValueError: if device is None and inputs do not all have the
426        same device specification.
427    """
428    full_name = "%s/%d" % (name_prefix, index)
429    shapes = [t.shape for t in inputs]
430    if device is None:
431      devices = [t.device for t in inputs]
432      for i in xrange(1, self.number_of_tuple_elements):
433        if devices[0] != devices[i]:
434          raise ValueError(
435              "input devices for shard %d are %s, but should all be the same",
436              index, str(devices))
437      with ops.colocate_with(inputs[0]):
438        return tpu_ops.infeed_enqueue_tuple(
439            inputs=inputs,
440            shapes=shapes,
441            name=full_name,
442            device_ordinal=tpu_ordinal)
443    else:
444      with ops.device(device):
445        return tpu_ops.infeed_enqueue_tuple(
446            inputs=inputs,
447            shapes=shapes,
448            name=full_name,
449            device_ordinal=tpu_ordinal)
450
451  def generate_enqueue_ops(self, sharded_inputs, tpu_ordinal_function=None):
452    """Generates the host-side Ops to enqueue the shards of a tuple.
453
454    sharded_inputs is a list, one for each shard, of lists of
455    Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed
456    shard 0 if the queue. Returns the host-side Ops that must be run to
457    enqueue the sharded tuple. The Op for shard i is colocated with the inputs
458    for shard i.
459
460    Implicitly freezes the queue configuration if it is not already
461    frozen. If the configuration has already been frozen, and is not
462    compatible with the types and shapes of sharded_inputs, an error
463    will be raised.
464
465    Args:
466      sharded_inputs: a list of lists of Tensors. The length of the outer list
467        determines the number of shards. Each inner list indicates the types
468        and shapes of the tuples in the corresponding shard.
469      tpu_ordinal_function: if not None, a function that takes the
470        shard index as input and returns the ordinal of the TPU device
471        the shard's infeed should be placed on. tpu_ordinal_function must be
472        set if the inputs are placed on CPU devices.
473
474    Returns:
475      A list of host-side Ops, one for each shard, that when executed together
476      will enqueue a full-size element of infeed.
477
478    Raises:
479      ValueError: if the queue configuration has previously been frozen and the
480        shapes of the elements of sharded_inputs are not compatible with the
481        frozen configuration; or if the shapes of the elements of sharded_inputs
482        don't form a consistent unsharded tuple; or if the elements of a tuple
483        have different device constraints.
484      TypeError: if the queue configuration has previously been frozen and the
485        types of the elements of sharded_inputs are not compatible with the
486        frozen configuration; or if the types of the elements of sharded_inputs
487        don't form a consistent unsharded tuple.
488    """
489    self.set_configuration_from_sharded_input_tensors(sharded_inputs)
490    self.freeze()
491    if self._generated_enqueue_ops:
492      raise ValueError("Can't generate two enqueue Ops from the same queue")
493    self._generated_enqueue_ops = True
494    if tpu_ordinal_function is None:
495      tpu_ordinal_function = lambda index: -1
496    name_prefix = "%s/enqueue" % self._name
497    return [
498        self._generate_enqueue_op(shard, name_prefix, index,
499                                  tpu_ordinal=tpu_ordinal_function(index))
500        for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
501    ]
502
503  # TODO(misard) Generalize this to the case of systems that don't
504  # have 8 devices per host, and figure out what to do with
505  # model-parallelism.
506  def _default_placement_function(self, index):
507    return "/task:%d/device:CPU:0" % (index / 8)
508
509  def _default_ordinal_function(self, index):
510    return index % 8
511
512  # TODO(b/36470756) remove this from tutorials once we have a better story
513  # for automatic placement of input pipelines.
514  def split_inputs_and_generate_enqueue_ops(self,
515                                            inputs,
516                                            device_assignment=None,
517                                            placement_function=None,
518                                            tpu_ordinal_function=None):
519    """POORLY-PERFORMING ON MULTI-HOST SYSTEMS.
520
521    Generates the host-side Ops to enqueue a tuple.
522
523    This method performs poorly because it takes an entire input on a single
524    host, splits it, and distributes it to all of the cores. It is present only
525    to simplify tutorial examples.
526
527    inputs is a list of Tensors to use to feed the queue. Each input is split
528    into self.number_of_shards shards. Returns an Op for each shard to enqueue
529    the shard. The Op for shard i is placed on device placement_function(i).
530
531    Implicitly freezes the queue configuration if it is not already
532    frozen. If the configuration has already been frozen, and is not
533    compatible with the types and shapes of inputs, an error
534    will be raised.
535
536    Args:
537      inputs: a list of Tensors which indicates the types and shapes of the
538        queue tuple.
539     device_assignment: if not `None`, a TPU `DeviceAssignment`. If
540        device_assignment is not `None`, but `placement_function` and
541        `ordinal_function` are None, then `device_assignment` will be used to
542        place infeeds on the first k TPU shards, where k is the number of shards
543        in the queue. If all three are `None`, then default placement and
544        ordinal functions are used.
545      placement_function: if not None, a function that takes the shard
546        index as input and returns a device string indicating which
547        device the shard's infeed should be placed on. If placement_function
548        and tpu_ordinal_function are None, inputs are sharded round-robin
549        across the devices in the system.
550      tpu_ordinal_function: if not None, a function that takes the
551        shard index as input and returns the ordinal of the TPU device
552        the shard's infeed should be placed on. If placement_function
553        and tpu_ordinal_function are None, inputs are sharded round-robin
554        across the devices in the system.
555
556    Returns:
557      A list of host-side Ops, one for each shard, that when executed together
558      will enqueue a full-size element of infeed.
559
560    Raises:
561      ValueError: if the queue configuration has previously been frozen and the
562        shapes of the elements of inputs are not compatible with the frozen
563        configuration.
564      TypeError: if the queue configuration has previously been frozen and the
565        types of the elements of inputs are not compatible with the frozen
566        configuration.
567    """
568    if device_assignment is None:
569      if placement_function is None:
570        placement_function = self._default_placement_function
571      if tpu_ordinal_function is None:
572        tpu_ordinal_function = self._default_ordinal_function
573    else:
574
575      def _placement_function_from_map(index):
576        return device_assignment.host_device(replica=index)
577
578      def _ordinal_function_from_map(index):
579        return device_assignment.tpu_ordinal(replica=index)
580
581      if placement_function is None:
582        placement_function = _placement_function_from_map
583      if tpu_ordinal_function is None:
584        tpu_ordinal_function = _ordinal_function_from_map
585    self.set_configuration_from_input_tensors(inputs)
586    self.freeze()
587    if self._generated_enqueue_ops:
588      raise ValueError("Can't generate two enqueue Ops from the same queue")
589    self._generated_enqueue_ops = True
590    split_name_prefix = "%s/split" % self._name
591    if self.number_of_shards == 1:
592      transposed_sharded_inputs = [[inp] for inp in inputs]
593    else:
594
595      def split_fn(inp, num_shards, axis, name):
596        with ops.colocate_with(inp):
597          return array_ops.split(inp, num_shards, axis=axis, name=name)
598
599      transposed_sharded_inputs = [
600          split_fn(
601              inp,
602              self.number_of_shards,
603              axis=policy.shard_dimension,
604              name="%s/%d" % (split_name_prefix, index))
605          for (inp, policy, index) in zip(inputs, self._sharding_policies,
606                                          xrange(self.number_of_tuple_elements))
607      ]
608    sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs]
609                      for i in xrange(self.number_of_shards)]
610    name_prefix = "%s/enqueue" % self._name
611    return [
612        self._generate_enqueue_op(
613            shard,
614            name_prefix,
615            index,
616            device=placement_function(index),
617            tpu_ordinal=tpu_ordinal_function(index))
618        for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards))
619    ]
620