1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Operations for TPUs."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import platform
23
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26
27if platform.system() != "Windows":
28  # pylint: disable=wildcard-import,unused-import,g-import-not-at-top
29  from tensorflow.contrib.tpu.ops import gen_tpu_ops
30  from tensorflow.contrib.tpu.ops.gen_tpu_ops import *
31
32  from tensorflow.contrib.util import loader
33  from tensorflow.python.platform import resource_loader
34  # pylint: enable=wildcard-import,unused-import,g-import-not-at-top
35
36  _tpu_ops = loader.load_op_library(
37      resource_loader.get_path_to_datafile("_tpu_ops.so"))
38
39  @ops.RegisterGradient("CrossReplicaSum")
40  def _cross_replica_sum_grad(op, grad):
41    del op  # Unused
42    # The gradient of a cross replica sum is also a cross-replica sum.
43    return gen_tpu_ops.cross_replica_sum(grad)
44
45  # This extra type checking exists to give a more helpful error message in
46  # the common case that uint8 and int64 values are infed. Remove when both
47  # types are supported.
48
49  _SUPPORTED_INFEED_DTYPES = set([
50      dtypes.bool, dtypes.int32, dtypes.bfloat16, dtypes.float32,
51      dtypes.complex64
52  ])
53
54  def infeed_dequeue(dtype, shape, name=None):
55    """A placeholder op for a value that will be fed into the computation.
56
57    Args:
58      dtype: A `tf.DType`. The type of elements in the tensor.
59      shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
60      name: A name for the operation (optional).
61
62    Returns:
63      A `Tensor` of type `dtype`.
64      A tensor that will be provided using the infeed mechanism.
65
66    Raises:
67      TypeError: If 'dtype` is not a supported infeed type.
68    """
69    if dtype not in _SUPPORTED_INFEED_DTYPES:
70      raise TypeError(
71          "{} is not a supported TPU infeed type. Supported types are: "
72          "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
73
74    return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
75
76  # pylint: disable=redefined-outer-name
77  def infeed_dequeue_tuple(dtypes, shapes, name=None):
78    """A placeholder op for values fed into the TPU simultaneously as a tuple.
79
80    Args:
81      dtypes: A list of `tf.DType`s that has length `>= 1`.
82        The element types of each element in `outputs`.
83      shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
84        The shapes of each tensor in `outputs`.
85      name: A name for the operation (optional).
86
87    Returns:
88      A list of `Tensor` objects of type `dtypes`.
89      A list of tensors that will be provided using the infeed mechanism.
90
91    Raises:
92      TypeError: If a type in 'dtypes` is not a supported infeed type.
93    """
94    for dtype in dtypes:
95      if dtype not in _SUPPORTED_INFEED_DTYPES:
96        raise TypeError(
97            "{} is not a supported TPU infeed type. Supported types are: "
98            "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
99    return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
100  # pylint: enable=redefined-outer-name
101
102else:
103  # We have already built the appropriate libraries into the binary via CMake
104  # if we have built contrib, so we don't need this
105  pass
106