1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""An in-process, local XLA client in Python, supporting AOT compilation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import enum # pylint: disable=g-bad-import-order 23import inspect 24import itertools 25import os 26 27import numpy as np 28 29from tensorflow.compiler.xla import xla_data_pb2 30from tensorflow.compiler.xla.python import pywrap_xla as c_api 31 32 33# Most functions are snake_case for consistency with other modules, 34# whereas method names of ComputationBuilder and LocalComputation are 35# CamelCase for consistency with XLA. 36# pylint: disable=invalid-name 37 38 39_OP_METADATA_FIELDS = [ 40 'op_type', 41 'op_name', 42 'source_file', 43 'source_line', 44] 45OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) 46 47 48def OpMetadataToProto(pyobj): 49 proto = xla_data_pb2.OpMetadata() 50 for field in _OP_METADATA_FIELDS: 51 attr = getattr(pyobj, field) 52 if attr is not None: 53 setattr(proto, field, attr) 54 return proto 55 56 57def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): 58 """Helper for use in source mapping that returns an OpMetadata object.""" 59 full_filename, lineno = inspect.stack()[skip_frames][1:3] 60 filename = os.path.basename(full_filename) 61 return OpMetadata( 62 op_type=op_type, 63 op_name=op_name, 64 source_file=filename, 65 source_line=lineno) 66 67 68class PaddingType(enum.Enum): 69 VALID = 1 70 SAME = 2 71 72 73def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, 74 window_strides): 75 """Maps PaddingType (VALID or SAME) to pad values (list of pairs of ints).""" 76 if padding_type == PaddingType.VALID: 77 return [(0, 0)] * len(window_strides) 78 79 out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) 80 pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0) 81 for out_size, stride, filter_size, in_size 82 in zip(out_shape, window_strides, rhs_dims, lhs_dims)] 83 return [(pad_size // 2, pad_size - pad_size // 2) 84 for pad_size in pad_sizes] 85 86 87_UNARY_OPS = [ 88 'Not', 89 'Abs', 90 'Exp', 91 'Floor', 92 'Round', 93 'Ceil', 94 'Log', 95 'Sign', 96 'Cos', 97 'Sin', 98 'Tanh', 99 'SqrtF32', 100 'SquareF32', 101 'IsFinite', 102 'ReciprocalF32', 103 'Neg', 104 'Sort', 105] 106 107_BINARY_OPS = [ 108 'Eq', 109 'Ne', 110 'Ge', 111 'Gt', 112 'Lt', 113 'Le', 114 'Add', 115 'Sub', 116 'Mul', 117 'Div', 118 'Rem', 119 'Max', 120 'Min', 121 'And', 122 'Or', 123 'Pow', 124] 125 126XLA_ELEMENT_TYPE_TO_DTYPE = { 127 xla_data_pb2.F32: np.dtype(np.float32), 128 xla_data_pb2.F64: np.dtype(np.float64), 129 xla_data_pb2.S32: np.dtype(np.int32), 130 xla_data_pb2.S64: np.dtype(np.int64), 131 xla_data_pb2.U32: np.dtype(np.uint32), 132 xla_data_pb2.U64: np.dtype(np.uint64), 133 xla_data_pb2.PRED: np.dtype(np.bool), 134 xla_data_pb2.TUPLE: np.dtype(np.object), 135} 136 137# Note the conversion on the key. Numpy has a known issue wherein dtype hashing 138# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, 139# when keying by dtype in this dict, we use the string form of dtypes. 140DTYPE_TO_XLA_ELEMENT_TYPE = { 141 str(v): k 142 for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items() 143} 144 145 146class LocalBuffer(object): 147 """Represents a handle to data owned by XLA. 148 149 The referent is ready for use in executing a local, compiled 150 Computation. On XLA platforms involving a device (e.g. GPU), this 151 means the referent is in device memory. 152 """ 153 154 def __init__(self, c_local_shaped_buffer): 155 self.c_local_shaped_buffer = c_local_shaped_buffer 156 self._delete = c_api.DeleteLocalShapedBuffer 157 158 @staticmethod 159 def from_py(npval, layout_fn=None): 160 npval = require_numpy_array_layout(npval) 161 if layout_fn: 162 shape = Shape.from_numpy(npval) 163 shape = shape.map_leaves(layout_fn) 164 else: 165 shape = None 166 return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape)) 167 168 def to_py(self): 169 return self.c_local_shaped_buffer.ToLiteral() 170 171 def delete(self): 172 if self.c_local_shaped_buffer is not None: 173 self._delete(self.c_local_shaped_buffer) 174 self.c_local_shaped_buffer = None 175 176 def is_deleted(self): 177 return self.c_local_shaped_buffer is None 178 179 def __del__(self): 180 self.delete() 181 182 183class Shape(object): 184 """XLA shape. 185 186 Represents an XLA shape by a corresponding Python/Numpy type and a 187 list of dimensions, which are themselves Shapes in case this one 188 represents an XLA tuple. 189 """ 190 191 def __init__(self, np_dtype, dimensions, minor_to_major=None): 192 assert isinstance(dimensions, tuple) 193 self.np_dtype = np_dtype 194 self._dimensions = dimensions 195 self._minor_to_major = minor_to_major 196 self._check_minor_to_major() 197 198 def __eq__(self, other): 199 # pylint: disable=protected-access 200 return (self.np_dtype == other.np_dtype and 201 self._dimensions == other._dimensions and 202 self._minor_to_major == other._minor_to_major) 203 204 def __repr__(self): 205 return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, ' 206 'minor_to_major={!r})').format(self.np_dtype, self._dimensions, 207 self._minor_to_major) 208 209 def element_type(self): 210 return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)] 211 212 def is_tuple(self): 213 return self.element_type() == xla_data_pb2.TUPLE 214 215 def dimensions(self): 216 if self.is_tuple(): 217 raise ValueError('Tuple shape has no dimensions') 218 return self._dimensions 219 220 def minor_to_major(self): 221 return self._minor_to_major 222 223 def tuple_shapes(self): 224 if not self.is_tuple(): 225 raise ValueError('Shape is not a tuple shape') 226 return self._dimensions 227 228 def rank(self): 229 return len(self.dimensions()) 230 231 def map_leaves(self, f): 232 """Map f over each leaf-level array subshape. 233 234 Args: 235 f: The function to apply. Whenever f returns None, the identity is 236 applied instead. 237 238 Returns: 239 A new Shape with the mapped leaves. 240 """ 241 if self.is_tuple(): 242 children = tuple(child.map_leaves(f) for child in self.tuple_shapes()) 243 return Shape(np.dtype('O'), children) 244 else: 245 mapped = f(self) 246 return self if mapped is None else mapped 247 248 def _check_minor_to_major(self): 249 mtm = self._minor_to_major 250 if self.is_tuple(): 251 assert mtm is None, self 252 if mtm is not None: 253 assert self.rank() == len(mtm), self 254 assert sorted(mtm) == range(len(mtm)), self 255 256 def update_minor_to_major(self, minor_to_major): 257 if not isinstance(minor_to_major, tuple): 258 raise TypeError('minor_to_major must be a tuple') 259 updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major) 260 updated._check_minor_to_major() # pylint: disable=protected-access 261 return updated 262 263 @staticmethod 264 def from_numpy(npval): 265 266 def convert(npval): 267 if isinstance(npval, tuple): 268 return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval)) 269 else: 270 return Shape(npval.dtype, np.shape(npval)) 271 272 return convert(require_numpy_array_layout(npval)) 273 274 275def _wrap_shape(shape_info): 276 dtype, dims = shape_info 277 element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] 278 if element_type == xla_data_pb2.TUPLE: 279 dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims) 280 return Shape(dtype, dims) 281 282 283def _wrap_data_handle(handle): 284 cdh = xla_data_pb2.ComputationDataHandle() 285 cdh.handle = handle 286 return cdh 287 288 289def _unwrap_data_handle(handle_proto): 290 return handle_proto.handle 291 292 293def _unwrap_data_handles(handle_protos): 294 return [_unwrap_data_handle(cdh) for cdh in handle_protos] 295 296 297def require_numpy_array_layout(value): 298 if isinstance(value, tuple): 299 return tuple(require_numpy_array_layout(x) for x in value) 300 else: 301 return np.require(value, requirements=['C', 'A']) 302 303 304class CompileOptions(object): 305 """Python object for XLA compile options. 306 307 These options can be passed to the 'compile' step when using a local XLA 308 client. 309 """ 310 311 def __init__(self): 312 self.generate_hlo_graph = None 313 314 315def transfer_to_infeed(value, replica_number=None): 316 """Transfers the given value into the XLA infeed queue. 317 318 XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with 319 a totally ordered stream of values. This is dequeued from XLA computations via 320 the Infeed() operation. 321 322 Args: 323 value: the value that the caller would like to enqueue into the XLA infeed 324 queue 325 replica_number: the replica number to infeed the value to -- if not 326 provided, then the default replica (trivially replica 0) is used. 327 """ 328 if replica_number is None: 329 c_api.TransferToInfeedLocal(require_numpy_array_layout(value)) 330 else: 331 c_api.TransferToInfeedLocalReplica( 332 require_numpy_array_layout(value), replica_number) 333 334 335def transfer_from_outfeed(shape, replica_number=None): 336 """Transfers a literal of the given shape from replica_number's outfeed. 337 338 Args: 339 shape: The shape of the value to transfer from outfeed. 340 replica_number: The replica number ordinal to transfer the outfeed value 341 from. (Each replica has a distinct outfeed queue.) 342 343 Returns: 344 The literal value that is produced from the outfeed queue. 345 """ 346 return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0) 347 348 349class LocalComputation(object): 350 """Python wrapper for a local XLA Computation. 351 352 A LocalComputation can be executed if it is compiled. Otherwise, it 353 can still be used as a Computation where required by the 354 ComputationBuilder methods. 355 """ 356 357 def __init__(self, c_local_computation, is_compiled): 358 self.c_local_computation = c_local_computation 359 self.is_compiled = is_compiled 360 361 # Ensure a reference to C-based destructor for use in __del__. 362 if is_compiled: 363 assert isinstance(c_local_computation, c_api.CompiledLocalComputation) 364 self._delete = c_api.DeleteCompiledLocalComputation 365 else: 366 assert isinstance(c_local_computation, c_api.LocalComputation) 367 self._delete = c_api.DeleteLocalComputation 368 369 def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): 370 """Compiles an un-compiled local computation. 371 372 Local computations are the result of a "LocalComputationBuild'ing" process 373 -- they start in uncompiled form, and via a call to Compile() turn into a 374 compiled local computation. 375 376 Raises: 377 ValueError: if this is already a compiled local computation. 378 379 Arguments: 380 argument_shapes: parameter shapes -- they are first laid out by layout_fn 381 if layout_fn is provided. Otherwise, the default layout for those shapes 382 will be used. 383 compile_options: options to use for compilation, includes an optional 384 laid out result shape for the computation. 385 layout_fn: lambda that is used to lay out the argument/result shapes. 386 387 Returns: 388 A newly *compiled* local computation instance. 389 """ 390 if self.is_compiled: 391 raise ValueError('Attempt to compile a compiled local XLA computation.') 392 393 if layout_fn: 394 argument_shapes = [ 395 shape.map_leaves(layout_fn) for shape in argument_shapes 396 ] 397 result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) 398 result_shape = result_shape.map_leaves(layout_fn) 399 compile_options = compile_options or CompileOptions() 400 compile_options.result_shape = result_shape 401 return LocalComputation( 402 self.c_local_computation.Compile(argument_shapes, compile_options), 403 is_compiled=True) 404 405 def CompileWithExampleArguments(self, 406 arguments=(), 407 compile_options=None, 408 layout_fn=None): 409 return self.Compile( 410 argument_shapes=[Shape.from_numpy(arg) for arg in arguments], 411 compile_options=compile_options, 412 layout_fn=layout_fn) 413 414 def Execute(self, arguments=(), layout_fn=None): 415 """Execute with Python values as arguments and return value.""" 416 if not self.is_compiled: 417 raise ValueError('Cannot execute an uncompiled local XLA computation.') 418 argument_shapes = [Shape.from_numpy(arg) for arg in arguments] 419 if layout_fn: 420 argument_shapes = [ 421 shape.map_leaves(layout_fn) for shape in argument_shapes 422 ] 423 else: 424 argument_shapes = [None for shape in argument_shapes] 425 arguments = tuple(map(require_numpy_array_layout, arguments)) 426 return self.c_local_computation.Execute(arguments, argument_shapes) 427 428 def ExecuteWithLocalBuffers(self, arguments=()): 429 """Execute with LocalBuffer arguments and return value.""" 430 if not self.is_compiled: 431 raise ValueError('Cannot execute an uncompiled local XLA computation.') 432 arguments = tuple(arguments) 433 if any(arg.is_deleted() for arg in arguments): 434 raise ValueError('Executing with deleted local buffer argument') 435 return LocalBuffer( 436 self.c_local_computation.ExecuteWithShapedBuffers( 437 [arg.c_local_shaped_buffer for arg in arguments])) 438 439 def __del__(self): 440 self._delete(self.c_local_computation) 441 442 443class ComputationBuilder(object): 444 """XLA computation builder. 445 446 Enqueues XLA ops in sequence and in order to build a 447 LocalComputation, which in turn can be compiled into a 448 CompiledLocalComputation, which in turn can be locally executed. 449 """ 450 451 # The methods of this class map 1-to-1 onto the XLA C++ 452 # computation builder API. Therefore, there's no need to laboriously list 453 # arguments and return values for every method, especially where it's obvious. 454 # 455 # pylint: disable=g-doc-return-or-yield 456 # pylint: disable=g-doc-args 457 458 def __init__(self, name): 459 self._client = c_api.LocalComputationBuilder(name.encode('utf8')) 460 self._parameter_numbering = itertools.count() 461 462 def Build(self): 463 return LocalComputation(self._client.Build(), is_compiled=False) 464 465 def SetOpMetadata(self, op_metadata): 466 """Set metadata for operations that are about to be enqueued.""" 467 self._client.SetOpMetadata(op_metadata) 468 469 def ClearOpMetadata(self): 470 """Clear metadata for operations that are about to be enqueued.""" 471 self._client.ClearOpMetadata() 472 473 def Infeed(self, shape): 474 """Enqueues an infeed op onto the computation. 475 476 Infeed operations dequeue data of the given shape from the device's infeed 477 queue for subsequent use in the computation. 478 479 Returns: 480 A ComputationDataHandle message. 481 """ 482 return _wrap_data_handle(self._client.Infeed(shape)) 483 484 def Outfeed(self, operand): 485 """Enqueues an outfeed op onto the computation. 486 487 Outfeed operations enqueue data, using the given operand, onto the XLA 488 outfeed queue for subsequent dequeue via the client API. 489 """ 490 self._client.Outfeed( 491 _unwrap_data_handle(operand), self.GetShape(operand), 492 ''.encode('utf-8')) 493 494 def Constant(self, value): 495 """Enqueues a constant op onto the computation. 496 497 Args: 498 value: value for the constant, as a np.array with an explicit dtype set 499 to one of the supported types. 500 501 Returns: 502 A ComputationDataHandle message. 503 """ 504 value = require_numpy_array_layout(value) 505 return _wrap_data_handle(self._client.ConstantLiteral(value)) 506 507 def ConstantF32Scalar(self, value): 508 """Convenience method to enqueue a scalar F32 constant op. 509 510 Args: 511 value: a floating-point number. 512 513 Returns: 514 A ComputationDataHandle message. 515 """ 516 return self.Constant(np.array(value, dtype=np.float32)) 517 518 def ConstantF64Scalar(self, value): 519 """Convenience method to enqueue a scalar F32 constant op. 520 521 Args: 522 value: a floating-point number. 523 524 Returns: 525 A ComputationDataHandle message. 526 """ 527 return self.Constant(np.array(value, dtype=np.float64)) 528 529 def ConstantS32Scalar(self, value): 530 """Convenience method to enqueue a scalar S32 constant op. 531 532 Args: 533 value: a floating-point number. 534 535 Returns: 536 A ComputationDataHandle message. 537 """ 538 return self.Constant(np.array(value, dtype=np.int32)) 539 540 def ConstantS64Scalar(self, value): 541 """Convenience method to enqueue a scalar S64 constant op. 542 543 Args: 544 value: a floating-point number. 545 546 Returns: 547 A ComputationDataHandle message. 548 """ 549 return self.Constant(np.array(value, dtype=np.int64)) 550 551 def ConstantPredScalar(self, value): 552 """Convenience method to enqueue a scalar PRED constant op. 553 554 Args: 555 value: a boolean value. 556 557 Returns: 558 A ComputationDataHandle message. 559 """ 560 return self.Constant(np.array(value, dtype=np.bool)) 561 562 def ParameterWithShape(self, shape, name=None, parameter_num=None): 563 """Enqueues a Parameter op onto the computation, given a shape. 564 565 Args: 566 shape: the parameter's shape as a Shape object. 567 name: optional string name for the parameter. 568 parameter_num: parameter number in the computation function. If None, 569 the next linear parameter number is used. The default value capability 570 can be used for auto-numbering. If you're using auto-numbering for some 571 parameters, use it for *all* parameters to avoid clashes. 572 573 Returns: 574 A ComputationDataHandle message. 575 """ 576 if name is None: 577 name = '' 578 if parameter_num is None: 579 parameter_num = next(self._parameter_numbering) 580 581 return _wrap_data_handle( 582 self._client.Parameter(parameter_num, shape, name.encode('utf8'))) 583 584 def ParameterFromNumpy(self, value, name=None, parameter_num=None): 585 """Enqueues a Parameter op onto the computation. 586 587 Args: 588 value: a Numpy array, or a nested tuple thereof, from which the 589 shape is inferred. 590 name: as in ParameterWithShape. 591 parameter_num: as in ParameterWithShape. 592 593 Returns: 594 A ComputationDataHandle message. 595 """ 596 return self.ParameterWithShape( 597 Shape.from_numpy(value), name=name, parameter_num=parameter_num) 598 599 def Broadcast(self, operand, sizes): 600 """Enqueues a broadcast operation onto the computation. 601 602 Args: 603 operand: the operand ComputationDataHandle to broadcast. 604 sizes: an iterable of broadcast sizes. 605 606 Returns: 607 A ComputationDataHandle representing the added broadcast op. 608 """ 609 return _wrap_data_handle( 610 self._client.Broadcast(_unwrap_data_handle(operand), sizes)) 611 612 def Concatenate(self, operands, dimension): 613 """Enqueues a concatenate operation onto the computation. 614 615 Args: 616 operands: the operands to concatenate. 617 dimension: the dimension in which to perform the concatenation. 618 619 Returns: 620 A ComputationDataHandle representing the added concatenate op. 621 """ 622 return _wrap_data_handle( 623 self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) 624 625 def ConvertElementType(self, operand, new_element_type): 626 """Enqueues an element type conversion operation onto the computation. 627 628 Args: 629 operand: the operand to convert. 630 new_element_type: the target primitive type. 631 632 Returns: 633 A ComputationDataHandle representing the added conversion op. 634 """ 635 return _wrap_data_handle( 636 self._client.ConvertElementType( 637 _unwrap_data_handle(operand), new_element_type)) 638 639 def GetShape(self, operand): 640 return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) 641 642 def GetReturnValueShape(self): 643 return _wrap_shape(self._client.GetReturnValueShape()) 644 645 def GetComputationStats(self): 646 raise NotImplementedError() 647 648 def Pad(self, operand, padding_value, padding_config): 649 """Enqueues a Pad operation onto the computation. 650 651 Args: 652 operand: ComputationDataHandle representing the array to pad. 653 padding_value: ComputationDataHandle representing the scalar pad value. 654 padding_config: either an xla_data_pb2.PaddingConfig or a list of integer 655 triples (edge_padding_low, edge_padding_high, interior_padding) 656 representing the configuration of the padding operation. 657 658 Returns: 659 A ComputationDataHandle representing the added pad op. 660 """ 661 if not isinstance(padding_config, xla_data_pb2.PaddingConfig): 662 padding_config = GetPaddingConfigFromTriples(padding_config) 663 return _wrap_data_handle( 664 self._client.Pad(_unwrap_data_handle(operand), 665 _unwrap_data_handle(padding_value), 666 padding_config)) 667 668 def Reshape(self, operand, dimensions, new_sizes): 669 """Reshape op.""" 670 return _wrap_data_handle( 671 self._client.Reshape( 672 _unwrap_data_handle(operand), dimensions, new_sizes)) 673 674 def CrossReplicaSum(self, operand): 675 """CrossReplicaSum op. 676 677 Args: 678 operand: the operand to sum across replica instances. 679 680 Returns: 681 A ComputationDataHandle that has the sum of the value among all replicas. 682 """ 683 return _wrap_data_handle( 684 self._client.CrossReplicaSum(_unwrap_data_handle(operand))) 685 686 def Collapse(self, operand, dimensions): 687 """Collapse op.""" 688 return _wrap_data_handle( 689 self._client.Collapse(_unwrap_data_handle(operand), dimensions)) 690 691 def Trans(self, operand): 692 """Specialized matrix transpose op.""" 693 return _wrap_data_handle( 694 self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) 695 696 def Transpose(self, operand, permutation): 697 """Transpose op.""" 698 return _wrap_data_handle( 699 self._client.Transpose(_unwrap_data_handle(operand), permutation)) 700 701 def Rev(self, operand, dimensions): 702 """Rev op.""" 703 return _wrap_data_handle( 704 self._client.Rev(_unwrap_data_handle(operand), dimensions)) 705 706 def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin 707 """Clamp op.""" 708 return _wrap_data_handle( 709 self._client.Clamp(_unwrap_data_handle(min), 710 _unwrap_data_handle(operand), 711 _unwrap_data_handle(max))) 712 713 def SelectAndScatter(self, operand, select, window_dimensions, window_strides, 714 padding, source, init_value, scatter): 715 """Select and scatter op, used by the gradient of ReduceWindow. 716 717 Args: 718 operand: ComputationDataHandle for array of dimension N and type T over 719 which the windows slide. 720 select: Computation of type (T, T) -> Pred to apply to the elements of 721 each window to indicate which element is selected. 722 window_dimensions: sequence of N integers for dimensions of the window. 723 window_strides: sequence of N integers for the strides of the window. 724 padding: PaddingType representing either 'SAME' or 'VALID ' padding. 725 source: ComputationDataHandle for array of type T with values to scatter. 726 init_value: ComputationDataHandle of scalar type T for initial out value. 727 scatter: Computation of type (T, T) -> T to apply to each scatter source 728 element with its destination element. 729 730 Returns: 731 A ComputationDataHandle representing the added SelectAndScatter op. 732 """ 733 pads = _convert_padding_type_to_pad_values( 734 padding, self.GetShape(operand).dimensions(), 735 window_dimensions, window_strides) 736 return _wrap_data_handle( 737 self._client.SelectAndScatterWithGeneralPadding( 738 _unwrap_data_handle(operand), select.c_local_computation, 739 window_dimensions, window_strides, pads, 740 _unwrap_data_handle(source), _unwrap_data_handle(init_value), 741 scatter.c_local_computation)) 742 743 def Select(self, pred, on_true, on_false): 744 """Element-wise selection op. 745 746 Constructs an output array from elements of two input arrays, based on the 747 values of a predicate array. 748 """ 749 return _wrap_data_handle( 750 self._client.Select( 751 _unwrap_data_handle(pred), 752 _unwrap_data_handle(on_true), 753 _unwrap_data_handle(on_false))) 754 755 def Slice(self, operand, start_indices, limit_indices, strides=None): 756 """Enqueues a slice operation onto the computation. 757 758 Args: 759 operand: ComputationDataHandle for the N dimensional array to be sliced. 760 start_indices: iterable of N integers containing the starting indices of 761 the slice for each dimension. 762 limit_indices: iterable of N integers containing the ending indices 763 (exclusive) of the slice for each dimension. 764 strides: optional iterable of N integers containing the stride sizes for 765 each dimension. 766 767 Returns: 768 A ComputationDataHandle representing the added Slice op. 769 """ 770 if strides is None: 771 start_indices = list(start_indices) 772 strides = [1] * len(start_indices) 773 return _wrap_data_handle( 774 self._client.Slice( 775 _unwrap_data_handle(operand), 776 start_indices, 777 limit_indices, 778 strides)) 779 780 def DynamicSlice(self, operand, start_indices, slice_sizes): 781 """Enqueues a slice op with dynamic start indices onto the computation. 782 783 Args: 784 operand: ComputationDataHandle for the N dimensional array to be sliced. 785 start_indices: ComputationDataHandle for the 1D array of N integers 786 containing the starting indices of the slice. 787 slice_sizes: iterable of N integers containing the slice sizes in each 788 dimension. 789 790 Returns: 791 A ComputationDataHandle representing the added DynamicSlice op. 792 """ 793 return _wrap_data_handle( 794 self._client.DynamicSlice( 795 _unwrap_data_handle(operand), 796 _unwrap_data_handle(start_indices), 797 slice_sizes)) 798 799 def DynamicUpdateSlice(self, operand, update, start_indices): 800 """Enqueues a dynamic update slice operation onto the computation. 801 802 Args: 803 operand: ComputationDataHandle for the N dimensional array to be updated. 804 update: N dimensional array comprising the slice update. 805 start_indices: Rank-1 array of N integers comprising the starting indices 806 of the slice along each dimension. 807 Returns: 808 A ComputationDataHandle representing the added DynamicUpdateSlice op. 809 """ 810 return _wrap_data_handle( 811 self._client.DynamicUpdateSlice( 812 _unwrap_data_handle(operand), 813 _unwrap_data_handle(update), 814 _unwrap_data_handle(start_indices))) 815 816 def Tuple(self, *ops): 817 """Enqueues a tuple operation onto the computation. 818 819 Args: 820 ops: a sequence of tuple operands (each a ComputationDataHandle). 821 822 Returns: 823 A ComputationDataHandle representing the added Tuple op. 824 """ 825 return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) 826 827 def GetTupleElement(self, tup, index): 828 """Enqueues a 'get tuple element' operation onto the computation. 829 830 Args: 831 tup: the tuple operand (a ComputationDataHandle). 832 index: numeric index to select from the tuple. 833 834 Returns: 835 A ComputationDataHandle representing the added GetTupleElement op. 836 """ 837 return _wrap_data_handle( 838 self._client.GetTupleElement(_unwrap_data_handle(tup), index)) 839 840 def Call(self, computation_to_apply, operands): 841 """Enqueues a call operation onto the computation. 842 843 Args: 844 computation_to_apply: a Computation object. 845 operands: an iterable of ComputationDataHandle. The number and types of 846 operands must match the arity of computation_to_apply. 847 848 Returns: 849 A ComputationDataHandle representing the added call op. 850 """ 851 return _wrap_data_handle( 852 self._client.Call(computation_to_apply.c_local_computation, 853 _unwrap_data_handles(operands))) 854 855 def Map(self, operands, computation_to_apply, dimensions, static_operands=()): 856 """Enqueues a map operation onto the computation. 857 858 Args: 859 operands: an iterable of ComputationDataHandle. 860 computation_to_apply: a Computation object. 861 dimensions: dimensions over which to apply map the function. 862 static_operands: auxiliary arguments passed to the applied computation. 863 864 Returns: 865 A ComputationDataHandle representing the added Map op. 866 """ 867 return _wrap_data_handle( 868 self._client.Map( 869 _unwrap_data_handles(operands), 870 computation_to_apply.c_local_computation, 871 dimensions, 872 _unwrap_data_handles(static_operands))) 873 874 def Reduce(self, operand, init_value, computation_to_apply, dimensions): 875 """Enqueues a reduction operation onto the computation. 876 877 Args: 878 operand: reduction operand (ComputationDataHandle). 879 init_value: reduction initial value (ComputationDataHandle). 880 computation_to_apply: a Computation object - binary reduction function. 881 dimensions: sequence of dimensions (integers) to reduce on. 882 883 Returns: 884 A ComputationDataHandle representing the added Reduce op. 885 """ 886 return _wrap_data_handle( 887 self._client.Reduce( 888 _unwrap_data_handle(operand), 889 _unwrap_data_handle(init_value), 890 computation_to_apply.c_local_computation, 891 dimensions)) 892 893 def ReduceWindow(self, operand, init_value, computation_to_apply, 894 window_dimensions, window_strides, padding): 895 """Enqueues a windowed reduction operation onto the computation. 896 897 Args: 898 operand: reduction operand (ComputationDataHandle). 899 init_value: reduction initial value (ComputationDataHandle). 900 computation_to_apply: a binary reduction function (Computation). 901 window_dimensions: dimensions of window (sequence of integers). 902 window_strides: strides for window (sequence of integers). 903 padding: PaddingType representing either 'SAME' or 'VALID' padding. 904 905 Returns: 906 A ComputationDataHandle representing the added ReduceWindow op. 907 """ 908 pads = _convert_padding_type_to_pad_values( 909 padding, self.GetShape(operand).dimensions(), window_dimensions, 910 window_strides) 911 return _wrap_data_handle( 912 self._client.ReduceWindowWithGeneralPadding( 913 _unwrap_data_handle(operand), 914 _unwrap_data_handle(init_value), 915 computation_to_apply.c_local_computation, 916 window_dimensions, window_strides, pads)) 917 918 def RngNormal(self, mu, sigma, dims): 919 """Enqueues an RngNormal operation onto the computation. 920 921 Args: 922 mu: A ComputationDataHandle to an F32 scalar specifying the mean. 923 sigma: A ComputationDataHandle to an F32 scalar specifying the standard 924 deviation. 925 dims: A 1D array-like of nonnegative integers specifying the dimensions. 926 927 Returns: a ComputationDataHandle to the generated array of F32 values. 928 """ 929 shape = Shape(self.GetShape(mu).np_dtype, dims) 930 return _wrap_data_handle( 931 self._client.RngNormal( 932 _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) 933 934 def RngUniform(self, a, b, dims): 935 """Enqueues an RngUniform operation onto the computation. 936 937 Args: 938 a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with 939 the type of b) specifying the low end of the interval [a, b) over which 940 values are generated. 941 b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with 942 the type of a) specifying the high end of the interval [a, b) over which 943 values are generated. 944 dims: A 1D array-like of nonnegative integers specifying the dimensions. 945 946 Returns: a ComputationDataHandle to the generated array of values with the 947 same numeric type (F32, S32, or U32) as the arguments a and b. 948 """ 949 shape = Shape(self.GetShape(a).np_dtype, dims) 950 return _wrap_data_handle( 951 self._client.RngUniform( 952 _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) 953 954 def While(self, cond, body, init): 955 """Enqueues a While operation onto the computation. 956 957 Args: 958 cond: a Computation for the loop condition, which has type T -> PRED 959 body: a Computation for the loop body, which has type T -> T 960 init: a ComputationDataHandle for the initial parameter, which has type T 961 962 Returns: a ComputationDataHandle representing the While operation. 963 """ 964 return _wrap_data_handle( 965 self._client.While(cond.c_local_computation, 966 body.c_local_computation, 967 _unwrap_data_handle(init))) 968 969 def Conditional(self, pred, true_operand, true_computation, false_operand, 970 false_computation): 971 """Enqueues a Conditional operation onto the computation. 972 973 Args: 974 predicate: a ComputationDataHandle to test, which has scalar type PRED 975 true_operand: a ComputationDataHandle of type T_0 976 true_computation: a Computation to apply to true_operand, type T_0 -> S 977 false_operand: a ComputationDatahandle of type T_1 978 false_computation: a Computation to apply to false_operand, type T_1 -> S 979 980 Returns: a ComputationDataHandle representing the Conditional operation. 981 """ 982 return _wrap_data_handle( 983 self._client.Conditional( 984 _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), 985 true_computation.c_local_computation, 986 _unwrap_data_handle(false_operand), 987 false_computation.c_local_computation)) 988 989 def Dot(self, lhs, rhs): 990 """Enqueues a dot operation onto the computation. 991 992 Args: 993 lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. 994 rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. 995 996 Returns: a ComputationDataHandle representing the Dot operation. 997 """ 998 return _wrap_data_handle( 999 self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) 1000 1001 def DotGeneral(self, lhs, rhs, dimension_numbers): 1002 """Enqueues a general dot operation onto the computation. 1003 1004 Args: 1005 lhs: ComputationDataHandle for the left-hand-side array. 1006 rhs: ComputationDataHandle for the right-hand-side array. 1007 dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested 1008 tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of 1009 integers representing the dimensions to treat as contracting dimensions 1010 and batch dimensions on each input operand. 1011 1012 Returns: a ComputationDataHandle representing the DotGeneral operation. 1013 """ 1014 if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): 1015 dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) 1016 return _wrap_data_handle( 1017 self._client.DotGeneral( 1018 _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), 1019 dimension_numbers)) 1020 1021 def Conv(self, lhs, rhs, window_strides, padding): 1022 """Enqueues a Conv operation onto the computation. 1023 1024 Args: 1025 lhs: ComputationDataHandle for the rank N+2 array of inputs. 1026 rhs: ComputationDataHandle for the rank N+2 array of kernel weights. 1027 window_strides: length-N array-like of integer kernel strides. 1028 padding: PaddingType representing either 'SAME' or 'VALID' padding. 1029 1030 Returns: a ComputationDataHandle representing the Conv operation. 1031 """ 1032 pads = _convert_padding_type_to_pad_values( 1033 padding, self.GetShape(lhs).dimensions()[2:], 1034 self.GetShape(rhs).dimensions()[2:], window_strides) 1035 dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) 1036 return _wrap_data_handle( 1037 self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), 1038 _unwrap_data_handle(rhs), 1039 window_strides, 1040 pads, 1041 (), 1042 (), 1043 dimension_numbers)) 1044 1045 def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, 1046 lhs_dilation, rhs_dilation): 1047 """Enqueues a ConvWithGeneralPadding operation onto the computation. 1048 1049 Args: 1050 lhs: ComputationDataHandle for the rank N+2 array of inputs. 1051 rhs: ComputationDataHandle for the rank N+2 array of kernel weights. 1052 window_strides: length-N array-like of kernel strides. 1053 padding: length-N array-like of pairs of integers of (low, high) padding. 1054 lhs_dilation: length-N array-like of dilation factors. 1055 rhs_dilation: length-N array-like of dilation factors. 1056 1057 Returns: 1058 A ComputationdataHandle representing the added ConvWithGeneralPadding op. 1059 """ 1060 dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) 1061 return _wrap_data_handle( 1062 self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), 1063 _unwrap_data_handle(rhs), 1064 window_strides, 1065 padding, 1066 lhs_dilation, 1067 rhs_dilation, 1068 dimension_numbers)) 1069 1070 def _GetConvDimensionNumbers(self, num_spatial_dims): 1071 """Create ConvolutionDimensionNumbers proto for convolutions.""" 1072 nd = num_spatial_dims 1073 dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() 1074 dimension_numbers.input_batch_dimension = 0 1075 dimension_numbers.input_feature_dimension = 1 1076 dimension_numbers.output_batch_dimension = 0 1077 dimension_numbers.output_feature_dimension = 1 1078 dimension_numbers.kernel_output_feature_dimension = 0 1079 dimension_numbers.kernel_input_feature_dimension = 1 1080 dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) 1081 dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) 1082 dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) 1083 return dimension_numbers 1084 1085 1086def _forward_methods_to_local_builder(): 1087 """Forward remaining ComputationBuilder methods to the C API. 1088 1089 Set up methods, corresponding to unary and binary XLA operations, 1090 whose calls are forwarded in a boilerplate manner to the underlying 1091 LocalComputationBuilder C-extension API. 1092 """ 1093 1094 def forward_to_local_builder_with_handles(target_method, is_binop=False): 1095 """Generate a forwarding method that wraps/unwraps data handles.""" 1096 1097 def forward(self, *args, **kwargs): 1098 unwrapped_args = [_unwrap_data_handle(arg) for arg in args] 1099 1100 if is_binop and len(unwrapped_args) < 3: 1101 unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) 1102 1103 return _wrap_data_handle( 1104 target_method( 1105 self._client, # pylint: disable=protected-access 1106 *unwrapped_args)) 1107 1108 return forward 1109 1110 for method_name in _UNARY_OPS: 1111 forward = forward_to_local_builder_with_handles( 1112 getattr(c_api.LocalComputationBuilder, method_name)) 1113 forward.__name__ = method_name 1114 setattr(ComputationBuilder, method_name, forward) 1115 1116 for method_name in _BINARY_OPS: 1117 forward = forward_to_local_builder_with_handles( 1118 getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) 1119 forward.__name__ = method_name 1120 setattr(ComputationBuilder, method_name, forward) 1121 1122 1123_forward_methods_to_local_builder() 1124 1125 1126def initialize_replica_count(replica_count): 1127 """Initializes the desired replica count to use on XLA service init. 1128 1129 Args: 1130 replica_count: number of replicas that are desired for set up during XLA 1131 initialization. 1132 1133 Raises: 1134 A runtime exception if the XLA service has already been initialized. 1135 """ 1136 c_api.InitializeReplicaCount(replica_count) 1137 1138 1139def get_replica_count(): 1140 """Returns the current replica count used for the XLA service. 1141 1142 Note: this will return a value whether the XLA service has been initialized 1143 yet or not. 1144 """ 1145 return c_api.GetReplicaCount() 1146 1147 1148def GetPaddingConfigFromTriples(triples): 1149 """Create PaddingConfig proto from list of triples of integers.""" 1150 padding_config = xla_data_pb2.PaddingConfig() 1151 for lo, hi, interior in triples: 1152 dimension = padding_config.dimensions.add() 1153 dimension.edge_padding_low = lo 1154 dimension.edge_padding_high = hi 1155 dimension.interior_padding = interior 1156 return padding_config 1157 1158 1159def GetDotDimensionsFromLists(dimension_numbers): 1160 (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 1161 dot_dims_proto = xla_data_pb2.DotDimensionNumbers() 1162 dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) 1163 dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) 1164 dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) 1165 dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) 1166 return dot_dims_proto 1167