nn_test.py revision d53f06b6a3c77160ad955b718a04260674752298
1# Copyright 2015 Google Inc. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tests for tensorflow.ops.nn.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22 23import tensorflow.python.platform 24 25import tensorflow as tf 26import numpy as np 27from six.moves import xrange # pylint: disable=redefined-builtin 28 29from tensorflow.python.ops import gen_nn_ops 30 31exp = math.exp 32log = math.log 33 34 35class SigmoidCrossEntropyWithLogitsTest(tf.test.TestCase): 36 37 def _SigmoidCrossEntropyWithLogits(self, logits, targets): 38 assert len(logits) == len(targets) 39 pred = [1 / (1 + exp(-x)) for x in logits] 40 eps = 0.0001 41 pred = [min(max(p, eps), 1 - eps) for p in pred] 42 return [-z * log(y) - (1 - z) * log(1 - y) for y, z in zip(pred, targets)] 43 44 def _Inputs(self, x=None, y=None, dtype=tf.float64, sizes=None): 45 x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x 46 y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y 47 assert len(x) == len(y) 48 sizes = sizes if sizes else [len(x)] 49 logits = tf.constant(x, shape=sizes, dtype=dtype, name="logits") 50 targets = tf.constant(y, shape=sizes, dtype=dtype, name="targets") 51 losses = np.array(self._SigmoidCrossEntropyWithLogits(x, y)).reshape(*sizes) 52 return logits, targets, losses 53 54 def testConstructionNamed(self): 55 with self.test_session(): 56 logits, targets, _ = self._Inputs() 57 loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, 58 targets, 59 name="mylogistic") 60 self.assertEqual("mylogistic", loss.op.name) 61 62 def testLogisticOutput(self): 63 for use_gpu in [True, False]: 64 with self.test_session(use_gpu=use_gpu): 65 logits, targets, losses = self._Inputs(dtype=tf.float32) 66 loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets) 67 np_loss = np.array(losses).astype(np.float32) 68 tf_loss = loss.eval() 69 self.assertAllClose(np_loss, tf_loss, atol=0.001) 70 71 def testLogisticOutputMultiDim(self): 72 for use_gpu in [True, False]: 73 with self.test_session(use_gpu=use_gpu): 74 logits, targets, losses = self._Inputs(dtype=tf.float32, 75 sizes=[2, 2, 2]) 76 loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets) 77 np_loss = np.array(losses).astype(np.float32) 78 tf_loss = loss.eval() 79 self.assertAllClose(np_loss, tf_loss, atol=0.001) 80 81 def testGradient(self): 82 sizes = [4, 2] 83 with self.test_session(): 84 logits, targets, _ = self._Inputs(sizes=sizes) 85 loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets) 86 err = tf.test.compute_gradient_error(logits, sizes, loss, sizes) 87 print("logistic loss gradient err = ", err) 88 self.assertLess(err, 1e-7) 89 90 91class ZeroFractionTest(tf.test.TestCase): 92 93 def _ZeroFraction(self, x): 94 assert x.shape 95 total_elements = np.prod(x.shape) 96 nonzeros = np.count_nonzero(x.flatten()) 97 return 1.0 - nonzeros / total_elements 98 99 def testZeroFraction(self): 100 x_shape = [5, 17] 101 x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32) 102 y_np = self._ZeroFraction(x_np) 103 with self.test_session(): 104 x_tf = tf.constant(x_np) 105 x_tf.set_shape(x_shape) 106 y_tf = tf.nn.zero_fraction(x_tf) 107 y_tf_np = y_tf.eval() 108 eps = 1e-8 109 self.assertAllClose(y_tf_np, y_np, eps) 110 111 def testZeroFractionEmpty(self): 112 with self.test_session(): 113 x = np.zeros(0) 114 y = tf.nn.zero_fraction(x).eval() 115 self.assertTrue(np.isnan(y)) 116 117 118class SoftmaxTest(tf.test.TestCase): 119 120 def _softmax(self, x): 121 assert len(x.shape) == 2 122 m = x.max(1)[:, np.newaxis] 123 u = np.exp(x - m) 124 z = u.sum(1)[:, np.newaxis] 125 return u / z 126 127 def testSoftmax(self): 128 x_shape = [5, 10] 129 x_np = np.random.randn(*x_shape).astype(np.float32) 130 y_np = self._softmax(x_np) 131 with self.test_session(): 132 x_tf = tf.constant(x_np) 133 y_tf = tf.nn.softmax(x_tf) 134 y_tf_np = y_tf.eval() 135 eps = 1e-3 136 self.assertAllClose(y_tf_np, y_np, eps) 137 138 def testGradient(self): 139 x_shape = [5, 10] 140 x_np = np.random.randn(*x_shape).astype(np.float64) 141 with self.test_session(): 142 x_tf = tf.constant(x_np) 143 y_tf = tf.nn.softmax(x_tf) 144 err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape) 145 eps = 1e-8 146 self.assertLess(err, eps) 147 148 149class Conv2DTransposeTest(tf.test.TestCase): 150 151 def testConv2DTransposeSingleStride(self): 152 with self.test_session(): 153 strides = [1, 1, 1, 1] 154 155 # Input, output: [batch, height, width, depth] 156 x_shape = [2, 6, 4, 3] 157 y_shape = [2, 6, 4, 2] 158 159 # Filter: [kernel_height, kernel_width, output_depth, input_depth] 160 f_shape = [3, 3, 2, 3] 161 162 x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32) 163 f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32) 164 output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides, 165 padding="SAME") 166 value = output.eval() 167 168 # We count the number of cells being added at the locations in the output. 169 # At the center, #cells=kernel_height * kernel_width 170 # At the corners, #cells=ceil(kernel_height/2) * ceil(kernel_width/2) 171 # At the borders, #cells=ceil(kernel_height/2)*kernel_width or 172 # kernel_height * ceil(kernel_width/2) 173 174 for n in xrange(x_shape[0]): 175 for k in xrange(f_shape[2]): 176 for w in xrange(y_shape[2]): 177 for h in xrange(y_shape[1]): 178 target = 4 * 3.0 179 h_in = h > 0 and h < y_shape[1] - 1 180 w_in = w > 0 and w < y_shape[2] - 1 181 if h_in and w_in: 182 target += 5 * 3.0 183 elif h_in or w_in: 184 target += 2 * 3.0 185 self.assertAllClose(target, value[n, h, w, k]) 186 187 def testConv2DTransposeSame(self): 188 with self.test_session(): 189 strides = [1, 2, 2, 1] 190 191 # Input, output: [batch, height, width, depth] 192 x_shape = [2, 6, 4, 3] 193 y_shape = [2, 12, 8, 2] 194 195 # Filter: [kernel_height, kernel_width, output_depth, input_depth] 196 f_shape = [3, 3, 2, 3] 197 198 x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32) 199 f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32) 200 output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides, 201 padding="SAME") 202 value = output.eval() 203 204 for n in xrange(x_shape[0]): 205 for k in xrange(f_shape[2]): 206 for w in xrange(y_shape[2]): 207 for h in xrange(y_shape[1]): 208 target = 3.0 209 # We add a case for locations divisible by the stride. 210 h_in = h % strides[1] == 0 and h > 0 and h < y_shape[1] - 1 211 w_in = w % strides[2] == 0 and w > 0 and w < y_shape[2] - 1 212 if h_in and w_in: 213 target += 9.0 214 elif h_in or w_in: 215 target += 3.0 216 self.assertAllClose(target, value[n, h, w, k]) 217 218 def testConv2DTransposeValid(self): 219 with self.test_session(): 220 strides = [1, 2, 2, 1] 221 222 # Input, output: [batch, height, width, depth] 223 x_shape = [2, 6, 4, 3] 224 y_shape = [2, 13, 9, 2] 225 226 # Filter: [kernel_height, kernel_width, output_depth, input_depth] 227 f_shape = [3, 3, 2, 3] 228 229 x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32) 230 f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32) 231 output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides, 232 padding="VALID") 233 value = output.eval() 234 235 cache_values = np.zeros(y_shape, dtype=np.float32) 236 237 # The amount of padding added 238 pad = 1 239 240 for n in xrange(x_shape[0]): 241 for k in xrange(f_shape[2]): 242 for w in xrange(pad, y_shape[2] - pad): 243 for h in xrange(pad, y_shape[1] - pad): 244 target = 3.0 245 # We add a case for locations divisible by the stride. 246 h_in = h % strides[ 247 1] == 0 and h > pad and h < y_shape[1] - 1 - pad 248 w_in = w % strides[ 249 2] == 0 and w > pad and w < y_shape[2] - 1 - pad 250 if h_in and w_in: 251 target += 9.0 252 elif h_in or w_in: 253 target += 3.0 254 cache_values[n, h, w, k] = target 255 256 # copy values in the border 257 cache_values[n, :, 0, k] = cache_values[n, :, 1, k] 258 cache_values[n, :, -1, k] = cache_values[n, :, -2, k] 259 cache_values[n, 0, :, k] = cache_values[n, 1, :, k] 260 cache_values[n, -1, :, k] = cache_values[n, -2, :, k] 261 262 self.assertAllClose(cache_values, value) 263 264 def testGradient(self): 265 x_shape = [2, 6, 4, 3] 266 f_shape = [3, 3, 2, 3] 267 y_shape = [2, 12, 8, 2] 268 strides = [1, 2, 2, 1] 269 np.random.seed(1) # Make it reproducible. 270 x_val = np.random.random_sample(x_shape).astype(np.float64) 271 f_val = np.random.random_sample(f_shape).astype(np.float64) 272 with self.test_session(): 273 x = tf.constant(x_val, name="x", dtype=tf.float32) 274 f = tf.constant(f_val, name="f", dtype=tf.float32) 275 output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides, 276 padding="SAME") 277 err = tf.test.compute_gradient_error( 278 [x, f], [x_shape, f_shape], output, y_shape) 279 print("DeConv gradient err = %g " % err) 280 err_tolerance = 0.0005 281 self.assertLess(err, err_tolerance) 282 283 284class L2LossTest(tf.test.TestCase): 285 286 def testL2Loss(self): 287 with self.test_session(): 288 x = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x") 289 l2loss = tf.nn.l2_loss(x) 290 value = l2loss.eval() 291 self.assertAllClose(7.0, value) 292 293 def testGradient(self): 294 x_shape = [20, 7, 3] 295 np.random.seed(1) # Make it reproducible. 296 x_val = np.random.random_sample(x_shape).astype(np.float64) 297 with self.test_session(): 298 x = tf.constant(x_val, name="x") 299 output = tf.nn.l2_loss(x) 300 err = tf.test.compute_gradient_error(x, x_shape, output, [1]) 301 print("L2Loss gradient err = %g " % err) 302 err_tolerance = 1e-11 303 self.assertLess(err, err_tolerance) 304 305 306class L2NormalizeTest(tf.test.TestCase): 307 308 def _l2Normalize(self, x, dim): 309 norm = np.apply_along_axis(np.linalg.norm, dim, x) 310 return x / np.expand_dims(norm, dim) 311 312 def testL2Normalize(self): 313 x_shape = [20, 7, 3] 314 np.random.seed(1) 315 x_np = np.random.random_sample(x_shape).astype(np.float32) 316 for dim in range(len(x_shape)): 317 y_np = self._l2Normalize(x_np, dim) 318 with self.test_session(): 319 x_tf = tf.constant(x_np, name="x") 320 y_tf = tf.nn.l2_normalize(x_tf, dim) 321 self.assertAllClose(y_np, y_tf.eval()) 322 323 def testL2NormalizeGradient(self): 324 x_shape = [20, 7, 3] 325 np.random.seed(1) 326 x_np = np.random.random_sample(x_shape).astype(np.float64) 327 for dim in range(len(x_shape)): 328 with self.test_session(): 329 x_tf = tf.constant(x_np, name="x") 330 y_tf = tf.nn.l2_normalize(x_tf, dim) 331 err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape) 332 print("L2Normalize gradient err = %g " % err) 333 self.assertLess(err, 1e-4) 334 335 336class DropoutTest(tf.test.TestCase): 337 338 def testDropout(self): 339 # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate 340 # that it is producing approximately the right number of ones over a large 341 # number of samples, based on the keep probability. 342 x_dim = 40 343 y_dim = 30 344 num_iter = 10 345 for keep_prob in [0.1, 0.5, 0.8]: 346 with self.test_session(): 347 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 348 dropout = tf.nn.dropout(t, keep_prob) 349 final_count = 0 350 self.assertEqual([x_dim, y_dim], dropout.get_shape()) 351 for _ in xrange(0, num_iter): 352 value = dropout.eval() 353 final_count += np.count_nonzero(value) 354 # Verifies that there are only two values: 0 and 1/keep_prob. 355 sorted_value = np.unique(np.sort(value)) 356 self.assertEqual(0, sorted_value[0]) 357 self.assertAllClose(1 / keep_prob, sorted_value[1]) 358 # Check that we are in the 15% error range 359 expected_count = x_dim * y_dim * keep_prob * num_iter 360 rel_error = math.fabs(final_count - expected_count) / expected_count 361 print(rel_error) 362 self.assertTrue(rel_error < 0.15) 363 364 def testShapedDropout(self): 365 # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate 366 # that it is producing approximately the right number of ones over a large 367 # number of samples, based on the keep probability. This time with shaped 368 # noise. 369 x_dim = 40 * 30 370 y_dim = 3 371 num_iter = 10 372 for keep_prob in [0.1, 0.5, 0.8]: 373 with self.test_session(): 374 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 375 dropout = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) 376 self.assertEqual([x_dim, y_dim], dropout.get_shape()) 377 final_count = 0 378 for _ in xrange(0, num_iter): 379 value = dropout.eval() 380 final_count += np.count_nonzero(value) 381 # Verifies that there are only two values: 0 and 1/keep_prob. 382 sorted_value = np.unique(np.sort(value)) 383 self.assertEqual(0, sorted_value[0]) 384 self.assertAllClose(1 / keep_prob, sorted_value[1]) 385 # Check that we are in the 15% error range 386 expected_count = x_dim * y_dim * keep_prob * num_iter 387 rel_error = math.fabs(final_count - expected_count) / expected_count 388 print(rel_error) 389 self.assertTrue(rel_error < 0.15) 390 391 def testShapedDropoutCorrelation(self): 392 # Runs a shaped dropout and tests that the correlations are correct. 393 x_dim = 40 394 y_dim = 30 395 num_iter = 10 396 for keep_prob in [0.1, 0.5, 0.8]: 397 with self.test_session(): 398 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 399 dropout = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) 400 self.assertEqual([x_dim, y_dim], dropout.get_shape()) 401 for _ in xrange(0, num_iter): 402 value = dropout.eval() 403 # Verifies that each y column as only one type of activation. 404 for i in xrange(x_dim): 405 sorted_value = np.unique(np.sort(value[i, :])) 406 self.assertEqual(sorted_value.size, 1) 407 408 def testDropoutPlaceholderKeepProb(self): 409 # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate 410 # that it is producing approximately the right number of ones over a large 411 # number of samples, based on the keep probability. 412 x_dim = 40 413 y_dim = 30 414 num_iter = 10 415 for keep_prob in [0.1, 0.5, 0.8]: 416 with self.test_session(): 417 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 418 keep_prob_placeholder = tf.placeholder(tf.float32) 419 dropout = tf.nn.dropout(t, keep_prob_placeholder) 420 final_count = 0 421 self.assertEqual([x_dim, y_dim], dropout.get_shape()) 422 for _ in xrange(0, num_iter): 423 value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob}) 424 final_count += np.count_nonzero(value) 425 # Verifies that there are only two values: 0 and 1/keep_prob. 426 sorted_value = np.unique(np.sort(value)) 427 self.assertEqual(0, sorted_value[0]) 428 self.assertAllClose(1 / keep_prob, sorted_value[1]) 429 # Check that we are in the 15% error range 430 expected_count = x_dim * y_dim * keep_prob * num_iter 431 rel_error = math.fabs(final_count - expected_count) / expected_count 432 print(rel_error) 433 self.assertTrue(rel_error < 0.15) 434 435 def testShapedDropoutUnknownShape(self): 436 x_dim = 40 437 y_dim = 30 438 keep_prob = 0.5 439 x = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 440 dropout_x = tf.nn.dropout(x, 441 keep_prob, 442 noise_shape=tf.placeholder(tf.int32)) 443 self.assertEqual(x.get_shape(), dropout_x.get_shape()) 444 445 def testInvalidKeepProb(self): 446 x_dim = 40 447 y_dim = 30 448 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 449 with self.assertRaises(ValueError): 450 tf.nn.dropout(t, -1.0) 451 with self.assertRaises(ValueError): 452 tf.nn.dropout(t, 1.1) 453 with self.assertRaises(ValueError): 454 tf.nn.dropout(t, [0.0, 1.0]) 455 with self.assertRaises(ValueError): 456 tf.nn.dropout(t, tf.placeholder(tf.float64)) 457 with self.assertRaises(ValueError): 458 tf.nn.dropout(t, tf.placeholder(tf.float32, shape=[2])) 459 460 def testShapedDropoutShapeError(self): 461 # Runs shaped dropout and verifies an error is thrown on misshapen noise. 462 x_dim = 40 463 y_dim = 30 464 keep_prob = 0.5 465 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 466 with self.assertRaises(ValueError): 467 _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10]) 468 with self.assertRaises(ValueError): 469 _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5]) 470 with self.assertRaises(ValueError): 471 _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim + 3]) 472 with self.assertRaises(ValueError): 473 _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim]) 474 # test that broadcasting proceeds 475 _ = tf.nn.dropout(t, keep_prob, noise_shape=[y_dim]) 476 _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, y_dim]) 477 _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) 478 _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, 1]) 479 480 481class BatchNormWithGlobalNormalizationTest(tf.test.TestCase): 482 483 def _npBatchNorm(self, x, m, v, beta, gamma, epsilon, 484 scale_after_normalization): 485 y = (x - m) / np.sqrt(v + epsilon) 486 y = y * gamma if scale_after_normalization else y 487 y += beta 488 return y 489 490 def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon, 491 scale_after_normalization): 492 y = (x - m) * tf.rsqrt(v + epsilon) 493 if scale_after_normalization: 494 y = gamma * y 495 y += beta 496 return y 497 498 def testBatchNorm(self): 499 x_shape = [3, 5, 4, 2] 500 param_shape = [2] 501 x_val = np.random.random_sample(x_shape).astype(np.float32) 502 m_val = np.random.random_sample(param_shape).astype(np.float32) 503 v_val = np.random.random_sample(param_shape).astype(np.float32) 504 beta_val = np.random.random_sample(param_shape).astype(np.float32) 505 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 506 for use_gpu in [True, False]: 507 with self.test_session(use_gpu=use_gpu) as sess: 508 x = tf.constant(x_val, name="x") 509 m = tf.constant(m_val, name="m") 510 v = tf.constant(v_val, name="v") 511 beta = tf.constant(beta_val, name="beta") 512 gamma = tf.constant(gamma_val, name="gamma") 513 epsilon = 0.001 514 for scale_after_normalization in [True, False]: 515 bn = tf.nn.batch_norm_with_global_normalization( 516 x, m, v, beta, gamma, epsilon, scale_after_normalization) 517 on = self._opsBatchNorm( 518 x, m, v, beta, gamma, epsilon, scale_after_normalization) 519 np_batch_norm = self._npBatchNorm( 520 x_val, m_val, v_val, beta_val, gamma_val, epsilon, 521 scale_after_normalization) 522 tf_batch_norm, ops_batch_norm = sess.run([bn, on]) 523 self.assertAllClose(np_batch_norm, tf_batch_norm, atol=0.000001) 524 self.assertAllClose(np_batch_norm, ops_batch_norm, atol=0.000001) 525 self.assertAllClose(tf_batch_norm, ops_batch_norm, atol=0.000001) 526 527 def _testBatchNormGradient(self, param_index, tag, scale_after_normalization, 528 err_tolerance=1e-11): 529 x_shape = [3, 5, 4, 5] 530 param_shape = [5] 531 np.random.seed(1) # Make it reproducible. 532 x_val = np.random.random_sample(x_shape).astype(np.float64) 533 m_val = np.random.random_sample(param_shape).astype(np.float64) 534 v_val = np.random.random_sample(param_shape).astype(np.float64) 535 beta_val = np.random.random_sample(param_shape).astype(np.float64) 536 gamma_val = np.random.random_sample(param_shape).astype(np.float64) 537 with self.test_session(): 538 x = tf.constant(x_val, name="x") 539 m = tf.constant(m_val, name="m") 540 v = tf.constant(v_val, name="v") 541 beta = tf.constant(beta_val, name="beta") 542 gamma = tf.constant(gamma_val, name="gamma") 543 epsilon = 0.001 544 # If scale_after_normalization is False, backprop for gamma 545 # will be 0. gamma is unchanged. 546 output = tf.nn.batch_norm_with_global_normalization( 547 x, m, v, beta, gamma, epsilon, scale_after_normalization) 548 all_params = [x, m, v, beta, gamma] 549 all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape] 550 err = tf.test.compute_gradient_error( 551 all_params[param_index], all_shapes[param_index], output, x_shape) 552 print("Batch normalization %s gradient %s scale err = " % 553 (tag, "with" if scale_after_normalization else "without"), err) 554 self.assertLess(err, err_tolerance) 555 556 def testBatchNormInputGradient(self): 557 for scale_after_normalization in [True, False]: 558 self._testBatchNormGradient(0, "x", scale_after_normalization) 559 560 def testBatchNormMeanGradient(self): 561 for scale_after_normalization in [True, False]: 562 self._testBatchNormGradient(1, "mean", scale_after_normalization) 563 564 def testBatchNormVarianceGradient(self): 565 for scale_after_normalization in [True, False]: 566 self._testBatchNormGradient(2, "variance", scale_after_normalization, 567 err_tolerance=1e-03) 568 569 def testBatchNormBetaGradient(self): 570 for scale_after_normalization in [True, False]: 571 self._testBatchNormGradient(3, "beta", scale_after_normalization) 572 573 def testBatchNormGammaGradient(self): 574 for scale_after_normalization in [True, False]: 575 self._testBatchNormGradient(4, "gamma", scale_after_normalization) 576 577 def testBatchNormGradImpl(self): 578 x_shape = [7, 5, 4, 6] 579 param_shape = [6] 580 np.random.seed(1) # Make it reproducible. 581 x_val = np.random.random_sample(x_shape).astype(np.float32) 582 m_val = np.random.random_sample(param_shape).astype(np.float32) 583 v_val = np.random.random_sample(param_shape).astype(np.float32) 584 beta_val = np.random.random_sample(param_shape).astype(np.float32) 585 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 586 backprop_val = np.random.random_sample(x_shape).astype(np.float32) 587 for use_gpu in [False, True]: 588 with self.test_session(use_gpu=use_gpu) as sess: 589 x = tf.constant(x_val, name="x") 590 m = tf.constant(m_val, name="m") 591 v = tf.constant(v_val, name="v") 592 beta = tf.constant(beta_val, name="beta") 593 gamma = tf.constant(gamma_val, name="gamma") 594 backprop = tf.constant(backprop_val, name="backprop") 595 epsilon = 0.001 596 for scale_after_normalization in [True, False]: 597 dx, dm, dv, db, dg = ( 598 gen_nn_ops._batch_norm_with_global_normalization_grad( 599 x, m, v, gamma, backprop, epsilon, scale_after_normalization)) 600 on = self._opsBatchNorm( 601 x, m, v, beta, gamma, epsilon, scale_after_normalization) 602 odx, odm, odv, odb, odg = tf.gradients( 603 [on], [x, m, v, beta, gamma], [backprop]) 604 if scale_after_normalization: 605 all_grads = sess.run([dx, dm, dv, db, dg, odx, odm, odv, odb, odg]) 606 to_check = ["dx", "dm", "dv", "db", "dg"] 607 else: 608 all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb]) 609 to_check = ["dx", "dm", "dv", "db"] 610 for i, n in enumerate(to_check): 611 print(n) 612 self.assertAllClose( 613 all_grads[i + len(to_check)], all_grads[i], atol=0.000001) 614 615 616class MomentsTest(tf.test.TestCase): 617 618 def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims): 619 with self.test_session(): 620 # shape = [batch, width, height, depth] 621 assert len(shape) == 4 622 623 x_numpy = np.random.normal(size=shape).astype(np.float32) 624 x = tf.placeholder(tf.float32, shape=[None] * len(shape)) 625 626 mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims) 627 628 num_elements = np.prod([shape[i] for i in axes]) 629 630 ax = tuple(axes) 631 expected_mean = np.sum( 632 x_numpy, axis=ax, keepdims=keep_dims) / num_elements 633 expected_mean_squared = np.multiply(expected_mean, expected_mean) 634 expected_x_squared = np.sum( 635 np.multiply(x_numpy, x_numpy), 636 axis=ax, 637 keepdims=keep_dims) / num_elements 638 expected_variance = expected_x_squared - expected_mean_squared 639 640 # Check that the moments are correct. 641 self.assertAllClose(expected_mean, mean.eval(feed_dict={x: x_numpy})) 642 self.assertAllClose(expected_variance, var.eval(feed_dict={x: x_numpy})) 643 644 def RunMomentTest(self, shape, axes, keep_dims): 645 with self.test_session(): 646 # shape = [batch, width, height, depth] 647 assert len(shape) == 4 648 649 x_numpy = np.random.normal(size=shape).astype(np.float32) 650 x = tf.constant(x_numpy) 651 652 mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims) 653 654 num_elements = np.prod([shape[i] for i in axes]) 655 656 ax = tuple(axes) 657 expected_mean = np.sum( 658 x_numpy, axis=ax, keepdims=keep_dims) / num_elements 659 expected_mean_squared = np.multiply(expected_mean, expected_mean) 660 expected_x_squared = np.sum( 661 np.multiply(x_numpy, x_numpy), 662 axis=ax, 663 keepdims=keep_dims) / num_elements 664 expected_variance = expected_x_squared - expected_mean_squared 665 666 # Check that the moments are correct. 667 self.assertAllClose(expected_mean, mean.eval()) 668 self.assertAllClose(expected_variance, var.eval()) 669 670 def testBasic(self): 671 for keep_dims in [False, True]: 672 self.RunMomentTest(shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims) 673 self.RunMomentTestWithDynamicShape( 674 shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims) 675 676 def testGlobalNormalization(self): 677 for keep_dims in [False, True]: 678 self.RunMomentTest( 679 shape=[2, 3, 5, 4], axes=[0, 1, 2], keep_dims=keep_dims) 680 self.RunMomentTestWithDynamicShape( 681 shape=[2, 3, 5, 4], axes=[0, 1, 2], keep_dims=keep_dims) 682 683 def testAxes(self): 684 for keep_dims in [False, True]: 685 self.RunMomentTest( 686 shape=[2, 3, 5, 4], axes=[1, 2, 3], keep_dims=keep_dims) 687 self.RunMomentTestWithDynamicShape( 688 shape=[2, 3, 5, 4], axes=[1, 2, 3], keep_dims=keep_dims) 689 690 def _testGlobalGradient(self, from_y="mean"): 691 with self.test_session(): 692 x_shape = [3, 5, 4, 2] 693 x_val = np.random.random_sample(x_shape).astype(np.float64) 694 x = tf.constant(x_val) 695 x.set_shape(x_shape) 696 697 axes = [0, 1, 2] 698 y_shape = [2] # Depth of x 699 out_mean, out_var = tf.nn.moments(x, axes) 700 if from_y == "mean": 701 y = out_mean 702 elif from_y == "var": 703 y = out_var 704 err = tf.test.compute_gradient_error(x, x_shape, y, y_shape) 705 print("Moments %s gradient err = %g" % (from_y, err)) 706 self.assertLess(err, 1e-11) 707 708 def testMeanGlobalGradient(self): 709 self._testGlobalGradient(from_y="mean") 710 711 def testVarGlobalGradient(self): 712 self._testGlobalGradient(from_y="var") 713 714 715class ComputeSampledLogitsTest(tf.test.TestCase): 716 717 def setUp(self): 718 self._num_classes = 5 719 self._dim = 10 720 self._batch_size = 3 721 self._num_shards = 3 722 723 def _GenerateTestInputs(self): 724 np.random.seed(0) 725 weights = np.random.randn(self._num_classes, self._dim).astype(np.float32) 726 biases = np.random.randn(self._num_classes).astype(np.float32) 727 hidden_acts = np.random.randn(self._batch_size, self._dim).astype( 728 np.float32) 729 sharded_weights = [ 730 weights[[row for row in range(self._num_classes) 731 if row % self._num_shards == shard]] 732 for shard in range(self._num_shards)] 733 return weights, biases, hidden_acts, sharded_weights 734 735 def _ComputeSampledLogitsNP(self, true_w, true_b, sampled_w, sampled_b, 736 hidden_acts, 737 num_true=1, 738 true_expected=None, 739 sampled_expected=None): 740 741 batch_size, dim = hidden_acts.shape 742 true_logits = np.sum( 743 hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape( 744 (batch_size, num_true, dim)), 745 axis=2) 746 true_b = true_b.reshape((batch_size, num_true)) 747 true_logits += true_b 748 sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b 749 750 if true_expected is not None: 751 true_logits -= np.log(true_expected) 752 if sampled_expected is not None: 753 sampled_logits -= np.log(sampled_expected[np.newaxis, :]) 754 755 out_logits = np.concatenate([true_logits, sampled_logits], axis=1) 756 out_labels = np.hstack((np.ones_like(true_logits) / num_true, 757 np.zeros_like(sampled_logits))) 758 759 return out_logits, out_labels 760 761 def _ComputeSampledLogitsTF(self, weights, biases, hidden_acts, labels, 762 num_sampled, num_classes, num_true, sampled_vals, 763 subtract_log_q, remove_accidental_hits, 764 name="sampled_loss_TF"): 765 # Should be called from within a `with test_session():` block 766 if isinstance(weights, list): 767 weights_tf = [tf.constant(shard) for shard in weights] 768 else: 769 weights_tf = tf.constant(weights) 770 biases_tf = tf.constant(biases) 771 hidden_acts_tf = tf.constant(hidden_acts, 772 shape=(self._batch_size, self._dim)) 773 labels_tf = tf.constant(labels, 774 dtype=tf.int64, 775 shape=(self._batch_size, num_true)) 776 777 pred_logits_tf, pred_labels_tf = tf.nn._compute_sampled_logits( 778 weights_tf, 779 biases_tf, 780 hidden_acts_tf, 781 labels_tf, 782 num_sampled, 783 num_classes, 784 num_true, 785 sampled_vals, 786 subtract_log_q=subtract_log_q, 787 remove_accidental_hits=remove_accidental_hits, 788 name=name) 789 return pred_logits_tf, pred_labels_tf 790 791 def testComputeSampledLogitsShapes(self): 792 # We just check that the shapes of the returned values are correct. 793 weights, biases, hidden_acts, _ = self._GenerateTestInputs() 794 sampled = [1, 0, 2, 3] 795 num_sampled = len(sampled) 796 true_exp = sampled_exp = [1., 1., 1., 1.] 797 test_sampled_vals = (sampled, true_exp, sampled_exp) 798 sampled_w, sampled_b = weights[sampled], biases[sampled] 799 800 with self.test_session() as sess: 801 for num_true_test in range(1, 5): 802 labels = np.random.randint(low=0, high=self._num_classes, 803 size=self._batch_size * num_true_test) 804 true_w, true_b = weights[labels], biases[labels] 805 806 logits_np, labels_np = self._ComputeSampledLogitsNP( 807 true_w, true_b, sampled_w, sampled_b, hidden_acts, 808 num_true=num_true_test) 809 810 logits_tf, labels_tf = self._ComputeSampledLogitsTF( 811 weights, biases, hidden_acts, labels, num_sampled, 812 self._num_classes, 813 num_true=num_true_test, 814 sampled_vals=test_sampled_vals, 815 remove_accidental_hits=True, 816 subtract_log_q=False) 817 818 logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf]) 819 self.assertEqual(logits_np.shape, logits_tf_val.shape) 820 self.assertEqual(labels_np.shape, labels_tf_val.shape) 821 822 def testComputeSampledLogitsValues(self): 823 # Here we check the actual numerics. 824 weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() 825 eps = 1e-3 826 sampled = [1, 0, 2, 3] 827 num_sampled = len(sampled) 828 true_exp = np.empty([self._batch_size, 1], dtype=np.float32) 829 true_exp.fill(0.5) 830 sampled_exp = np.empty([num_sampled], dtype=np.float32) 831 sampled_exp.fill(0.5) 832 sampled_w, sampled_b = weights[sampled], biases[sampled] 833 test_sampled_vals = (sampled, true_exp, sampled_exp) 834 835 with self.test_session() as sess: 836 for num_true_test in range(1, 5): 837 # Generate test data for this run 838 labels = np.random.randint(low=0, high=self._num_classes, 839 size=self._batch_size * num_true_test) 840 true_w, true_b = weights[labels], biases[labels] 841 842 # Test 1: Without accidental hit removal or subtract_log_q 843 logits_np, labels_np = self._ComputeSampledLogitsNP( 844 true_w, true_b, sampled_w, sampled_b, hidden_acts, 845 num_true=num_true_test) 846 logits_tf, labels_tf = self._ComputeSampledLogitsTF( 847 weights, biases, hidden_acts, labels, num_sampled, 848 self._num_classes, 849 num_true=num_true_test, 850 sampled_vals=test_sampled_vals, 851 subtract_log_q=False, 852 remove_accidental_hits=False, 853 name="sampled_loss_test1_num_true%d" % num_true_test) 854 855 logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf]) 856 self.assertAllClose(logits_np, logits_tf_val, eps) 857 self.assertAllClose(labels_np, labels_tf_val, eps) 858 859 # Test 2: With accidental hit removal, no subtract_log_q 860 logits_tf, labels_tf = self._ComputeSampledLogitsTF( 861 weights, biases, hidden_acts, labels, num_sampled, 862 self._num_classes, 863 num_true=num_true_test, 864 sampled_vals=test_sampled_vals, 865 subtract_log_q=False, 866 remove_accidental_hits=True, 867 name="sampled_loss_test2_num_true%d" % num_true_test) 868 869 # Test that the exponentiated logits of accidental hits are near 0. 870 # First we need to find the hits in this random test run: 871 labels_reshape = labels.reshape((self._batch_size, num_true_test)) 872 logits_tf_np = logits_tf.eval() 873 for row in xrange(self._batch_size): 874 row_labels = labels_reshape[row, :] 875 for col in xrange(num_sampled): 876 if sampled[col] in row_labels: 877 # We need to add the num_true_test offset into logits_* 878 self.assertNear( 879 np.exp(logits_tf_np[row, col + num_true_test]), 0., eps) 880 881 # Test 3: With subtract_log_q, no accidental hit removal 882 logits_np, labels_np = self._ComputeSampledLogitsNP( 883 true_w, true_b, sampled_w, sampled_b, hidden_acts, 884 num_true=num_true_test, 885 true_expected=true_exp, 886 sampled_expected=sampled_exp) 887 logits_tf, labels_tf = self._ComputeSampledLogitsTF( 888 weights, biases, hidden_acts, labels, num_sampled, 889 self._num_classes, 890 num_true=num_true_test, 891 sampled_vals=test_sampled_vals, 892 subtract_log_q=True, 893 remove_accidental_hits=False, 894 name="sampled_loss_test3_num_true%d" % num_true_test) 895 896 logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf]) 897 self.assertAllClose(logits_np, logits_tf_val, eps) 898 self.assertAllClose(labels_np, labels_tf_val, eps) 899 900 # Test 4: Test 1, with sharded weights 901 logits_np, labels_np = self._ComputeSampledLogitsNP( 902 true_w, true_b, sampled_w, sampled_b, hidden_acts, 903 num_true=num_true_test) 904 logits_tf, labels_tf = self._ComputeSampledLogitsTF( 905 sharded_weights, biases, hidden_acts, labels, num_sampled, 906 self._num_classes, 907 num_true=num_true_test, 908 sampled_vals=test_sampled_vals, 909 subtract_log_q=False, 910 remove_accidental_hits=False, 911 name="sampled_loss_test1_num_true%d" % num_true_test) 912 913 logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf]) 914 self.assertAllClose(logits_np, logits_tf_val, eps) 915 self.assertAllClose(labels_np, labels_tf_val, eps) 916 917 def testNCELoss(self): 918 # A simple test to verify the numerics. 919 920 def _SigmoidCrossEntropyWithLogits(logits, targets): 921 # logits, targets: float arrays of the same shape. 922 assert logits.shape == targets.shape 923 pred = 1. / (1. + np.exp(-logits)) 924 eps = 0.0001 925 pred = np.minimum(np.maximum(pred, eps), 1 - eps) 926 return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred) 927 928 weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() 929 labels = [0, 1, 2] 930 true_w, true_b = weights[labels], biases[labels] 931 sampled = [1, 0, 2, 3] 932 num_sampled = len(sampled) 933 true_exp = np.empty([self._batch_size, 1], dtype=np.float32) 934 true_exp.fill(0.5) 935 sampled_exp = np.empty([num_sampled], dtype=np.float32) 936 sampled_exp.fill(0.5) 937 sampled_w, sampled_b = weights[sampled], biases[sampled] 938 test_sampled_vals = (sampled, true_exp, sampled_exp) 939 940 with self.test_session(): 941 logits_np, labels_np = self._ComputeSampledLogitsNP( 942 true_w, true_b, sampled_w, sampled_b, hidden_acts, 943 true_expected=true_exp, 944 sampled_expected=sampled_exp) 945 nce_loss_np = np.sum( 946 _SigmoidCrossEntropyWithLogits(logits_np, labels_np), 1) 947 948 labels_tf = tf.constant(labels, shape=(self._batch_size, 1)) 949 weights_tf = tf.constant(weights) 950 biases_tf = tf.constant(biases) 951 inputs_tf = tf.constant(hidden_acts) 952 953 nce_loss_tf = tf.nn.nce_loss(weights_tf, 954 biases_tf, 955 inputs_tf, 956 labels_tf, 957 num_sampled=1, 958 num_classes=self._num_classes, 959 num_true=1, 960 sampled_values=test_sampled_vals) 961 962 self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4) 963 964 # Test with sharded weights 965 nce_loss_tf = tf.nn.nce_loss( 966 [tf.constant(shard) for shard in sharded_weights], 967 biases_tf, 968 inputs_tf, 969 labels_tf, 970 num_sampled=1, 971 num_classes=self._num_classes, 972 num_true=1, 973 sampled_values=test_sampled_vals) 974 975 self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4) 976 977 def testSampledSoftmaxLoss(self): 978 # A simple test to verify the numerics. 979 980 def _SoftmaxCrossEntropyWithLogits(logits, targets): 981 # logits, targets: float arrays of the same shape. 982 assert logits.shape == targets.shape 983 stable_exp_logits = np.exp(logits - np.amax( 984 logits, axis=1, keepdims=True)) 985 pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) 986 return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) 987 988 weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() 989 labels = [0, 1, 2] 990 true_w, true_b = weights[labels], biases[labels] 991 sampled = [1, 0, 2, 3] 992 num_sampled = len(sampled) 993 true_exp = np.full([self._batch_size, 1], fill_value=0.5, dtype=np.float32) 994 sampled_exp = np.full([num_sampled], fill_value=0.5, dtype=np.float32) 995 sampled_w, sampled_b = weights[sampled], biases[sampled] 996 test_sampled_vals = (sampled, true_exp, sampled_exp) 997 998 with self.test_session(): 999 logits_np, labels_np = self._ComputeSampledLogitsNP( 1000 true_w, true_b, sampled_w, sampled_b, hidden_acts, 1001 true_expected=true_exp, 1002 sampled_expected=sampled_exp) 1003 sampled_softmax_loss_np = _SoftmaxCrossEntropyWithLogits(logits_np, 1004 labels_np) 1005 1006 labels_tf = tf.constant(labels, shape=(self._batch_size, 1)) 1007 weights_tf = tf.constant(weights) 1008 biases_tf = tf.constant(biases) 1009 inputs_tf = tf.constant(hidden_acts) 1010 1011 sampled_softmax_loss_tf = tf.nn.sampled_softmax_loss( 1012 weights_tf, 1013 biases_tf, 1014 inputs_tf, 1015 labels_tf, 1016 num_sampled=1, 1017 num_classes=self._num_classes, 1018 num_true=1, 1019 sampled_values=test_sampled_vals, 1020 remove_accidental_hits=False) 1021 1022 self.assertAllClose( 1023 sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4) 1024 1025 # Test with sharded weights 1026 sampled_softmax_loss_tf = tf.nn.sampled_softmax_loss( 1027 [tf.constant(shard) for shard in sharded_weights], 1028 biases_tf, 1029 inputs_tf, 1030 labels_tf, 1031 num_sampled=1, 1032 num_classes=self._num_classes, 1033 num_true=1, 1034 sampled_values=test_sampled_vals, 1035 remove_accidental_hits=False) 1036 1037 self.assertAllClose( 1038 sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4) 1039 1040 1041if __name__ == "__main__": 1042 tf.test.main() 1043