1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15 16"""Helper library for handling infeed between hosts and TPUs. 17""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23from six.moves import xrange # pylint: disable=redefined-builtin 24 25from tensorflow.contrib.tpu.python.ops import tpu_ops 26from tensorflow.contrib.tpu.python.tpu import tpu_sharding 27 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.ops import array_ops 32 33 34class InfeedQueue(object): 35 """A helper object to build a device infeed queue. 36 37 The InfeedQueue builds the host-side and device-side Ops to enqueue and 38 dequeue elements, respectively, and ensures that their types and 39 shapes match. 40 """ 41 42 def __init__(self, 43 number_of_tuple_elements=None, 44 tuple_types=None, 45 tuple_shapes=None, 46 shard_dimensions=None, 47 name=None): 48 """Creates a new InfeedQueue with the given configuration. 49 50 The configuration need not be fully specified at creation since it 51 can be modified subsequently by methods that set the values 52 explicitly or infer them from the shapes of inputs. 53 54 Args: 55 number_of_tuple_elements: the number of Tensors fed atomically through the 56 queue, must be present unless it can be inferred from other arguments. 57 tuple_types: if not None, a list of types of the elements of the queue. 58 tuple_shapes: if not None, a list of shapes of the elements of the queue. 59 shard_dimensions: if not None, a list of dimensions on which the 60 elements of the queue should be sharded during automatic 61 parallelization. 62 name: the name of the queue. 63 64 Raises: 65 ValueError: if number_of_tuple_elements <= 0; or 66 number_of_tuple_arguments, tuple_types, tuple_shapes, and 67 shard_dimensions are all None; or the length of tuple_types, 68 tuple_shapes, or shard_dimensions is not equal to 69 number_of_tuple_elements; or any element of shard_dimensions 70 can't be converted to a Dimension. 71 TypeError: if any element of tuple_types or tuple_shapes can't 72 be converted to a dtype or TensorShape, respectively. 73 """ 74 self._frozen = False 75 self._generated_enqueue_ops = False 76 self._generated_dequeue_op = False 77 self._name = "InfeedQueue" if name is None else name 78 if number_of_tuple_elements is None: 79 if tuple_types is not None: 80 number_of_tuple_elements = len(tuple_types) 81 elif tuple_shapes is not None: 82 number_of_tuple_elements = len(tuple_shapes) 83 elif shard_dimensions is not None: 84 number_of_tuple_elements = len(shard_dimensions) 85 else: 86 raise ValueError( 87 "number of tuple elements cannot be inferred from InfeedQueue " 88 "constructor" 89 ) 90 if number_of_tuple_elements <= 0: 91 raise ValueError("number_of_tuple_elements %d must be > 0" % 92 number_of_tuple_elements) 93 # Make an empty sharding policy for each tuple element. 94 self._sharding_policies = [ 95 tpu_sharding.ShardingPolicy() 96 for _ in xrange(number_of_tuple_elements) 97 ] 98 if tuple_types is not None: 99 self.set_tuple_types(tuple_types) 100 else: 101 self._tuple_types = None 102 if tuple_shapes is not None: 103 self.set_tuple_shapes(tuple_shapes) 104 else: 105 self._tuple_shapes = None 106 if shard_dimensions is not None: 107 self.set_shard_dimensions(shard_dimensions) 108 self._validate() 109 110 def _validate(self): 111 """Checks that the configuration is self-consistent. 112 113 Raises: 114 ValueError: if the shapes and sharding policies don't match. 115 """ 116 if self.tuple_shapes is not None: 117 for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes): 118 # Raise an error if the policy is incompatible with the shape. 119 _ = policy.get_sharded_shape(shape) 120 121 @property 122 def number_of_tuple_elements(self): 123 """Returns the number of InfeedQueue tuple elements.""" 124 return len(self._sharding_policies) 125 126 @property 127 def tuple_types(self): 128 """Returns the types of the InfeedQueue tuple elements.""" 129 return self._tuple_types 130 131 def set_tuple_types(self, tuple_types): 132 """Sets the type of each element of the queue. 133 134 tuple_types must be a list of length 135 self.number_of_tuple_elements, and each element must be 136 convertible to a dtype. 137 138 Args: 139 tuple_types: the types of each queue element. 140 141 Raises: 142 ValueError: if tuple_types is not of length 143 self.number_of_tuple_elements. 144 TypeError: if an element of tuple_types cannot be converted to a 145 dtype. 146 """ 147 if len(tuple_types) != self.number_of_tuple_elements: 148 raise ValueError("tuple_types is %s, but must be a list of length %d" % 149 (str(tuple_types), self.number_of_tuple_elements)) 150 if self._frozen: 151 for (frozen, updated) in zip(self._tuple_types, tuple_types): 152 if frozen != updated: 153 raise ValueError( 154 "Trying to update InfeedQueue with frozen configuration with an " 155 "incompatible type. Frozen types are %s, updated types are %s" % ( 156 str(self._tuple_types), str(tuple_types))) 157 else: 158 try: 159 self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types] 160 except (TypeError) as e: 161 raise TypeError( 162 "tuple_types is %s, but must be a list of elements each " 163 "convertible to dtype: got error %s" % (str(tuple_types), str(e))) 164 165 @property 166 def tuple_shapes(self): 167 """Returns the shapes of the InfeedQueue tuple elements.""" 168 return self._tuple_shapes 169 170 def set_tuple_shapes(self, tuple_shapes): 171 """Sets the shape of each element of the queue. 172 173 tuple_shapes must be a list of length 174 self.number_of_tuple_elements, and each element must be 175 convertible to a TensorShape. 176 177 Args: 178 tuple_shapes: the shapes of each queue element. 179 180 Raises: 181 ValueError: if tuple_shapes is not of length 182 self.number_of_tuple_elements. 183 TypeError: if an element of tuple_shapes cannot be converted to 184 a TensorShape. 185 """ 186 if len(tuple_shapes) != self.number_of_tuple_elements: 187 raise ValueError("tuple_shapes is %s, but must be a list of length %d" % 188 (str(tuple_shapes), self.number_of_tuple_elements)) 189 try: 190 tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes] 191 except (ValueError, TypeError) as e: 192 raise TypeError( 193 "tuple_shapes is %s, but must be a list of elements each " 194 "convertible to TensorShape: got error %s" % (str(tuple_shapes), 195 str(e))) 196 if self._frozen: 197 for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes): 198 if frozen != updated: 199 raise ValueError( 200 "Trying to update InfeedQueue with frozen configuration with an " 201 "incompatible shape. Frozen shapes are %s, updated shapes are %s" 202 % (str(self._tuple_shapes), str(tuple_shapes))) 203 else: 204 self._tuple_shapes = tuple_shapes 205 self._validate() 206 207 @property 208 def sharding_policies(self): 209 """Returns the sharding policies of the InfeedQueue tuple elements.""" 210 return self._sharding_policies 211 212 @property 213 def shard_dimensions(self): 214 """Gets the shard dimension of each tuple element. 215 216 Returns: 217 A list of length number_of_tuple_elements, where each list entry 218 is the shard dimension of that tuple element or None if the 219 shard dimension has not been set. 220 """ 221 # The number of shards is always the same for all the policies. 222 return [policy.shard_dimension for policy in self._sharding_policies] 223 224 def set_shard_dimensions(self, shard_dimensions): 225 """Sets the shard_dimension of each element of the queue. 226 227 shard_dimensions must be a list of length 228 self.number_of_tuple_elements, and each element must be 229 convertible to a Dimension compatible with self.tuple_shapes. 230 231 Args: 232 shard_dimensions: the dimensions of each queue element. 233 234 Raises: 235 ValueError: if shard_dimensions is not of length 236 self.number_of_tuple_elements; or an element of 237 shard_dimensions cannot be converted to a Dimension; or an 238 element of shard_dimensions is a Dimension that is out of 239 range for the corresponding tuple element shape. 240 """ 241 if len(shard_dimensions) != self.number_of_tuple_elements: 242 raise ValueError("shard_dimensions is %s, but must be a list of length %d" 243 % (str(shard_dimensions), 244 self.number_of_tuple_elements)) 245 for (policy, dimension) in zip(self._sharding_policies, shard_dimensions): 246 policy.set_shard_dimension(dimension) 247 self._validate() 248 249 @property 250 def number_of_shards(self): 251 """Gets the number of shards to use for the InfeedQueue. 252 253 Returns: 254 Number of shards or None if the number of shards has not been set. 255 """ 256 # The number of shards is always the same for all the policies. 257 return self._sharding_policies[0].number_of_shards 258 259 def set_number_of_shards(self, number_of_shards): 260 """Sets the number of shards to use for the InfeedQueue. 261 262 Args: 263 number_of_shards: number of ways to shard the InfeedQueue. 264 265 Raises: 266 ValueError: if number_of_shards is not > 0; or the policies have 267 been frozen and number_of_shards was already set to something 268 else. 269 """ 270 for policy in self._sharding_policies: 271 policy.set_number_of_shards(number_of_shards) 272 self._validate() 273 274 def set_configuration_from_input_tensors(self, input_tensors): 275 """Sets the shapes and types of the queue tuple elements. 276 277 input_tensors is a list of Tensors whose types and shapes are used 278 to set the queue configuration. 279 280 Args: 281 input_tensors: list of Tensors of the same types and shapes as 282 the desired queue Tuple. 283 284 Raises: 285 ValueError: if input_tensors is not a list of length 286 self.number_of_tuple_elements 287 """ 288 if len(input_tensors) != self.number_of_tuple_elements: 289 raise ValueError( 290 "input_tensors is %s, but should be a list of %d Tensors", ( 291 str(input_tensors), self.number_of_tuple_elements)) 292 self.set_tuple_shapes([t.shape for t in input_tensors]) 293 self.set_tuple_types([t.dtype for t in input_tensors]) 294 295 def set_configuration_from_sharded_input_tensors(self, input_tensors): 296 """Sets the shapes and types of the queue tuple elements. 297 298 input_tensors is a list of lists of Tensors whose types and shapes are used 299 to set the queue configuration. The length of the outer list is the number 300 of shards required, and each inner list is the tuple of Tensors to use to 301 determine the types and shapes of the corresponding shard. This method 302 depends on the shard dimension, and calling it freezes the shard policy. 303 304 Args: 305 input_tensors: list of lists of Tensors. The outer list length corresponds 306 to the desired number of shards, and each inner list is the size 307 and shape of the desired configuration of the corresponding shard. 308 309 Raises: 310 ValueError: if any inner list is not a list of length 311 self.number_of_tuple_elements; or the inner lists do not combine to 312 form a consistent unsharded shape. 313 TypeError: if the types of the Tensors in the inner lists do not match. 314 """ 315 if not self._frozen: 316 # Unset the tuple shapes in case the configuration becomes 317 # transiently inconsistent. 318 self._tuple_shapes = None 319 number_of_shards = len(input_tensors) 320 self.set_number_of_shards(number_of_shards) 321 for t in input_tensors: 322 if len(t) != self.number_of_tuple_elements: 323 raise ValueError( 324 "input_tensors is %s but must be a list of lists, where each inner" 325 " list has length number_of_tuple_elements=%d" % ( 326 str(input_tensors), self.number_of_tuple_elements)) 327 # Transpose the inputs to make a list of shard shapes for each tuple 328 # element. 329 sharded_shapes = [[t[i].shape for t in input_tensors] 330 for i in xrange(self.number_of_tuple_elements)] 331 # For each tuple, get the unsharded shape using that tuple's policy. 332 unsharded_shapes = [ 333 policy.get_unsharded_shape(s) 334 for (policy, s) in zip(self._sharding_policies, sharded_shapes) 335 ] 336 self.set_tuple_shapes(unsharded_shapes) 337 for i in xrange(1, self.number_of_shards): 338 for (t1, t2) in zip(input_tensors[0], input_tensors[i]): 339 if t1.dtype != t2.dtype: 340 raise TypeError( 341 "types of the tuple elements of input_tensors %s are not " 342 "consistent" % str(input_tensors)) 343 self.set_tuple_types([t.dtype for t in input_tensors[0]]) 344 345 def freeze(self): 346 """Freezes the InfeedQueue so it can no longer be modified. 347 348 The configuration is implicitly frozen before any host-side or 349 device-side Ops are generated. The configuration cannot be frozen 350 until the types and shapes of the tuple elements have been set. 351 352 Raises: 353 ValueError: if the types or shapes of the tuple elements have not been 354 set. 355 """ 356 self._frozen = True 357 if self._tuple_types is None: 358 raise ValueError( 359 "Can't freeze an InfeedQueue without setting all tuple types.") 360 if self._tuple_shapes is None: 361 raise ValueError( 362 "Can't freeze an InfeedQueue without setting all tuple shapes.") 363 for shape in self._tuple_shapes: 364 if shape.dims is None: 365 raise ValueError( 366 "Can't freeze an InfeedQueue without setting all tuple shapes.") 367 for policy in self._sharding_policies: 368 policy.freeze() 369 self._validate() 370 371 def generate_dequeue_op(self): 372 """Generates the device-side Op to dequeue a tuple from the queue. 373 374 Implicitly freezes the queue configuration if it is not already 375 frozen, which will raise errors if the shapes and types have not 376 been fully specified. 377 378 Returns: 379 A list of Outputs corresponding to a shard of infeed dequeued 380 into XLA, suitable for use within a replicated block. 381 382 Raises: 383 ValueError: if the types or shapes of the tuple elements have not been 384 set; or if a dequeue op has already been generated. 385 """ 386 self.freeze() 387 if self._generated_dequeue_op: 388 raise ValueError("Can't generate two dequeue Ops from the same queue") 389 self._generated_dequeue_op = True 390 full_name = "%s/dequeue" % self._name 391 sharded_shapes = [ 392 policy.get_sharded_shape(shape) 393 for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) 394 ] 395 return tpu_ops.infeed_dequeue_tuple( 396 dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) 397 398 def _generate_enqueue_op(self, 399 inputs, 400 name_prefix, 401 index, 402 device=None, 403 tpu_ordinal=-1): 404 """Generate a host-side Op to enqueue a tuple to the queue. 405 406 If device is None the inputs are all required to have the same 407 device specification, and the enqueue Op is colocated with 408 inputs[0]. Otherwise the enqueue Op is placed on 'device'. 409 410 Args: 411 inputs: a list of Tensors with the types and shapes of the tuple elements. 412 name_prefix: the base name for the Op. 413 index: the shard index, used to uniquify the Op name. 414 device: device to place the Op on, or None if it should be 415 colocated with the inputs. 416 tpu_ordinal: ordinal of the TPU device on the host to use for 417 infeed if device is a CPU device. Should be set to -1 if device 418 is a TPU device. 419 420 Returns: 421 An Op corresponding to a shard of infeed enqueued at the host, 422 suitable for use within a replicated block. 423 424 Raises: 425 ValueError: if device is None and inputs do not all have the 426 same device specification. 427 """ 428 full_name = "%s/%d" % (name_prefix, index) 429 shapes = [t.shape for t in inputs] 430 if device is None: 431 devices = [t.device for t in inputs] 432 for i in xrange(1, self.number_of_tuple_elements): 433 if devices[0] != devices[i]: 434 raise ValueError( 435 "input devices for shard %d are %s, but should all be the same", 436 index, str(devices)) 437 with ops.colocate_with(inputs[0]): 438 return tpu_ops.infeed_enqueue_tuple( 439 inputs=inputs, 440 shapes=shapes, 441 name=full_name, 442 device_ordinal=tpu_ordinal) 443 else: 444 with ops.device(device): 445 return tpu_ops.infeed_enqueue_tuple( 446 inputs=inputs, 447 shapes=shapes, 448 name=full_name, 449 device_ordinal=tpu_ordinal) 450 451 def generate_enqueue_ops(self, sharded_inputs, tpu_ordinal_function=None): 452 """Generates the host-side Ops to enqueue the shards of a tuple. 453 454 sharded_inputs is a list, one for each shard, of lists of 455 Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed 456 shard 0 if the queue. Returns the host-side Ops that must be run to 457 enqueue the sharded tuple. The Op for shard i is colocated with the inputs 458 for shard i. 459 460 Implicitly freezes the queue configuration if it is not already 461 frozen. If the configuration has already been frozen, and is not 462 compatible with the types and shapes of sharded_inputs, an error 463 will be raised. 464 465 Args: 466 sharded_inputs: a list of lists of Tensors. The length of the outer list 467 determines the number of shards. Each inner list indicates the types 468 and shapes of the tuples in the corresponding shard. 469 tpu_ordinal_function: if not None, a function that takes the 470 shard index as input and returns the ordinal of the TPU device 471 the shard's infeed should be placed on. tpu_ordinal_function must be 472 set if the inputs are placed on CPU devices. 473 474 Returns: 475 A list of host-side Ops, one for each shard, that when executed together 476 will enqueue a full-size element of infeed. 477 478 Raises: 479 ValueError: if the queue configuration has previously been frozen and the 480 shapes of the elements of sharded_inputs are not compatible with the 481 frozen configuration; or if the shapes of the elements of sharded_inputs 482 don't form a consistent unsharded tuple; or if the elements of a tuple 483 have different device constraints. 484 TypeError: if the queue configuration has previously been frozen and the 485 types of the elements of sharded_inputs are not compatible with the 486 frozen configuration; or if the types of the elements of sharded_inputs 487 don't form a consistent unsharded tuple. 488 """ 489 self.set_configuration_from_sharded_input_tensors(sharded_inputs) 490 self.freeze() 491 if self._generated_enqueue_ops: 492 raise ValueError("Can't generate two enqueue Ops from the same queue") 493 self._generated_enqueue_ops = True 494 if tpu_ordinal_function is None: 495 tpu_ordinal_function = lambda index: -1 496 name_prefix = "%s/enqueue" % self._name 497 return [ 498 self._generate_enqueue_op(shard, name_prefix, index, 499 tpu_ordinal=tpu_ordinal_function(index)) 500 for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) 501 ] 502 503 # TODO(misard) Generalize this to the case of systems that don't 504 # have 8 devices per host, and figure out what to do with 505 # model-parallelism. 506 def _default_placement_function(self, index): 507 return "/task:%d/device:CPU:0" % (index / 8) 508 509 def _default_ordinal_function(self, index): 510 return index % 8 511 512 # TODO(b/36470756) remove this from tutorials once we have a better story 513 # for automatic placement of input pipelines. 514 def split_inputs_and_generate_enqueue_ops(self, 515 inputs, 516 device_assignment=None, 517 placement_function=None, 518 tpu_ordinal_function=None): 519 """POORLY-PERFORMING ON MULTI-HOST SYSTEMS. 520 521 Generates the host-side Ops to enqueue a tuple. 522 523 This method performs poorly because it takes an entire input on a single 524 host, splits it, and distributes it to all of the cores. It is present only 525 to simplify tutorial examples. 526 527 inputs is a list of Tensors to use to feed the queue. Each input is split 528 into self.number_of_shards shards. Returns an Op for each shard to enqueue 529 the shard. The Op for shard i is placed on device placement_function(i). 530 531 Implicitly freezes the queue configuration if it is not already 532 frozen. If the configuration has already been frozen, and is not 533 compatible with the types and shapes of inputs, an error 534 will be raised. 535 536 Args: 537 inputs: a list of Tensors which indicates the types and shapes of the 538 queue tuple. 539 device_assignment: if not `None`, a TPU `DeviceAssignment`. If 540 device_assignment is not `None`, but `placement_function` and 541 `ordinal_function` are None, then `device_assignment` will be used to 542 place infeeds on the first k TPU shards, where k is the number of shards 543 in the queue. If all three are `None`, then default placement and 544 ordinal functions are used. 545 placement_function: if not None, a function that takes the shard 546 index as input and returns a device string indicating which 547 device the shard's infeed should be placed on. If placement_function 548 and tpu_ordinal_function are None, inputs are sharded round-robin 549 across the devices in the system. 550 tpu_ordinal_function: if not None, a function that takes the 551 shard index as input and returns the ordinal of the TPU device 552 the shard's infeed should be placed on. If placement_function 553 and tpu_ordinal_function are None, inputs are sharded round-robin 554 across the devices in the system. 555 556 Returns: 557 A list of host-side Ops, one for each shard, that when executed together 558 will enqueue a full-size element of infeed. 559 560 Raises: 561 ValueError: if the queue configuration has previously been frozen and the 562 shapes of the elements of inputs are not compatible with the frozen 563 configuration. 564 TypeError: if the queue configuration has previously been frozen and the 565 types of the elements of inputs are not compatible with the frozen 566 configuration. 567 """ 568 if device_assignment is None: 569 if placement_function is None: 570 placement_function = self._default_placement_function 571 if tpu_ordinal_function is None: 572 tpu_ordinal_function = self._default_ordinal_function 573 else: 574 575 def _placement_function_from_map(index): 576 return device_assignment.host_device(replica=index) 577 578 def _ordinal_function_from_map(index): 579 return device_assignment.tpu_ordinal(replica=index) 580 581 if placement_function is None: 582 placement_function = _placement_function_from_map 583 if tpu_ordinal_function is None: 584 tpu_ordinal_function = _ordinal_function_from_map 585 self.set_configuration_from_input_tensors(inputs) 586 self.freeze() 587 if self._generated_enqueue_ops: 588 raise ValueError("Can't generate two enqueue Ops from the same queue") 589 self._generated_enqueue_ops = True 590 split_name_prefix = "%s/split" % self._name 591 if self.number_of_shards == 1: 592 transposed_sharded_inputs = [[inp] for inp in inputs] 593 else: 594 595 def split_fn(inp, num_shards, axis, name): 596 with ops.colocate_with(inp): 597 return array_ops.split(inp, num_shards, axis=axis, name=name) 598 599 transposed_sharded_inputs = [ 600 split_fn( 601 inp, 602 self.number_of_shards, 603 axis=policy.shard_dimension, 604 name="%s/%d" % (split_name_prefix, index)) 605 for (inp, policy, index) in zip(inputs, self._sharding_policies, 606 xrange(self.number_of_tuple_elements)) 607 ] 608 sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs] 609 for i in xrange(self.number_of_shards)] 610 name_prefix = "%s/enqueue" % self._name 611 return [ 612 self._generate_enqueue_op( 613 shard, 614 name_prefix, 615 index, 616 device=placement_function(index), 617 tpu_ordinal=tpu_ordinal_function(index)) 618 for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) 619 ] 620