1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// See docs in ../ops/nn_ops.cc. 17 18#define EIGEN_USE_THREADS 19 20#include "tensorflow/core/kernels/batch_norm_op.h" 21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22#include "tensorflow/core/framework/numeric_op.h" 23#include "tensorflow/core/framework/op_kernel.h" 24#include "tensorflow/core/framework/register_types.h" 25#include "tensorflow/core/framework/tensor.h" 26 27namespace tensorflow { 28 29typedef Eigen::ThreadPoolDevice CPUDevice; 30typedef Eigen::GpuDevice GPUDevice; 31#ifdef TENSORFLOW_USE_SYCL 32typedef Eigen::SyclDevice SYCLDevice; 33#endif // TENSORFLOW_USE_SYCL 34 35template <typename Device, typename T> 36class BatchNormOp : public OpKernel { 37 public: 38 explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) { 39 float variance_epsilon; 40 OP_REQUIRES_OK(context, 41 context->GetAttr("variance_epsilon", &variance_epsilon)); 42 variance_epsilon_ = T(variance_epsilon); 43 OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", 44 &scale_after_normalization_)); 45 } 46 47 void Compute(OpKernelContext* context) override { 48 const Tensor& input = context->input(0); 49 const Tensor& mean = context->input(1); 50 const Tensor& var = context->input(2); 51 const Tensor& beta = context->input(3); 52 const Tensor& gamma = context->input(4); 53 54 OP_REQUIRES(context, input.dims() == 4, 55 errors::InvalidArgument("input must be 4-dimensional", 56 input.shape().DebugString())); 57 OP_REQUIRES(context, mean.dims() == 1, 58 errors::InvalidArgument("mean must be 1-dimensional", 59 mean.shape().DebugString())); 60 OP_REQUIRES(context, var.dims() == 1, 61 errors::InvalidArgument("var must be 1-dimensional", 62 var.shape().DebugString())); 63 OP_REQUIRES(context, beta.dims() == 1, 64 errors::InvalidArgument("beta must be 1-dimensional", 65 beta.shape().DebugString())); 66 OP_REQUIRES(context, gamma.dims() == 1, 67 errors::InvalidArgument("gamma must be 1-dimensional", 68 gamma.shape().DebugString())); 69 70 Tensor* output = nullptr; 71 OP_REQUIRES_OK(context, 72 context->allocate_output(0, input.shape(), &output)); 73 74 functor::BatchNorm<Device, T>()( 75 context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(), 76 var.vec<T>(), beta.vec<T>(), gamma.vec<T>(), variance_epsilon_, 77 scale_after_normalization_, output->tensor<T, 4>()); 78 } 79 80 private: 81 T variance_epsilon_; 82 bool scale_after_normalization_; 83}; 84 85template <typename Device, typename T> 86class BatchNormGradOp : public OpKernel { 87 public: 88 explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) { 89 float variance_epsilon; 90 OP_REQUIRES_OK(context, 91 context->GetAttr("variance_epsilon", &variance_epsilon)); 92 variance_epsilon_ = T(variance_epsilon); 93 OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization", 94 &scale_after_normalization_)); 95 } 96 97 void Compute(OpKernelContext* context) override { 98 const Tensor& input = context->input(0); 99 const Tensor& mean = context->input(1); 100 const Tensor& var = context->input(2); 101 const Tensor& gamma = context->input(3); 102 const Tensor& out_backprop = context->input(4); 103 104 OP_REQUIRES(context, input.dims() == 4, 105 errors::InvalidArgument("input must be 4-dimensional", 106 input.shape().DebugString())); 107 OP_REQUIRES(context, mean.dims() == 1, 108 errors::InvalidArgument("mean must be 1-dimensional", 109 mean.shape().DebugString())); 110 OP_REQUIRES(context, var.dims() == 1, 111 errors::InvalidArgument("var must be 1-dimensional", 112 var.shape().DebugString())); 113 OP_REQUIRES(context, gamma.dims() == 1, 114 errors::InvalidArgument("gamma must be 1-dimensional", 115 gamma.shape().DebugString())); 116 OP_REQUIRES(context, out_backprop.dims() == 4, 117 errors::InvalidArgument("out_backprop must be 4-dimensional", 118 out_backprop.shape().DebugString())); 119 120 Tensor* dx = nullptr; 121 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 122 {0, 4}, 0, input.shape(), &dx)); 123 Tensor* dm = nullptr; 124 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 125 {1}, 1, mean.shape(), &dm)); 126 Tensor* dv = nullptr; 127 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 128 {2}, 2, var.shape(), &dv)); 129 Tensor* db = nullptr; 130 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 131 {3}, 3, mean.shape(), &db)); 132 Tensor* dg = nullptr; 133 OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg)); 134 135 // Scratch buffer of [depth] dimension, aka the 4th dimension of input, 136 // which is dim_size(3), for calculating various combinations of 137 // (var + epsilon). 138 Tensor scratch1; 139 OP_REQUIRES_OK(context, context->allocate_temp( 140 DataTypeToEnum<T>::value, 141 TensorShape({input.dim_size(3)}), &scratch1)); 142 143 // Scratch buffer of [depth] dimension for saving intermediate calculation 144 // values. 145 Tensor scratch2; 146 OP_REQUIRES_OK(context, context->allocate_temp( 147 DataTypeToEnum<T>::value, 148 TensorShape({input.dim_size(3)}), &scratch2)); 149 150 functor::BatchNormGrad<Device, T>()( 151 context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(), 152 var.vec<T>(), gamma.vec<T>(), out_backprop.tensor<T, 4>(), 153 variance_epsilon_, scale_after_normalization_, dx->tensor<T, 4>(), 154 dm->vec<T>(), dv->vec<T>(), db->vec<T>(), dg->vec<T>(), 155 scratch1.vec<T>(), scratch2.vec<T>()); 156 } 157 158 private: 159 T variance_epsilon_; 160 bool scale_after_normalization_; 161}; 162 163#define REGISTER_KERNEL(T) \ 164 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ 165 .Device(DEVICE_CPU) \ 166 .TypeConstraint<T>("T"), \ 167 BatchNormOp<CPUDevice, T>); 168 169TF_CALL_half(REGISTER_KERNEL); 170TF_CALL_float(REGISTER_KERNEL); 171TF_CALL_double(REGISTER_KERNEL); 172#undef REGISTER_KERNEL 173 174#if GOOGLE_CUDA 175// Forward declarations of the functor specializations for GPU. 176namespace functor { 177#define DECLARE_GPU_SPEC(T) \ 178 template <> \ 179 void BatchNorm<GPUDevice, T>::operator()( \ 180 const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \ 181 typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \ 182 typename TTypes<T>::ConstVec beta, typename TTypes<T>::ConstVec gamma, \ 183 T variance_epsilon, bool scale_after_normalization, \ 184 typename TTypes<T, 4>::Tensor output); \ 185 extern template struct BatchNorm<GPUDevice, T>; 186 187#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); 188 189TF_CALL_half(DECLARE_GPU_SPECS); 190TF_CALL_float(DECLARE_GPU_SPECS); 191#undef DECLARE_GPU_SPEC 192} // namespace functor 193 194// Registration of the GPU implementations. 195#define REGISTER_GPU_KERNEL(T) \ 196 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ 197 .Device(DEVICE_GPU) \ 198 .TypeConstraint<T>("T"), \ 199 BatchNormOp<GPUDevice, T>); 200 201TF_CALL_half(REGISTER_GPU_KERNEL); 202TF_CALL_float(REGISTER_GPU_KERNEL); 203#undef REGISTER_GPU_KERNEL 204 205#endif // GOOGLE_CUDA 206 207#if TENSORFLOW_USE_SYCL 208#define REGISTER_KERNEL(T) \ 209 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \ 210 .Device(DEVICE_SYCL) \ 211 .TypeConstraint<T>("T"), \ 212 BatchNormOp<SYCLDevice, T>); 213 214TF_CALL_float(REGISTER_KERNEL); 215TF_CALL_double(REGISTER_KERNEL); 216#undef REGISTER_KERNEL 217#endif // TENSORFLOW_USE_SYCL 218 219#define REGISTER_KERNEL(T) \ 220 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ 221 .Device(DEVICE_CPU) \ 222 .TypeConstraint<T>("T"), \ 223 BatchNormGradOp<CPUDevice, T>); 224 225TF_CALL_half(REGISTER_KERNEL); 226TF_CALL_float(REGISTER_KERNEL); 227TF_CALL_double(REGISTER_KERNEL); 228#undef REGISTER_KERNEL 229 230#if GOOGLE_CUDA 231// Forward declarations of the functor specializations for GPU. 232namespace functor { 233#define DECLARE_GPU_SPEC(T) \ 234 template <> \ 235 void BatchNormGrad<GPUDevice, T>::operator()( \ 236 const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \ 237 typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \ 238 typename TTypes<T>::ConstVec gamma, \ 239 typename TTypes<T, 4>::ConstTensor out_backprop, T variance_epsilon, \ 240 bool scale_after_normalization, typename TTypes<T, 4>::Tensor dx, \ 241 typename TTypes<T>::Vec dm, typename TTypes<T>::Vec dv, \ 242 typename TTypes<T>::Vec db, typename TTypes<T>::Vec dg, \ 243 typename TTypes<T>::Vec scratch1, typename TTypes<T>::Vec scratch2); \ 244 extern template struct BatchNormGrad<GPUDevice, T>; 245 246#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); 247 248TF_CALL_half(DECLARE_GPU_SPECS); 249TF_CALL_float(DECLARE_GPU_SPECS); 250#undef DECLARE_GPU_SPEC 251} // namespace functor 252 253// Registration of the GPU implementations. 254#define REGISTER_GPU_KERNEL(T) \ 255 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ 256 .Device(DEVICE_GPU) \ 257 .TypeConstraint<T>("T"), \ 258 BatchNormGradOp<GPUDevice, T>); 259 260TF_CALL_half(REGISTER_GPU_KERNEL); 261TF_CALL_float(REGISTER_GPU_KERNEL); 262#undef REGISTER_GPU_KERNEL 263 264#endif // GOOGLE_CUDA 265 266#if TENSORFLOW_USE_SYCL 267#define REGISTER_KERNEL(T) \ 268 REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \ 269 .Device(DEVICE_SYCL) \ 270 .TypeConstraint<T>("T"), \ 271 BatchNormGradOp<SYCLDevice, T>); 272 273TF_CALL_float(REGISTER_KERNEL); 274TF_CALL_double(REGISTER_KERNEL); 275#undef REGISTER_KERNEL 276 277#endif // TENSORFLOW_USE_SYCL 278 279} // namespace tensorflow 280