1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Operations for TPUs.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import platform 23 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26 27if platform.system() != "Windows": 28 # pylint: disable=wildcard-import,unused-import,g-import-not-at-top 29 from tensorflow.contrib.tpu.ops import gen_tpu_ops 30 from tensorflow.contrib.tpu.ops.gen_tpu_ops import * 31 32 from tensorflow.contrib.util import loader 33 from tensorflow.python.platform import resource_loader 34 # pylint: enable=wildcard-import,unused-import,g-import-not-at-top 35 36 _tpu_ops = loader.load_op_library( 37 resource_loader.get_path_to_datafile("_tpu_ops.so")) 38 39 @ops.RegisterGradient("CrossReplicaSum") 40 def _cross_replica_sum_grad(op, grad): 41 del op # Unused 42 # The gradient of a cross replica sum is also a cross-replica sum. 43 return gen_tpu_ops.cross_replica_sum(grad) 44 45 # This extra type checking exists to give a more helpful error message in 46 # the common case that uint8 and int64 values are infed. Remove when both 47 # types are supported. 48 49 _SUPPORTED_INFEED_DTYPES = set([ 50 dtypes.bool, dtypes.int32, dtypes.bfloat16, dtypes.float32, 51 dtypes.complex64 52 ]) 53 54 def infeed_dequeue(dtype, shape, name=None): 55 """A placeholder op for a value that will be fed into the computation. 56 57 Args: 58 dtype: A `tf.DType`. The type of elements in the tensor. 59 shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. 60 name: A name for the operation (optional). 61 62 Returns: 63 A `Tensor` of type `dtype`. 64 A tensor that will be provided using the infeed mechanism. 65 66 Raises: 67 TypeError: If 'dtype` is not a supported infeed type. 68 """ 69 if dtype not in _SUPPORTED_INFEED_DTYPES: 70 raise TypeError( 71 "{} is not a supported TPU infeed type. Supported types are: " 72 "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) 73 74 return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) 75 76 # pylint: disable=redefined-outer-name 77 def infeed_dequeue_tuple(dtypes, shapes, name=None): 78 """A placeholder op for values fed into the TPU simultaneously as a tuple. 79 80 Args: 81 dtypes: A list of `tf.DType`s that has length `>= 1`. 82 The element types of each element in `outputs`. 83 shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). 84 The shapes of each tensor in `outputs`. 85 name: A name for the operation (optional). 86 87 Returns: 88 A list of `Tensor` objects of type `dtypes`. 89 A list of tensors that will be provided using the infeed mechanism. 90 91 Raises: 92 TypeError: If a type in 'dtypes` is not a supported infeed type. 93 """ 94 for dtype in dtypes: 95 if dtype not in _SUPPORTED_INFEED_DTYPES: 96 raise TypeError( 97 "{} is not a supported TPU infeed type. Supported types are: " 98 "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) 99 return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) 100 # pylint: enable=redefined-outer-name 101 102else: 103 # We have already built the appropriate libraries into the binary via CMake 104 # if we have built contrib, so we don't need this 105 pass 106