1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/batch_norm_op.h"
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/numeric_op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26
27namespace tensorflow {
28
29typedef Eigen::ThreadPoolDevice CPUDevice;
30typedef Eigen::GpuDevice GPUDevice;
31#ifdef TENSORFLOW_USE_SYCL
32typedef Eigen::SyclDevice SYCLDevice;
33#endif  // TENSORFLOW_USE_SYCL
34
35template <typename Device, typename T>
36class BatchNormOp : public OpKernel {
37 public:
38  explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
39    float variance_epsilon;
40    OP_REQUIRES_OK(context,
41                   context->GetAttr("variance_epsilon", &variance_epsilon));
42    variance_epsilon_ = T(variance_epsilon);
43    OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
44                                             &scale_after_normalization_));
45  }
46
47  void Compute(OpKernelContext* context) override {
48    const Tensor& input = context->input(0);
49    const Tensor& mean = context->input(1);
50    const Tensor& var = context->input(2);
51    const Tensor& beta = context->input(3);
52    const Tensor& gamma = context->input(4);
53
54    OP_REQUIRES(context, input.dims() == 4,
55                errors::InvalidArgument("input must be 4-dimensional",
56                                        input.shape().DebugString()));
57    OP_REQUIRES(context, mean.dims() == 1,
58                errors::InvalidArgument("mean must be 1-dimensional",
59                                        mean.shape().DebugString()));
60    OP_REQUIRES(context, var.dims() == 1,
61                errors::InvalidArgument("var must be 1-dimensional",
62                                        var.shape().DebugString()));
63    OP_REQUIRES(context, beta.dims() == 1,
64                errors::InvalidArgument("beta must be 1-dimensional",
65                                        beta.shape().DebugString()));
66    OP_REQUIRES(context, gamma.dims() == 1,
67                errors::InvalidArgument("gamma must be 1-dimensional",
68                                        gamma.shape().DebugString()));
69
70    Tensor* output = nullptr;
71    OP_REQUIRES_OK(context,
72                   context->allocate_output(0, input.shape(), &output));
73
74    functor::BatchNorm<Device, T>()(
75        context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(),
76        var.vec<T>(), beta.vec<T>(), gamma.vec<T>(), variance_epsilon_,
77        scale_after_normalization_, output->tensor<T, 4>());
78  }
79
80 private:
81  T variance_epsilon_;
82  bool scale_after_normalization_;
83};
84
85template <typename Device, typename T>
86class BatchNormGradOp : public OpKernel {
87 public:
88  explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) {
89    float variance_epsilon;
90    OP_REQUIRES_OK(context,
91                   context->GetAttr("variance_epsilon", &variance_epsilon));
92    variance_epsilon_ = T(variance_epsilon);
93    OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
94                                             &scale_after_normalization_));
95  }
96
97  void Compute(OpKernelContext* context) override {
98    const Tensor& input = context->input(0);
99    const Tensor& mean = context->input(1);
100    const Tensor& var = context->input(2);
101    const Tensor& gamma = context->input(3);
102    const Tensor& out_backprop = context->input(4);
103
104    OP_REQUIRES(context, input.dims() == 4,
105                errors::InvalidArgument("input must be 4-dimensional",
106                                        input.shape().DebugString()));
107    OP_REQUIRES(context, mean.dims() == 1,
108                errors::InvalidArgument("mean must be 1-dimensional",
109                                        mean.shape().DebugString()));
110    OP_REQUIRES(context, var.dims() == 1,
111                errors::InvalidArgument("var must be 1-dimensional",
112                                        var.shape().DebugString()));
113    OP_REQUIRES(context, gamma.dims() == 1,
114                errors::InvalidArgument("gamma must be 1-dimensional",
115                                        gamma.shape().DebugString()));
116    OP_REQUIRES(context, out_backprop.dims() == 4,
117                errors::InvalidArgument("out_backprop must be 4-dimensional",
118                                        out_backprop.shape().DebugString()));
119
120    Tensor* dx = nullptr;
121    OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
122                                {0, 4}, 0, input.shape(), &dx));
123    Tensor* dm = nullptr;
124    OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
125                                {1}, 1, mean.shape(), &dm));
126    Tensor* dv = nullptr;
127    OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
128                                {2}, 2, var.shape(), &dv));
129    Tensor* db = nullptr;
130    OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
131                                {3}, 3, mean.shape(), &db));
132    Tensor* dg = nullptr;
133    OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
134
135    // Scratch buffer of [depth] dimension, aka the 4th dimension of input,
136    // which is dim_size(3), for calculating various combinations of
137    // (var + epsilon).
138    Tensor scratch1;
139    OP_REQUIRES_OK(context, context->allocate_temp(
140                                DataTypeToEnum<T>::value,
141                                TensorShape({input.dim_size(3)}), &scratch1));
142
143    // Scratch buffer of [depth] dimension for saving intermediate calculation
144    // values.
145    Tensor scratch2;
146    OP_REQUIRES_OK(context, context->allocate_temp(
147                                DataTypeToEnum<T>::value,
148                                TensorShape({input.dim_size(3)}), &scratch2));
149
150    functor::BatchNormGrad<Device, T>()(
151        context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(),
152        var.vec<T>(), gamma.vec<T>(), out_backprop.tensor<T, 4>(),
153        variance_epsilon_, scale_after_normalization_, dx->tensor<T, 4>(),
154        dm->vec<T>(), dv->vec<T>(), db->vec<T>(), dg->vec<T>(),
155        scratch1.vec<T>(), scratch2.vec<T>());
156  }
157
158 private:
159  T variance_epsilon_;
160  bool scale_after_normalization_;
161};
162
163#define REGISTER_KERNEL(T)                                         \
164  REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
165                              .Device(DEVICE_CPU)                  \
166                              .TypeConstraint<T>("T"),             \
167                          BatchNormOp<CPUDevice, T>);
168
169TF_CALL_half(REGISTER_KERNEL);
170TF_CALL_float(REGISTER_KERNEL);
171TF_CALL_double(REGISTER_KERNEL);
172#undef REGISTER_KERNEL
173
174#if GOOGLE_CUDA
175// Forward declarations of the functor specializations for GPU.
176namespace functor {
177#define DECLARE_GPU_SPEC(T)                                                  \
178  template <>                                                                \
179  void BatchNorm<GPUDevice, T>::operator()(                                  \
180      const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,          \
181      typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var,   \
182      typename TTypes<T>::ConstVec beta, typename TTypes<T>::ConstVec gamma, \
183      T variance_epsilon, bool scale_after_normalization,                    \
184      typename TTypes<T, 4>::Tensor output);                                 \
185  extern template struct BatchNorm<GPUDevice, T>;
186
187#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
188
189TF_CALL_half(DECLARE_GPU_SPECS);
190TF_CALL_float(DECLARE_GPU_SPECS);
191#undef DECLARE_GPU_SPEC
192}  // namespace functor
193
194// Registration of the GPU implementations.
195#define REGISTER_GPU_KERNEL(T)                                     \
196  REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
197                              .Device(DEVICE_GPU)                  \
198                              .TypeConstraint<T>("T"),             \
199                          BatchNormOp<GPUDevice, T>);
200
201TF_CALL_half(REGISTER_GPU_KERNEL);
202TF_CALL_float(REGISTER_GPU_KERNEL);
203#undef REGISTER_GPU_KERNEL
204
205#endif  // GOOGLE_CUDA
206
207#if TENSORFLOW_USE_SYCL
208#define REGISTER_KERNEL(T)                                         \
209  REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
210                              .Device(DEVICE_SYCL)                 \
211                              .TypeConstraint<T>("T"),             \
212                          BatchNormOp<SYCLDevice, T>);
213
214TF_CALL_float(REGISTER_KERNEL);
215TF_CALL_double(REGISTER_KERNEL);
216#undef REGISTER_KERNEL
217#endif  // TENSORFLOW_USE_SYCL
218
219#define REGISTER_KERNEL(T)                                             \
220  REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
221                              .Device(DEVICE_CPU)                      \
222                              .TypeConstraint<T>("T"),                 \
223                          BatchNormGradOp<CPUDevice, T>);
224
225TF_CALL_half(REGISTER_KERNEL);
226TF_CALL_float(REGISTER_KERNEL);
227TF_CALL_double(REGISTER_KERNEL);
228#undef REGISTER_KERNEL
229
230#if GOOGLE_CUDA
231// Forward declarations of the functor specializations for GPU.
232namespace functor {
233#define DECLARE_GPU_SPEC(T)                                                \
234  template <>                                                              \
235  void BatchNormGrad<GPUDevice, T>::operator()(                            \
236      const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,        \
237      typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \
238      typename TTypes<T>::ConstVec gamma,                                  \
239      typename TTypes<T, 4>::ConstTensor out_backprop, T variance_epsilon, \
240      bool scale_after_normalization, typename TTypes<T, 4>::Tensor dx,    \
241      typename TTypes<T>::Vec dm, typename TTypes<T>::Vec dv,              \
242      typename TTypes<T>::Vec db, typename TTypes<T>::Vec dg,              \
243      typename TTypes<T>::Vec scratch1, typename TTypes<T>::Vec scratch2); \
244  extern template struct BatchNormGrad<GPUDevice, T>;
245
246#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
247
248TF_CALL_half(DECLARE_GPU_SPECS);
249TF_CALL_float(DECLARE_GPU_SPECS);
250#undef DECLARE_GPU_SPEC
251}  // namespace functor
252
253// Registration of the GPU implementations.
254#define REGISTER_GPU_KERNEL(T)                                         \
255  REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
256                              .Device(DEVICE_GPU)                      \
257                              .TypeConstraint<T>("T"),                 \
258                          BatchNormGradOp<GPUDevice, T>);
259
260TF_CALL_half(REGISTER_GPU_KERNEL);
261TF_CALL_float(REGISTER_GPU_KERNEL);
262#undef REGISTER_GPU_KERNEL
263
264#endif  // GOOGLE_CUDA
265
266#if TENSORFLOW_USE_SYCL
267#define REGISTER_KERNEL(T)                                             \
268  REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
269                              .Device(DEVICE_SYCL)                     \
270                              .TypeConstraint<T>("T"),                 \
271                          BatchNormGradOp<SYCLDevice, T>);
272
273TF_CALL_float(REGISTER_KERNEL);
274TF_CALL_double(REGISTER_KERNEL);
275#undef REGISTER_KERNEL
276
277#endif  // TENSORFLOW_USE_SYCL
278
279}  // namespace tensorflow
280