1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""An in-process, local XLA client in Python, supporting AOT compilation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import enum  # pylint: disable=g-bad-import-order
23import inspect
24import itertools
25import os
26
27import numpy as np
28
29from tensorflow.compiler.xla import xla_data_pb2
30from tensorflow.compiler.xla.python import pywrap_xla as c_api
31
32
33# Most functions are snake_case for consistency with other modules,
34# whereas method names of ComputationBuilder and LocalComputation are
35# CamelCase for consistency with XLA.
36# pylint: disable=invalid-name
37
38
39_OP_METADATA_FIELDS = [
40    'op_type',
41    'op_name',
42    'source_file',
43    'source_line',
44]
45OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS)
46
47
48def OpMetadataToProto(pyobj):
49  proto = xla_data_pb2.OpMetadata()
50  for field in _OP_METADATA_FIELDS:
51    attr = getattr(pyobj, field)
52    if attr is not None:
53      setattr(proto, field, attr)
54  return proto
55
56
57def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
58  """Helper for use in source mapping that returns an OpMetadata object."""
59  full_filename, lineno = inspect.stack()[skip_frames][1:3]
60  filename = os.path.basename(full_filename)
61  return OpMetadata(
62      op_type=op_type,
63      op_name=op_name,
64      source_file=filename,
65      source_line=lineno)
66
67
68class PaddingType(enum.Enum):
69  VALID = 1
70  SAME = 2
71
72
73def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
74                                        window_strides):
75  """Maps PaddingType (VALID or SAME) to pad values (list of pairs of ints)."""
76  if padding_type == PaddingType.VALID:
77    return [(0, 0)] * len(window_strides)
78
79  out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
80  pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0)
81               for out_size, stride, filter_size, in_size
82               in zip(out_shape, window_strides, rhs_dims, lhs_dims)]
83  return [(pad_size // 2, pad_size - pad_size // 2)
84          for pad_size in pad_sizes]
85
86
87_UNARY_OPS = [
88    'Not',
89    'Abs',
90    'Exp',
91    'Floor',
92    'Round',
93    'Ceil',
94    'Log',
95    'Sign',
96    'Cos',
97    'Sin',
98    'Tanh',
99    'SqrtF32',
100    'SquareF32',
101    'IsFinite',
102    'ReciprocalF32',
103    'Neg',
104    'Sort',
105]
106
107_BINARY_OPS = [
108    'Eq',
109    'Ne',
110    'Ge',
111    'Gt',
112    'Lt',
113    'Le',
114    'Add',
115    'Sub',
116    'Mul',
117    'Div',
118    'Rem',
119    'Max',
120    'Min',
121    'And',
122    'Or',
123    'Pow',
124]
125
126XLA_ELEMENT_TYPE_TO_DTYPE = {
127    xla_data_pb2.F32: np.dtype(np.float32),
128    xla_data_pb2.F64: np.dtype(np.float64),
129    xla_data_pb2.S32: np.dtype(np.int32),
130    xla_data_pb2.S64: np.dtype(np.int64),
131    xla_data_pb2.U32: np.dtype(np.uint32),
132    xla_data_pb2.U64: np.dtype(np.uint64),
133    xla_data_pb2.PRED: np.dtype(np.bool),
134    xla_data_pb2.TUPLE: np.dtype(np.object),
135}
136
137# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
138# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
139# when keying by dtype in this dict, we use the string form of dtypes.
140DTYPE_TO_XLA_ELEMENT_TYPE = {
141    str(v): k
142    for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items()
143}
144
145
146class LocalBuffer(object):
147  """Represents a handle to data owned by XLA.
148
149  The referent is ready for use in executing a local, compiled
150  Computation. On XLA platforms involving a device (e.g. GPU), this
151  means the referent is in device memory.
152  """
153
154  def __init__(self, c_local_shaped_buffer):
155    self.c_local_shaped_buffer = c_local_shaped_buffer
156    self._delete = c_api.DeleteLocalShapedBuffer
157
158  @staticmethod
159  def from_py(npval, layout_fn=None):
160    npval = require_numpy_array_layout(npval)
161    if layout_fn:
162      shape = Shape.from_numpy(npval)
163      shape = shape.map_leaves(layout_fn)
164    else:
165      shape = None
166    return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape))
167
168  def to_py(self):
169    return self.c_local_shaped_buffer.ToLiteral()
170
171  def delete(self):
172    if self.c_local_shaped_buffer is not None:
173      self._delete(self.c_local_shaped_buffer)
174      self.c_local_shaped_buffer = None
175
176  def is_deleted(self):
177    return self.c_local_shaped_buffer is None
178
179  def __del__(self):
180    self.delete()
181
182
183class Shape(object):
184  """XLA shape.
185
186  Represents an XLA shape by a corresponding Python/Numpy type and a
187  list of dimensions, which are themselves Shapes in case this one
188  represents an XLA tuple.
189  """
190
191  def __init__(self, np_dtype, dimensions, minor_to_major=None):
192    assert isinstance(dimensions, tuple)
193    self.np_dtype = np_dtype
194    self._dimensions = dimensions
195    self._minor_to_major = minor_to_major
196    self._check_minor_to_major()
197
198  def __eq__(self, other):
199    # pylint: disable=protected-access
200    return (self.np_dtype == other.np_dtype and
201            self._dimensions == other._dimensions and
202            self._minor_to_major == other._minor_to_major)
203
204  def __repr__(self):
205    return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, '
206            'minor_to_major={!r})').format(self.np_dtype, self._dimensions,
207                                           self._minor_to_major)
208
209  def element_type(self):
210    return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)]
211
212  def is_tuple(self):
213    return self.element_type() == xla_data_pb2.TUPLE
214
215  def dimensions(self):
216    if self.is_tuple():
217      raise ValueError('Tuple shape has no dimensions')
218    return self._dimensions
219
220  def minor_to_major(self):
221    return self._minor_to_major
222
223  def tuple_shapes(self):
224    if not self.is_tuple():
225      raise ValueError('Shape is not a tuple shape')
226    return self._dimensions
227
228  def rank(self):
229    return len(self.dimensions())
230
231  def map_leaves(self, f):
232    """Map f over each leaf-level array subshape.
233
234    Args:
235      f: The function to apply. Whenever f returns None, the identity is
236        applied instead.
237
238    Returns:
239      A new Shape with the mapped leaves.
240    """
241    if self.is_tuple():
242      children = tuple(child.map_leaves(f) for child in self.tuple_shapes())
243      return Shape(np.dtype('O'), children)
244    else:
245      mapped = f(self)
246      return self if mapped is None else mapped
247
248  def _check_minor_to_major(self):
249    mtm = self._minor_to_major
250    if self.is_tuple():
251      assert mtm is None, self
252    if mtm is not None:
253      assert self.rank() == len(mtm), self
254      assert sorted(mtm) == range(len(mtm)), self
255
256  def update_minor_to_major(self, minor_to_major):
257    if not isinstance(minor_to_major, tuple):
258      raise TypeError('minor_to_major must be a tuple')
259    updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major)
260    updated._check_minor_to_major()  # pylint: disable=protected-access
261    return updated
262
263  @staticmethod
264  def from_numpy(npval):
265
266    def convert(npval):
267      if isinstance(npval, tuple):
268        return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval))
269      else:
270        return Shape(npval.dtype, np.shape(npval))
271
272    return convert(require_numpy_array_layout(npval))
273
274
275def _wrap_shape(shape_info):
276  dtype, dims = shape_info
277  element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
278  if element_type == xla_data_pb2.TUPLE:
279    dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims)
280  return Shape(dtype, dims)
281
282
283def _wrap_data_handle(handle):
284  cdh = xla_data_pb2.ComputationDataHandle()
285  cdh.handle = handle
286  return cdh
287
288
289def _unwrap_data_handle(handle_proto):
290  return handle_proto.handle
291
292
293def _unwrap_data_handles(handle_protos):
294  return [_unwrap_data_handle(cdh) for cdh in handle_protos]
295
296
297def require_numpy_array_layout(value):
298  if isinstance(value, tuple):
299    return tuple(require_numpy_array_layout(x) for x in value)
300  else:
301    return np.require(value, requirements=['C', 'A'])
302
303
304class CompileOptions(object):
305  """Python object for XLA compile options.
306
307  These options can be passed to the 'compile' step when using a local XLA
308  client.
309  """
310
311  def __init__(self):
312    self.generate_hlo_graph = None
313
314
315def transfer_to_infeed(value, replica_number=None):
316  """Transfers the given value into the XLA infeed queue.
317
318  XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
319  a totally ordered stream of values. This is dequeued from XLA computations via
320  the Infeed() operation.
321
322  Args:
323    value: the value that the caller would like to enqueue into the XLA infeed
324      queue
325    replica_number: the replica number to infeed the value to -- if not
326      provided, then the default replica (trivially replica 0) is used.
327  """
328  if replica_number is None:
329    c_api.TransferToInfeedLocal(require_numpy_array_layout(value))
330  else:
331    c_api.TransferToInfeedLocalReplica(
332        require_numpy_array_layout(value), replica_number)
333
334
335def transfer_from_outfeed(shape, replica_number=None):
336  """Transfers a literal of the given shape from replica_number's outfeed.
337
338  Args:
339    shape: The shape of the value to transfer from outfeed.
340    replica_number: The replica number ordinal to transfer the outfeed value
341      from. (Each replica has a distinct outfeed queue.)
342
343  Returns:
344    The literal value that is produced from the outfeed queue.
345  """
346  return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0)
347
348
349class LocalComputation(object):
350  """Python wrapper for a local XLA Computation.
351
352  A LocalComputation can be executed if it is compiled. Otherwise, it
353  can still be used as a Computation where required by the
354  ComputationBuilder methods.
355  """
356
357  def __init__(self, c_local_computation, is_compiled):
358    self.c_local_computation = c_local_computation
359    self.is_compiled = is_compiled
360
361    # Ensure a reference to C-based destructor for use in __del__.
362    if is_compiled:
363      assert isinstance(c_local_computation, c_api.CompiledLocalComputation)
364      self._delete = c_api.DeleteCompiledLocalComputation
365    else:
366      assert isinstance(c_local_computation, c_api.LocalComputation)
367      self._delete = c_api.DeleteLocalComputation
368
369  def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None):
370    """Compiles an un-compiled local computation.
371
372    Local computations are the result of a "LocalComputationBuild'ing" process
373    -- they start in uncompiled form, and via a call to Compile() turn into a
374    compiled local computation.
375
376    Raises:
377      ValueError: if this is already a compiled local computation.
378
379    Arguments:
380      argument_shapes: parameter shapes -- they are first laid out by layout_fn
381        if layout_fn is provided. Otherwise, the default layout for those shapes
382        will be used.
383      compile_options: options to use for compilation, includes an optional
384        laid out result shape for the computation.
385      layout_fn: lambda that is used to lay out the argument/result shapes.
386
387    Returns:
388      A newly *compiled* local computation instance.
389    """
390    if self.is_compiled:
391      raise ValueError('Attempt to compile a compiled local XLA computation.')
392
393    if layout_fn:
394      argument_shapes = [
395          shape.map_leaves(layout_fn) for shape in argument_shapes
396      ]
397      result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape())
398      result_shape = result_shape.map_leaves(layout_fn)
399      compile_options = compile_options or CompileOptions()
400      compile_options.result_shape = result_shape
401    return LocalComputation(
402        self.c_local_computation.Compile(argument_shapes, compile_options),
403        is_compiled=True)
404
405  def CompileWithExampleArguments(self,
406                                  arguments=(),
407                                  compile_options=None,
408                                  layout_fn=None):
409    return self.Compile(
410        argument_shapes=[Shape.from_numpy(arg) for arg in arguments],
411        compile_options=compile_options,
412        layout_fn=layout_fn)
413
414  def Execute(self, arguments=(), layout_fn=None):
415    """Execute with Python values as arguments and return value."""
416    if not self.is_compiled:
417      raise ValueError('Cannot execute an uncompiled local XLA computation.')
418    argument_shapes = [Shape.from_numpy(arg) for arg in arguments]
419    if layout_fn:
420      argument_shapes = [
421          shape.map_leaves(layout_fn) for shape in argument_shapes
422      ]
423    else:
424      argument_shapes = [None for shape in argument_shapes]
425    arguments = tuple(map(require_numpy_array_layout, arguments))
426    return self.c_local_computation.Execute(arguments, argument_shapes)
427
428  def ExecuteWithLocalBuffers(self, arguments=()):
429    """Execute with LocalBuffer arguments and return value."""
430    if not self.is_compiled:
431      raise ValueError('Cannot execute an uncompiled local XLA computation.')
432    arguments = tuple(arguments)
433    if any(arg.is_deleted() for arg in arguments):
434      raise ValueError('Executing with deleted local buffer argument')
435    return LocalBuffer(
436        self.c_local_computation.ExecuteWithShapedBuffers(
437            [arg.c_local_shaped_buffer for arg in arguments]))
438
439  def __del__(self):
440    self._delete(self.c_local_computation)
441
442
443class ComputationBuilder(object):
444  """XLA computation builder.
445
446  Enqueues XLA ops in sequence and in order to build a
447  LocalComputation, which in turn can be compiled into a
448  CompiledLocalComputation, which in turn can be locally executed.
449  """
450
451  # The methods of this class map 1-to-1 onto the XLA C++
452  # computation builder API. Therefore, there's no need to laboriously list
453  # arguments and return values for every method, especially where it's obvious.
454  #
455  # pylint: disable=g-doc-return-or-yield
456  # pylint: disable=g-doc-args
457
458  def __init__(self, name):
459    self._client = c_api.LocalComputationBuilder(name.encode('utf8'))
460    self._parameter_numbering = itertools.count()
461
462  def Build(self):
463    return LocalComputation(self._client.Build(), is_compiled=False)
464
465  def SetOpMetadata(self, op_metadata):
466    """Set metadata for operations that are about to be enqueued."""
467    self._client.SetOpMetadata(op_metadata)
468
469  def ClearOpMetadata(self):
470    """Clear metadata for operations that are about to be enqueued."""
471    self._client.ClearOpMetadata()
472
473  def Infeed(self, shape):
474    """Enqueues an infeed op onto the computation.
475
476    Infeed operations dequeue data of the given shape from the device's infeed
477    queue for subsequent use in the computation.
478
479    Returns:
480      A  ComputationDataHandle message.
481    """
482    return _wrap_data_handle(self._client.Infeed(shape))
483
484  def Outfeed(self, operand):
485    """Enqueues an outfeed op onto the computation.
486
487    Outfeed operations enqueue data, using the given operand, onto the XLA
488    outfeed queue for subsequent dequeue via the client API.
489    """
490    self._client.Outfeed(
491        _unwrap_data_handle(operand), self.GetShape(operand),
492        ''.encode('utf-8'))
493
494  def Constant(self, value):
495    """Enqueues a constant op onto the computation.
496
497    Args:
498      value: value for the constant, as a np.array with an explicit dtype set
499             to one of the supported types.
500
501    Returns:
502      A ComputationDataHandle message.
503    """
504    value = require_numpy_array_layout(value)
505    return _wrap_data_handle(self._client.ConstantLiteral(value))
506
507  def ConstantF32Scalar(self, value):
508    """Convenience method to enqueue a scalar F32 constant op.
509
510    Args:
511      value: a floating-point number.
512
513    Returns:
514      A ComputationDataHandle message.
515    """
516    return self.Constant(np.array(value, dtype=np.float32))
517
518  def ConstantF64Scalar(self, value):
519    """Convenience method to enqueue a scalar F32 constant op.
520
521    Args:
522      value: a floating-point number.
523
524    Returns:
525      A ComputationDataHandle message.
526    """
527    return self.Constant(np.array(value, dtype=np.float64))
528
529  def ConstantS32Scalar(self, value):
530    """Convenience method to enqueue a scalar S32 constant op.
531
532    Args:
533      value: a floating-point number.
534
535    Returns:
536      A ComputationDataHandle message.
537    """
538    return self.Constant(np.array(value, dtype=np.int32))
539
540  def ConstantS64Scalar(self, value):
541    """Convenience method to enqueue a scalar S64 constant op.
542
543    Args:
544      value: a floating-point number.
545
546    Returns:
547      A ComputationDataHandle message.
548    """
549    return self.Constant(np.array(value, dtype=np.int64))
550
551  def ConstantPredScalar(self, value):
552    """Convenience method to enqueue a scalar PRED constant op.
553
554    Args:
555      value: a boolean value.
556
557    Returns:
558      A ComputationDataHandle message.
559    """
560    return self.Constant(np.array(value, dtype=np.bool))
561
562  def ParameterWithShape(self, shape, name=None, parameter_num=None):
563    """Enqueues a Parameter op onto the computation, given a shape.
564
565    Args:
566      shape: the parameter's shape as a Shape object.
567      name: optional string name for the parameter.
568      parameter_num: parameter number in the computation function. If None,
569        the next linear parameter number is used. The default value capability
570        can be used for auto-numbering. If you're using auto-numbering for some
571        parameters, use it for *all* parameters to avoid clashes.
572
573    Returns:
574      A ComputationDataHandle message.
575    """
576    if name is None:
577      name = ''
578    if parameter_num is None:
579      parameter_num = next(self._parameter_numbering)
580
581    return _wrap_data_handle(
582        self._client.Parameter(parameter_num, shape, name.encode('utf8')))
583
584  def ParameterFromNumpy(self, value, name=None, parameter_num=None):
585    """Enqueues a Parameter op onto the computation.
586
587    Args:
588      value: a Numpy array, or a nested tuple thereof, from which the
589        shape is inferred.
590      name: as in ParameterWithShape.
591      parameter_num: as in ParameterWithShape.
592
593    Returns:
594      A ComputationDataHandle message.
595    """
596    return self.ParameterWithShape(
597        Shape.from_numpy(value), name=name, parameter_num=parameter_num)
598
599  def Broadcast(self, operand, sizes):
600    """Enqueues a broadcast operation onto the computation.
601
602    Args:
603      operand: the operand ComputationDataHandle to broadcast.
604      sizes: an iterable of broadcast sizes.
605
606    Returns:
607      A ComputationDataHandle representing the added broadcast op.
608    """
609    return _wrap_data_handle(
610        self._client.Broadcast(_unwrap_data_handle(operand), sizes))
611
612  def Concatenate(self, operands, dimension):
613    """Enqueues a concatenate operation onto the computation.
614
615    Args:
616      operands: the operands to concatenate.
617      dimension: the dimension in which to perform the concatenation.
618
619    Returns:
620      A ComputationDataHandle representing the added concatenate op.
621    """
622    return _wrap_data_handle(
623        self._client.ConcatInDim(_unwrap_data_handles(operands), dimension))
624
625  def ConvertElementType(self, operand, new_element_type):
626    """Enqueues an element type conversion operation onto the computation.
627
628    Args:
629      operand: the operand to convert.
630      new_element_type: the target primitive type.
631
632    Returns:
633      A ComputationDataHandle representing the added conversion op.
634    """
635    return _wrap_data_handle(
636        self._client.ConvertElementType(
637            _unwrap_data_handle(operand), new_element_type))
638
639  def GetShape(self, operand):
640    return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand)))
641
642  def GetReturnValueShape(self):
643    return _wrap_shape(self._client.GetReturnValueShape())
644
645  def GetComputationStats(self):
646    raise NotImplementedError()
647
648  def Pad(self, operand, padding_value, padding_config):
649    """Enqueues a Pad operation onto the computation.
650
651    Args:
652      operand: ComputationDataHandle representing the array to pad.
653      padding_value: ComputationDataHandle representing the scalar pad value.
654      padding_config: either an xla_data_pb2.PaddingConfig or a list of integer
655        triples (edge_padding_low, edge_padding_high, interior_padding)
656        representing the configuration of the padding operation.
657
658    Returns:
659      A ComputationDataHandle representing the added pad op.
660    """
661    if not isinstance(padding_config, xla_data_pb2.PaddingConfig):
662      padding_config = GetPaddingConfigFromTriples(padding_config)
663    return _wrap_data_handle(
664        self._client.Pad(_unwrap_data_handle(operand),
665                         _unwrap_data_handle(padding_value),
666                         padding_config))
667
668  def Reshape(self, operand, dimensions, new_sizes):
669    """Reshape op."""
670    return _wrap_data_handle(
671        self._client.Reshape(
672            _unwrap_data_handle(operand), dimensions, new_sizes))
673
674  def CrossReplicaSum(self, operand):
675    """CrossReplicaSum op.
676
677    Args:
678      operand: the operand to sum across replica instances.
679
680    Returns:
681      A ComputationDataHandle that has the sum of the value among all replicas.
682    """
683    return _wrap_data_handle(
684        self._client.CrossReplicaSum(_unwrap_data_handle(operand)))
685
686  def Collapse(self, operand, dimensions):
687    """Collapse op."""
688    return _wrap_data_handle(
689        self._client.Collapse(_unwrap_data_handle(operand), dimensions))
690
691  def Trans(self, operand):
692    """Specialized matrix transpose op."""
693    return _wrap_data_handle(
694        self._client.Transpose(_unwrap_data_handle(operand), [1, 0]))
695
696  def Transpose(self, operand, permutation):
697    """Transpose op."""
698    return _wrap_data_handle(
699        self._client.Transpose(_unwrap_data_handle(operand), permutation))
700
701  def Rev(self, operand, dimensions):
702    """Rev op."""
703    return _wrap_data_handle(
704        self._client.Rev(_unwrap_data_handle(operand), dimensions))
705
706  def Clamp(self, min, operand, max):  # pylint: disable=redefined-builtin
707    """Clamp op."""
708    return _wrap_data_handle(
709        self._client.Clamp(_unwrap_data_handle(min),
710                           _unwrap_data_handle(operand),
711                           _unwrap_data_handle(max)))
712
713  def SelectAndScatter(self, operand, select, window_dimensions, window_strides,
714                       padding, source, init_value, scatter):
715    """Select and scatter op, used by the gradient of ReduceWindow.
716
717    Args:
718      operand: ComputationDataHandle for array of dimension N and type T over
719        which the windows slide.
720      select: Computation of type (T, T) -> Pred to apply to the elements of
721        each window to indicate which element is selected.
722      window_dimensions: sequence of N integers for dimensions of the window.
723      window_strides: sequence of N integers for the strides of the window.
724      padding: PaddingType representing either 'SAME' or 'VALID ' padding.
725      source: ComputationDataHandle for array of type T with values to scatter.
726      init_value: ComputationDataHandle of scalar type T for initial out value.
727      scatter: Computation of type (T, T) -> T to apply to each scatter source
728        element with its destination element.
729
730    Returns:
731      A ComputationDataHandle representing the added SelectAndScatter op.
732    """
733    pads = _convert_padding_type_to_pad_values(
734        padding, self.GetShape(operand).dimensions(),
735        window_dimensions, window_strides)
736    return _wrap_data_handle(
737        self._client.SelectAndScatterWithGeneralPadding(
738            _unwrap_data_handle(operand), select.c_local_computation,
739            window_dimensions, window_strides, pads,
740            _unwrap_data_handle(source), _unwrap_data_handle(init_value),
741            scatter.c_local_computation))
742
743  def Select(self, pred, on_true, on_false):
744    """Element-wise selection op.
745
746    Constructs an output array from elements of two input arrays, based on the
747    values of a predicate array.
748    """
749    return _wrap_data_handle(
750        self._client.Select(
751            _unwrap_data_handle(pred),
752            _unwrap_data_handle(on_true),
753            _unwrap_data_handle(on_false)))
754
755  def Slice(self, operand, start_indices, limit_indices, strides=None):
756    """Enqueues a slice operation onto the computation.
757
758    Args:
759      operand: ComputationDataHandle for the N dimensional array to be sliced.
760      start_indices: iterable of N integers containing the starting indices of
761        the slice for each dimension.
762      limit_indices: iterable of N integers containing the ending indices
763        (exclusive) of the slice for each dimension.
764      strides: optional iterable of N integers containing the stride sizes for
765        each dimension.
766
767    Returns:
768      A ComputationDataHandle representing the added Slice op.
769    """
770    if strides is None:
771      start_indices = list(start_indices)
772      strides = [1] * len(start_indices)
773    return _wrap_data_handle(
774        self._client.Slice(
775            _unwrap_data_handle(operand),
776            start_indices,
777            limit_indices,
778            strides))
779
780  def DynamicSlice(self, operand, start_indices, slice_sizes):
781    """Enqueues a slice op with dynamic start indices onto the computation.
782
783    Args:
784      operand: ComputationDataHandle for the N dimensional array to be sliced.
785      start_indices: ComputationDataHandle for the 1D array of N integers
786        containing the starting indices of the slice.
787      slice_sizes: iterable of N integers containing the slice sizes in each
788        dimension.
789
790    Returns:
791      A ComputationDataHandle representing the added DynamicSlice op.
792    """
793    return _wrap_data_handle(
794        self._client.DynamicSlice(
795            _unwrap_data_handle(operand),
796            _unwrap_data_handle(start_indices),
797            slice_sizes))
798
799  def DynamicUpdateSlice(self, operand, update, start_indices):
800    """Enqueues a dynamic update slice operation onto the computation.
801
802    Args:
803      operand: ComputationDataHandle for the N dimensional array to be updated.
804      update: N dimensional array comprising the slice update.
805      start_indices: Rank-1 array of N integers comprising the starting indices
806        of the slice along each dimension.
807    Returns:
808      A ComputationDataHandle representing the added DynamicUpdateSlice op.
809    """
810    return _wrap_data_handle(
811        self._client.DynamicUpdateSlice(
812            _unwrap_data_handle(operand),
813            _unwrap_data_handle(update),
814            _unwrap_data_handle(start_indices)))
815
816  def Tuple(self, *ops):
817    """Enqueues a tuple operation onto the computation.
818
819    Args:
820      ops: a sequence of tuple operands (each a ComputationDataHandle).
821
822    Returns:
823      A ComputationDataHandle representing the added Tuple op.
824    """
825    return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops)))
826
827  def GetTupleElement(self, tup, index):
828    """Enqueues a 'get tuple element' operation onto the computation.
829
830    Args:
831      tup: the tuple operand (a ComputationDataHandle).
832      index: numeric index to select from the tuple.
833
834    Returns:
835      A ComputationDataHandle representing the added GetTupleElement op.
836    """
837    return _wrap_data_handle(
838        self._client.GetTupleElement(_unwrap_data_handle(tup), index))
839
840  def Call(self, computation_to_apply, operands):
841    """Enqueues a call operation onto the computation.
842
843    Args:
844      computation_to_apply: a Computation object.
845      operands: an iterable of ComputationDataHandle. The number and types of
846        operands must match the arity of computation_to_apply.
847
848    Returns:
849      A ComputationDataHandle representing the added call op.
850    """
851    return _wrap_data_handle(
852        self._client.Call(computation_to_apply.c_local_computation,
853                          _unwrap_data_handles(operands)))
854
855  def Map(self, operands, computation_to_apply, dimensions, static_operands=()):
856    """Enqueues a map operation onto the computation.
857
858    Args:
859      operands: an iterable of ComputationDataHandle.
860      computation_to_apply: a Computation object.
861      dimensions: dimensions over which to apply map the function.
862      static_operands: auxiliary arguments passed to the applied computation.
863
864    Returns:
865      A ComputationDataHandle representing the added Map op.
866    """
867    return _wrap_data_handle(
868        self._client.Map(
869            _unwrap_data_handles(operands),
870            computation_to_apply.c_local_computation,
871            dimensions,
872            _unwrap_data_handles(static_operands)))
873
874  def Reduce(self, operand, init_value, computation_to_apply, dimensions):
875    """Enqueues a reduction operation onto the computation.
876
877    Args:
878      operand: reduction operand (ComputationDataHandle).
879      init_value: reduction initial value (ComputationDataHandle).
880      computation_to_apply: a Computation object - binary reduction function.
881      dimensions: sequence of dimensions (integers) to reduce on.
882
883    Returns:
884      A ComputationDataHandle representing the added Reduce op.
885    """
886    return _wrap_data_handle(
887        self._client.Reduce(
888            _unwrap_data_handle(operand),
889            _unwrap_data_handle(init_value),
890            computation_to_apply.c_local_computation,
891            dimensions))
892
893  def ReduceWindow(self, operand, init_value, computation_to_apply,
894                   window_dimensions, window_strides, padding):
895    """Enqueues a windowed reduction operation onto the computation.
896
897    Args:
898      operand: reduction operand (ComputationDataHandle).
899      init_value: reduction initial value (ComputationDataHandle).
900      computation_to_apply: a binary reduction function (Computation).
901      window_dimensions: dimensions of window (sequence of integers).
902      window_strides: strides for window (sequence of integers).
903      padding: PaddingType representing either 'SAME' or 'VALID' padding.
904
905    Returns:
906      A ComputationDataHandle representing the added ReduceWindow op.
907    """
908    pads = _convert_padding_type_to_pad_values(
909        padding, self.GetShape(operand).dimensions(), window_dimensions,
910        window_strides)
911    return _wrap_data_handle(
912        self._client.ReduceWindowWithGeneralPadding(
913            _unwrap_data_handle(operand),
914            _unwrap_data_handle(init_value),
915            computation_to_apply.c_local_computation,
916            window_dimensions, window_strides, pads))
917
918  def RngNormal(self, mu, sigma, dims):
919    """Enqueues an RngNormal operation onto the computation.
920
921    Args:
922      mu: A ComputationDataHandle to an F32 scalar specifying the mean.
923      sigma: A ComputationDataHandle to an F32 scalar specifying the standard
924        deviation.
925      dims: A 1D array-like of nonnegative integers specifying the dimensions.
926
927    Returns: a ComputationDataHandle to the generated array of F32 values.
928    """
929    shape = Shape(self.GetShape(mu).np_dtype, dims)
930    return _wrap_data_handle(
931        self._client.RngNormal(
932            _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape))
933
934  def RngUniform(self, a, b, dims):
935    """Enqueues an RngUniform operation onto the computation.
936
937    Args:
938      a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with
939        the type of b) specifying the low end of the interval [a, b) over which
940        values are generated.
941      b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with
942        the type of a) specifying the high end of the interval [a, b) over which
943        values are generated.
944      dims: A 1D array-like of nonnegative integers specifying the dimensions.
945
946    Returns: a ComputationDataHandle to the generated array of values with the
947      same numeric type (F32, S32, or U32) as the arguments a and b.
948    """
949    shape = Shape(self.GetShape(a).np_dtype, dims)
950    return _wrap_data_handle(
951        self._client.RngUniform(
952            _unwrap_data_handle(a), _unwrap_data_handle(b), shape))
953
954  def While(self, cond, body, init):
955    """Enqueues a While operation onto the computation.
956
957    Args:
958      cond: a Computation for the loop condition, which has type T -> PRED
959      body: a Computation for the loop body, which has type T -> T
960      init: a ComputationDataHandle for the initial parameter, which has type T
961
962    Returns: a ComputationDataHandle representing the While operation.
963    """
964    return _wrap_data_handle(
965        self._client.While(cond.c_local_computation,
966                           body.c_local_computation,
967                           _unwrap_data_handle(init)))
968
969  def Conditional(self, pred, true_operand, true_computation, false_operand,
970                  false_computation):
971    """Enqueues a Conditional operation onto the computation.
972
973    Args:
974      predicate: a ComputationDataHandle to test, which has scalar type PRED
975      true_operand: a ComputationDataHandle of type T_0
976      true_computation: a Computation to apply to true_operand, type T_0 -> S
977      false_operand: a ComputationDatahandle of type T_1
978      false_computation: a Computation to apply to false_operand, type T_1 -> S
979
980    Returns: a ComputationDataHandle representing the Conditional operation.
981    """
982    return _wrap_data_handle(
983        self._client.Conditional(
984            _unwrap_data_handle(pred), _unwrap_data_handle(true_operand),
985            true_computation.c_local_computation,
986            _unwrap_data_handle(false_operand),
987            false_computation.c_local_computation))
988
989  def Dot(self, lhs, rhs):
990    """Enqueues a dot operation onto the computation.
991
992    Args:
993      lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array.
994      rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array.
995
996    Returns: a ComputationDataHandle representing the Dot operation.
997    """
998    return _wrap_data_handle(
999        self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
1000
1001  def DotGeneral(self, lhs, rhs, dimension_numbers):
1002    """Enqueues a general dot operation onto the computation.
1003
1004    Args:
1005      lhs: ComputationDataHandle for the left-hand-side array.
1006      rhs: ComputationDataHandle for the right-hand-side array.
1007      dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested
1008        tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of
1009        integers representing the dimensions to treat as contracting dimensions
1010        and batch dimensions on each input operand.
1011
1012    Returns: a ComputationDataHandle representing the DotGeneral operation.
1013    """
1014    if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers):
1015      dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
1016    return _wrap_data_handle(
1017        self._client.DotGeneral(
1018            _unwrap_data_handle(lhs), _unwrap_data_handle(rhs),
1019            dimension_numbers))
1020
1021  def Conv(self, lhs, rhs, window_strides, padding):
1022    """Enqueues a Conv operation onto the computation.
1023
1024    Args:
1025      lhs: ComputationDataHandle for the rank N+2 array of inputs.
1026      rhs: ComputationDataHandle for the rank N+2 array of kernel weights.
1027      window_strides: length-N array-like of integer kernel strides.
1028      padding: PaddingType representing either 'SAME' or 'VALID' padding.
1029
1030    Returns: a ComputationDataHandle representing the Conv operation.
1031    """
1032    pads = _convert_padding_type_to_pad_values(
1033        padding, self.GetShape(lhs).dimensions()[2:],
1034        self.GetShape(rhs).dimensions()[2:], window_strides)
1035    dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
1036    return _wrap_data_handle(
1037        self._client.ConvGeneralDilated(_unwrap_data_handle(lhs),
1038                                        _unwrap_data_handle(rhs),
1039                                        window_strides,
1040                                        pads,
1041                                        (),
1042                                        (),
1043                                        dimension_numbers))
1044
1045  def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
1046                             lhs_dilation, rhs_dilation):
1047    """Enqueues a ConvWithGeneralPadding operation onto the computation.
1048
1049    Args:
1050      lhs: ComputationDataHandle for the rank N+2 array of inputs.
1051      rhs: ComputationDataHandle for the rank N+2 array of kernel weights.
1052      window_strides: length-N array-like of kernel strides.
1053      padding: length-N array-like of pairs of integers of (low, high) padding.
1054      lhs_dilation: length-N array-like of dilation factors.
1055      rhs_dilation: length-N array-like of dilation factors.
1056
1057    Returns:
1058      A ComputationdataHandle representing the added ConvWithGeneralPadding op.
1059    """
1060    dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
1061    return _wrap_data_handle(
1062        self._client.ConvGeneralDilated(_unwrap_data_handle(lhs),
1063                                        _unwrap_data_handle(rhs),
1064                                        window_strides,
1065                                        padding,
1066                                        lhs_dilation,
1067                                        rhs_dilation,
1068                                        dimension_numbers))
1069
1070  def _GetConvDimensionNumbers(self, num_spatial_dims):
1071    """Create ConvolutionDimensionNumbers proto for convolutions."""
1072    nd = num_spatial_dims
1073    dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers()
1074    dimension_numbers.input_batch_dimension = 0
1075    dimension_numbers.input_feature_dimension = 1
1076    dimension_numbers.output_batch_dimension = 0
1077    dimension_numbers.output_feature_dimension = 1
1078    dimension_numbers.kernel_output_feature_dimension = 0
1079    dimension_numbers.kernel_input_feature_dimension = 1
1080    dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
1081    dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
1082    dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
1083    return dimension_numbers
1084
1085
1086def _forward_methods_to_local_builder():
1087  """Forward remaining ComputationBuilder methods to the C API.
1088
1089  Set up methods, corresponding to unary and binary XLA operations,
1090  whose calls are forwarded in a boilerplate manner to the underlying
1091  LocalComputationBuilder C-extension API.
1092  """
1093
1094  def forward_to_local_builder_with_handles(target_method, is_binop=False):
1095    """Generate a forwarding method that wraps/unwraps data handles."""
1096
1097    def forward(self, *args, **kwargs):
1098      unwrapped_args = [_unwrap_data_handle(arg) for arg in args]
1099
1100      if is_binop and len(unwrapped_args) < 3:
1101        unwrapped_args.append(kwargs.get('broadcast_dimensions', ()))
1102
1103      return _wrap_data_handle(
1104          target_method(
1105              self._client,  # pylint: disable=protected-access
1106              *unwrapped_args))
1107
1108    return forward
1109
1110  for method_name in _UNARY_OPS:
1111    forward = forward_to_local_builder_with_handles(
1112        getattr(c_api.LocalComputationBuilder, method_name))
1113    forward.__name__ = method_name
1114    setattr(ComputationBuilder, method_name, forward)
1115
1116  for method_name in _BINARY_OPS:
1117    forward = forward_to_local_builder_with_handles(
1118        getattr(c_api.LocalComputationBuilder, method_name), is_binop=True)
1119    forward.__name__ = method_name
1120    setattr(ComputationBuilder, method_name, forward)
1121
1122
1123_forward_methods_to_local_builder()
1124
1125
1126def initialize_replica_count(replica_count):
1127  """Initializes the desired replica count to use on XLA service init.
1128
1129  Args:
1130    replica_count: number of replicas that are desired for set up during XLA
1131      initialization.
1132
1133  Raises:
1134    A runtime exception if the XLA service has already been initialized.
1135  """
1136  c_api.InitializeReplicaCount(replica_count)
1137
1138
1139def get_replica_count():
1140  """Returns the current replica count used for the XLA service.
1141
1142  Note: this will return a value whether the XLA service has been initialized
1143  yet or not.
1144  """
1145  return c_api.GetReplicaCount()
1146
1147
1148def GetPaddingConfigFromTriples(triples):
1149  """Create PaddingConfig proto from list of triples of integers."""
1150  padding_config = xla_data_pb2.PaddingConfig()
1151  for lo, hi, interior in triples:
1152    dimension = padding_config.dimensions.add()
1153    dimension.edge_padding_low = lo
1154    dimension.edge_padding_high = hi
1155    dimension.interior_padding = interior
1156  return padding_config
1157
1158
1159def GetDotDimensionsFromLists(dimension_numbers):
1160  (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
1161  dot_dims_proto = xla_data_pb2.DotDimensionNumbers()
1162  dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
1163  dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
1164  dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
1165  dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
1166  return dot_dims_proto
1167