fused_batchnorm_test.py revision ad7eeec1cc06d7fdba6ee404f03a35fab9cd3e6a
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for fused batch norm operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.compiler.tests.xla_test import XLATestCase 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import gen_nn_ops 26from tensorflow.python.ops import gradient_checker 27from tensorflow.python.ops import nn 28from tensorflow.python.platform import test 29 30 31class FusedBatchNormTest(XLATestCase): 32 33 def _reference_training(self, x, scale, offset, epsilon, data_format): 34 if data_format != "NHWC": 35 raise ValueError("data_format must be NHWC, got %s." % data_format) 36 x_square = x * x 37 x_square_sum = np.sum(x_square, (0, 1, 2)) 38 x_sum = np.sum(x, axis=(0, 1, 2)) 39 element_count = np.size(x) / int(np.shape(x)[-1]) 40 mean = x_sum / element_count 41 var = x_square_sum / element_count - mean * mean 42 normalized = (x - mean) / np.sqrt(var + epsilon) 43 return (normalized * scale + offset), mean, var 44 45 def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format): 46 # Use the following formulas to calculate gradients: 47 # grad_scale = 48 # sum(grad_y * (x - mean)) * rsqrt(var + epsilon) 49 # 50 # grad_offset = sum(output_y) 51 # 52 # grad_x = 53 # 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) - 54 # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) 55 if data_format != "NHWC": 56 raise ValueError("data_format must be NHWC, got %s." % data_format) 57 grad_x = scale * (grad_y - np.mean(grad_y, axis=(0, 1, 2)) - 58 (x - mean) * np.mean(grad_y * 59 (x - mean), axis=(0, 1, 2)) / 60 (var + epsilon)) / np.sqrt(var + epsilon) 61 grad_scale = np.sum( 62 grad_y * (x - mean) / np.sqrt(var + epsilon), axis=(0, 1, 2)) 63 grad_offset = np.sum(grad_y, axis=(0, 1, 2)) 64 return grad_x, grad_scale, grad_offset 65 66 def testInference(self): 67 channel = 3 68 x_shape = [2, 2, 6, channel] 69 scale_shape = [channel] 70 x_val = np.random.random_sample(x_shape).astype(np.float32) 71 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 72 73 offset_val = np.random.random_sample(scale_shape).astype(np.float32) 74 data_format = "NHWC" 75 with self.test_session() as sess, self.test_scope(): 76 # To avoid constant folding 77 t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") 78 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 79 offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") 80 epsilon = 0.001 81 y_ref, mean_ref, var_ref = self._reference_training( 82 x_val, scale_val, offset_val, epsilon, data_format) 83 y, mean, variance = nn.fused_batch_norm( 84 t_val, 85 scale, 86 offset, 87 mean=mean_ref, 88 variance=var_ref, 89 epsilon=epsilon, 90 data_format=data_format, 91 is_training=False) 92 93 y_val, _, _ = sess.run( 94 [y, mean, 95 variance], {t_val: x_val, 96 scale: scale_val, 97 offset: offset_val}) 98 self.assertAllClose(y_val, y_ref, atol=1e-3) 99 100 def _testLearning(self, use_gradient_checker): 101 channel = 3 102 x_shape = [2, 2, 6, channel] 103 scale_shape = [channel] 104 x_val = np.random.random_sample(x_shape).astype(np.float32) 105 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 106 107 offset_val = np.random.random_sample(scale_shape).astype(np.float32) 108 mean_val = np.random.random_sample(scale_shape).astype(np.float32) 109 var_val = np.random.random_sample(scale_shape).astype(np.float32) 110 data_format = "NHWC" 111 with self.test_session() as sess, self.test_scope(): 112 # To avoid constant folding 113 t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") 114 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 115 offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") 116 epsilon = 0.001 117 y, mean, var = nn.fused_batch_norm( 118 t_val, 119 scale, 120 offset, 121 mean=None, 122 variance=None, 123 epsilon=epsilon, 124 data_format=data_format, 125 is_training=True) 126 # Check gradient. 127 if use_gradient_checker: 128 err = gradient_checker.compute_gradient_error( 129 t_val, 130 x_shape, 131 y, 132 x_shape, 133 extra_feed_dict={ 134 t_val: x_val, 135 scale: scale_val, 136 offset: offset_val 137 }) 138 self.assertLess(err, 1e-3) 139 140 y_val, mean_val, var_val = sess.run( 141 [y, mean, var], {t_val: x_val, 142 scale: scale_val, 143 offset: offset_val}) 144 y_ref, mean_ref, var_ref = self._reference_training( 145 x_val, scale_val, offset_val, epsilon, data_format) 146 self.assertAllClose(mean_val, mean_ref, atol=1e-3) 147 self.assertAllClose(y_val, y_ref, atol=1e-3) 148 self.assertAllClose(var_val, var_ref, atol=1e-3) 149 150 def testLearning(self): 151 self._testLearning(False) 152 153 def testLearningWithGradientChecker(self): 154 self._testLearning(True) 155 156 def testGradient(self): 157 # TODO(b/64270657): Use gradient_checker here in addition to comparing with 158 # this reference implementation. 159 channel = 3 160 x_shape = [2, 2, 6, channel] 161 scale_shape = [channel] 162 grad_val = np.random.random_sample(x_shape).astype(np.float32) 163 x_val = np.random.random_sample(x_shape).astype(np.float32) 164 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 165 mean_val = np.random.random_sample(scale_shape).astype(np.float32) 166 var_val = np.random.random_sample(scale_shape).astype(np.float32) 167 epsilon = 0.001 168 169 with self.test_session() as sess, self.test_scope(): 170 grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") 171 x = array_ops.placeholder(np.float32, shape=x_shape, name="x") 172 mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") 173 var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") 174 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 175 grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( 176 grad, x, scale, mean, var, data_format="NHWC") 177 178 grad_x_val, grad_scale_val, grad_offset_val = sess.run( 179 [grad_x, grad_scale, grad_offset], { 180 grad: grad_val, 181 x: x_val, 182 mean: mean_val, 183 var: var_val, 184 scale: scale_val 185 }) 186 187 grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( 188 x_val, grad_val, scale_val, mean_val, var_val, epsilon, "NHWC") 189 190 self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) 191 self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) 192 self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) 193 194 195if __name__ == "__main__": 196 test.main() 197