fused_batchnorm_test.py revision ad7eeec1cc06d7fdba6ee404f03a35fab9cd3e6a
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for fused batch norm operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.compiler.tests.xla_test import XLATestCase
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_nn_ops
26from tensorflow.python.ops import gradient_checker
27from tensorflow.python.ops import nn
28from tensorflow.python.platform import test
29
30
31class FusedBatchNormTest(XLATestCase):
32
33  def _reference_training(self, x, scale, offset, epsilon, data_format):
34    if data_format != "NHWC":
35      raise ValueError("data_format must be NHWC, got %s." % data_format)
36    x_square = x * x
37    x_square_sum = np.sum(x_square, (0, 1, 2))
38    x_sum = np.sum(x, axis=(0, 1, 2))
39    element_count = np.size(x) / int(np.shape(x)[-1])
40    mean = x_sum / element_count
41    var = x_square_sum / element_count - mean * mean
42    normalized = (x - mean) / np.sqrt(var + epsilon)
43    return (normalized * scale + offset), mean, var
44
45  def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format):
46    # Use the following formulas to calculate gradients:
47    # grad_scale =
48    #   sum(grad_y * (x - mean)) * rsqrt(var + epsilon)
49    #
50    # grad_offset = sum(output_y)
51    #
52    # grad_x =
53    #   1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) -
54    #   (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
55    if data_format != "NHWC":
56      raise ValueError("data_format must be NHWC, got %s." % data_format)
57    grad_x = scale * (grad_y - np.mean(grad_y, axis=(0, 1, 2)) -
58                      (x - mean) * np.mean(grad_y *
59                                           (x - mean), axis=(0, 1, 2)) /
60                      (var + epsilon)) / np.sqrt(var + epsilon)
61    grad_scale = np.sum(
62        grad_y * (x - mean) / np.sqrt(var + epsilon), axis=(0, 1, 2))
63    grad_offset = np.sum(grad_y, axis=(0, 1, 2))
64    return grad_x, grad_scale, grad_offset
65
66  def testInference(self):
67    channel = 3
68    x_shape = [2, 2, 6, channel]
69    scale_shape = [channel]
70    x_val = np.random.random_sample(x_shape).astype(np.float32)
71    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
72
73    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
74    data_format = "NHWC"
75    with self.test_session() as sess, self.test_scope():
76      # To avoid constant folding
77      t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
78      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
79      offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
80      epsilon = 0.001
81      y_ref, mean_ref, var_ref = self._reference_training(
82          x_val, scale_val, offset_val, epsilon, data_format)
83      y, mean, variance = nn.fused_batch_norm(
84          t_val,
85          scale,
86          offset,
87          mean=mean_ref,
88          variance=var_ref,
89          epsilon=epsilon,
90          data_format=data_format,
91          is_training=False)
92
93      y_val, _, _ = sess.run(
94          [y, mean,
95           variance], {t_val: x_val,
96                       scale: scale_val,
97                       offset: offset_val})
98      self.assertAllClose(y_val, y_ref, atol=1e-3)
99
100  def _testLearning(self, use_gradient_checker):
101    channel = 3
102    x_shape = [2, 2, 6, channel]
103    scale_shape = [channel]
104    x_val = np.random.random_sample(x_shape).astype(np.float32)
105    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
106
107    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
108    mean_val = np.random.random_sample(scale_shape).astype(np.float32)
109    var_val = np.random.random_sample(scale_shape).astype(np.float32)
110    data_format = "NHWC"
111    with self.test_session() as sess, self.test_scope():
112      # To avoid constant folding
113      t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
114      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
115      offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
116      epsilon = 0.001
117      y, mean, var = nn.fused_batch_norm(
118          t_val,
119          scale,
120          offset,
121          mean=None,
122          variance=None,
123          epsilon=epsilon,
124          data_format=data_format,
125          is_training=True)
126      # Check gradient.
127      if use_gradient_checker:
128        err = gradient_checker.compute_gradient_error(
129            t_val,
130            x_shape,
131            y,
132            x_shape,
133            extra_feed_dict={
134                t_val: x_val,
135                scale: scale_val,
136                offset: offset_val
137            })
138        self.assertLess(err, 1e-3)
139
140      y_val, mean_val, var_val = sess.run(
141          [y, mean, var], {t_val: x_val,
142                           scale: scale_val,
143                           offset: offset_val})
144      y_ref, mean_ref, var_ref = self._reference_training(
145          x_val, scale_val, offset_val, epsilon, data_format)
146      self.assertAllClose(mean_val, mean_ref, atol=1e-3)
147      self.assertAllClose(y_val, y_ref, atol=1e-3)
148      self.assertAllClose(var_val, var_ref, atol=1e-3)
149
150  def testLearning(self):
151    self._testLearning(False)
152
153  def testLearningWithGradientChecker(self):
154    self._testLearning(True)
155
156  def testGradient(self):
157    # TODO(b/64270657): Use gradient_checker here in addition to comparing with
158    # this reference implementation.
159    channel = 3
160    x_shape = [2, 2, 6, channel]
161    scale_shape = [channel]
162    grad_val = np.random.random_sample(x_shape).astype(np.float32)
163    x_val = np.random.random_sample(x_shape).astype(np.float32)
164    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
165    mean_val = np.random.random_sample(scale_shape).astype(np.float32)
166    var_val = np.random.random_sample(scale_shape).astype(np.float32)
167    epsilon = 0.001
168
169    with self.test_session() as sess, self.test_scope():
170      grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad")
171      x = array_ops.placeholder(np.float32, shape=x_shape, name="x")
172      mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean")
173      var = array_ops.placeholder(np.float32, shape=scale_shape, name="var")
174      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
175      grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
176          grad, x, scale, mean, var, data_format="NHWC")
177
178      grad_x_val, grad_scale_val, grad_offset_val = sess.run(
179          [grad_x, grad_scale, grad_offset], {
180              grad: grad_val,
181              x: x_val,
182              mean: mean_val,
183              var: var_val,
184              scale: scale_val
185          })
186
187      grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad(
188          x_val, grad_val, scale_val, mean_val, var_val, epsilon, "NHWC")
189
190      self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2)
191      self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
192      self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
193
194
195if __name__ == "__main__":
196  test.main()
197