nn_test.py revision b1b2dc893d616c024c5390dae8b2f932c917d7f8
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tests for miscellaneous functionality in tensorflow.ops.nn.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22 23import numpy as np 24from six.moves import xrange # pylint: disable=redefined-builtin 25import tensorflow as tf 26 27 28class ZeroFractionTest(tf.test.TestCase): 29 30 def _ZeroFraction(self, x): 31 assert x.shape 32 total_elements = np.prod(x.shape) 33 nonzeros = np.count_nonzero(x.flatten()) 34 return 1.0 - nonzeros / total_elements 35 36 def testZeroFraction(self): 37 x_shape = [5, 17] 38 x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32) 39 y_np = self._ZeroFraction(x_np) 40 with self.test_session(): 41 x_tf = tf.constant(x_np) 42 x_tf.set_shape(x_shape) 43 y_tf = tf.nn.zero_fraction(x_tf) 44 y_tf_np = y_tf.eval() 45 eps = 1e-8 46 self.assertAllClose(y_tf_np, y_np, eps) 47 48 def testZeroFractionEmpty(self): 49 with self.test_session(): 50 x = np.zeros(0) 51 y = tf.nn.zero_fraction(x).eval() 52 self.assertTrue(np.isnan(y)) 53 54 55class SoftmaxTest(tf.test.TestCase): 56 57 def _softmax(self, x): 58 assert len(x.shape) == 2 59 m = x.max(1)[:, np.newaxis] 60 u = np.exp(x - m) 61 z = u.sum(1)[:, np.newaxis] 62 return u / z 63 64 def testSoftmax(self): 65 x_shape = [5, 10] 66 x_np = np.random.randn(*x_shape).astype(np.float32) 67 y_np = self._softmax(x_np) 68 with self.test_session(): 69 x_tf = tf.constant(x_np) 70 y_tf = tf.nn.softmax(x_tf) 71 y_tf_np = y_tf.eval() 72 eps = 1e-3 73 self.assertAllClose(y_tf_np, y_np, eps) 74 75 def testGradient(self): 76 x_shape = [5, 10] 77 x_np = np.random.randn(*x_shape).astype(np.float64) 78 with self.test_session(): 79 x_tf = tf.constant(x_np) 80 y_tf = tf.nn.softmax(x_tf) 81 err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape) 82 eps = 1e-8 83 self.assertLess(err, eps) 84 85 86class LogSoftmaxTest(tf.test.TestCase): 87 88 def _log_softmax(self, x): 89 assert len(x.shape) == 2 90 m = x.max(1)[:, np.newaxis] 91 u = x - m 92 return u - np.log(np.sum(np.exp(u), 1, keepdims=True)) 93 94 def testLogSoftmax(self): 95 x_shape = [5, 10] 96 x_np = np.random.randn(*x_shape).astype(np.float32) 97 y_np = self._log_softmax(x_np) 98 with self.test_session(): 99 x_tf = tf.constant(x_np) 100 y_tf = tf.nn.log_softmax(x_tf) 101 y_tf_np = y_tf.eval() 102 eps = 1e-3 103 self.assertAllClose(y_tf_np, y_np, eps) 104 105 def testGradient(self): 106 x_shape = [5, 10] 107 x_np = np.random.randn(*x_shape).astype(np.float64) 108 with self.test_session(): 109 x_tf = tf.constant(x_np) 110 y_tf = tf.nn.log_softmax(x_tf) 111 err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape) 112 eps = 1e-7 113 self.assertLess(err, eps) 114 115 116class L2LossTest(tf.test.TestCase): 117 118 def testL2Loss(self): 119 for dtype in [tf.float32, tf.float64]: 120 with self.test_session(): 121 x = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x", 122 dtype=dtype) 123 l2loss = tf.nn.l2_loss(x) 124 value = l2loss.eval() 125 self.assertAllClose(7.0, value) 126 127 def testGradient(self): 128 x_shape = [20, 7, 3] 129 np.random.seed(1) # Make it reproducible. 130 x_val = np.random.random_sample(x_shape).astype(np.float64) 131 with self.test_session(): 132 x = tf.constant(x_val, name="x") 133 output = tf.nn.l2_loss(x) 134 err = tf.test.compute_gradient_error(x, x_shape, output, [1]) 135 print("L2Loss gradient err = %g " % err) 136 err_tolerance = 1e-11 137 self.assertLess(err, err_tolerance) 138 139 140class L2NormalizeTest(tf.test.TestCase): 141 142 def _l2Normalize(self, x, dim): 143 norm = np.apply_along_axis(np.linalg.norm, dim, x) 144 return x / np.expand_dims(norm, dim) 145 146 def testL2Normalize(self): 147 x_shape = [20, 7, 3] 148 np.random.seed(1) 149 x_np = np.random.random_sample(x_shape).astype(np.float32) 150 for dim in range(len(x_shape)): 151 y_np = self._l2Normalize(x_np, dim) 152 with self.test_session(): 153 x_tf = tf.constant(x_np, name="x") 154 y_tf = tf.nn.l2_normalize(x_tf, dim) 155 self.assertAllClose(y_np, y_tf.eval()) 156 157 def testL2NormalizeGradient(self): 158 x_shape = [20, 7, 3] 159 np.random.seed(1) 160 x_np = np.random.random_sample(x_shape).astype(np.float64) 161 for dim in range(len(x_shape)): 162 with self.test_session(): 163 x_tf = tf.constant(x_np, name="x") 164 y_tf = tf.nn.l2_normalize(x_tf, dim) 165 err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape) 166 print("L2Normalize gradient err = %g " % err) 167 self.assertLess(err, 1e-4) 168 169 170class DropoutTest(tf.test.TestCase): 171 172 def testDropout(self): 173 # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate 174 # that it is producing approximately the right number of ones over a large 175 # number of samples, based on the keep probability. 176 x_dim = 40 177 y_dim = 30 178 num_iter = 10 179 for keep_prob in [0.1, 0.5, 0.8]: 180 with self.test_session(): 181 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 182 dropout = tf.nn.dropout(t, keep_prob) 183 final_count = 0 184 self.assertEqual([x_dim, y_dim], dropout.get_shape()) 185 for _ in xrange(0, num_iter): 186 value = dropout.eval() 187 final_count += np.count_nonzero(value) 188 # Verifies that there are only two values: 0 and 1/keep_prob. 189 sorted_value = np.unique(np.sort(value)) 190 self.assertEqual(0, sorted_value[0]) 191 self.assertAllClose(1 / keep_prob, sorted_value[1]) 192 # Check that we are in the 15% error range 193 expected_count = x_dim * y_dim * keep_prob * num_iter 194 rel_error = math.fabs(final_count - expected_count) / expected_count 195 print(rel_error) 196 self.assertTrue(rel_error < 0.15) 197 198 def testShapedDropout(self): 199 # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate 200 # that it is producing approximately the right number of ones over a large 201 # number of samples, based on the keep probability. This time with shaped 202 # noise. 203 x_dim = 40 * 30 204 y_dim = 3 205 num_iter = 10 206 for keep_prob in [0.1, 0.5, 0.8]: 207 with self.test_session(): 208 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 209 dropout = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) 210 self.assertEqual([x_dim, y_dim], dropout.get_shape()) 211 final_count = 0 212 for _ in xrange(0, num_iter): 213 value = dropout.eval() 214 final_count += np.count_nonzero(value) 215 # Verifies that there are only two values: 0 and 1/keep_prob. 216 sorted_value = np.unique(np.sort(value)) 217 self.assertEqual(0, sorted_value[0]) 218 self.assertAllClose(1 / keep_prob, sorted_value[1]) 219 # Check that we are in the 15% error range 220 expected_count = x_dim * y_dim * keep_prob * num_iter 221 rel_error = math.fabs(final_count - expected_count) / expected_count 222 print(rel_error) 223 self.assertTrue(rel_error < 0.15) 224 225 def testShapedDropoutCorrelation(self): 226 # Runs a shaped dropout and tests that the correlations are correct. 227 x_dim = 40 228 y_dim = 30 229 num_iter = 10 230 for keep_prob in [0.1, 0.5, 0.8]: 231 with self.test_session(): 232 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 233 dropout = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) 234 self.assertEqual([x_dim, y_dim], dropout.get_shape()) 235 for _ in xrange(0, num_iter): 236 value = dropout.eval() 237 # Verifies that each y column as only one type of activation. 238 for i in xrange(x_dim): 239 sorted_value = np.unique(np.sort(value[i, :])) 240 self.assertEqual(sorted_value.size, 1) 241 242 def testDropoutPlaceholderKeepProb(self): 243 # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate 244 # that it is producing approximately the right number of ones over a large 245 # number of samples, based on the keep probability. 246 x_dim = 40 247 y_dim = 30 248 num_iter = 10 249 for keep_prob in [0.1, 0.5, 0.8]: 250 with self.test_session(): 251 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 252 keep_prob_placeholder = tf.placeholder(tf.float32) 253 dropout = tf.nn.dropout(t, keep_prob_placeholder) 254 final_count = 0 255 self.assertEqual([x_dim, y_dim], dropout.get_shape()) 256 for _ in xrange(0, num_iter): 257 value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob}) 258 final_count += np.count_nonzero(value) 259 # Verifies that there are only two values: 0 and 1/keep_prob. 260 sorted_value = np.unique(np.sort(value)) 261 self.assertEqual(0, sorted_value[0]) 262 self.assertAllClose(1 / keep_prob, sorted_value[1]) 263 # Check that we are in the 15% error range 264 expected_count = x_dim * y_dim * keep_prob * num_iter 265 rel_error = math.fabs(final_count - expected_count) / expected_count 266 print(rel_error) 267 self.assertTrue(rel_error < 0.15) 268 269 def testShapedDropoutUnknownShape(self): 270 x_dim = 40 271 y_dim = 30 272 keep_prob = 0.5 273 x = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 274 dropout_x = tf.nn.dropout(x, 275 keep_prob, 276 noise_shape=tf.placeholder(tf.int32)) 277 self.assertEqual(x.get_shape(), dropout_x.get_shape()) 278 279 def testInvalidKeepProb(self): 280 x_dim = 40 281 y_dim = 30 282 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 283 with self.assertRaises(ValueError): 284 tf.nn.dropout(t, -1.0) 285 with self.assertRaises(ValueError): 286 tf.nn.dropout(t, 1.1) 287 with self.assertRaises(ValueError): 288 tf.nn.dropout(t, [0.0, 1.0]) 289 with self.assertRaises(ValueError): 290 tf.nn.dropout(t, tf.placeholder(tf.float64)) 291 with self.assertRaises(ValueError): 292 tf.nn.dropout(t, tf.placeholder(tf.float32, shape=[2])) 293 294 def testShapedDropoutShapeError(self): 295 # Runs shaped dropout and verifies an error is thrown on misshapen noise. 296 x_dim = 40 297 y_dim = 30 298 keep_prob = 0.5 299 t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32) 300 with self.assertRaises(ValueError): 301 _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10]) 302 with self.assertRaises(ValueError): 303 _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5]) 304 with self.assertRaises(ValueError): 305 _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim + 3]) 306 with self.assertRaises(ValueError): 307 _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim]) 308 # test that broadcasting proceeds 309 _ = tf.nn.dropout(t, keep_prob, noise_shape=[y_dim]) 310 _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, y_dim]) 311 _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) 312 _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, 1]) 313 314 315class ComputeSampledLogitsTest(tf.test.TestCase): 316 317 def setUp(self): 318 self._num_classes = 5 319 self._dim = 10 320 self._batch_size = 3 321 self._num_shards = 3 322 323 def _GenerateTestInputs(self): 324 np.random.seed(0) 325 weights = np.random.randn(self._num_classes, self._dim).astype(np.float32) 326 biases = np.random.randn(self._num_classes).astype(np.float32) 327 hidden_acts = np.random.randn(self._batch_size, self._dim).astype( 328 np.float32) 329 sharded_weights = [ 330 weights[[row for row in range(self._num_classes) 331 if row % self._num_shards == shard]] 332 for shard in range(self._num_shards)] 333 return weights, biases, hidden_acts, sharded_weights 334 335 def _ComputeSampledLogitsNP(self, true_w, true_b, sampled_w, sampled_b, 336 hidden_acts, 337 num_true=1, 338 true_expected=None, 339 sampled_expected=None): 340 341 batch_size, dim = hidden_acts.shape 342 true_logits = np.sum( 343 hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape( 344 (batch_size, num_true, dim)), 345 axis=2) 346 true_b = true_b.reshape((batch_size, num_true)) 347 true_logits += true_b 348 sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b 349 350 if true_expected is not None: 351 true_logits -= np.log(true_expected) 352 if sampled_expected is not None: 353 sampled_logits -= np.log(sampled_expected[np.newaxis, :]) 354 355 out_logits = np.concatenate([true_logits, sampled_logits], axis=1) 356 out_labels = np.hstack((np.ones_like(true_logits) / num_true, 357 np.zeros_like(sampled_logits))) 358 359 return out_logits, out_labels 360 361 def _ComputeSampledLogitsTF(self, weights, biases, hidden_acts, labels, 362 num_sampled, num_classes, num_true, sampled_vals, 363 subtract_log_q, remove_accidental_hits, 364 name="sampled_loss_TF"): 365 # Should be called from within a `with test_session():` block 366 if isinstance(weights, list): 367 weights_tf = [tf.constant(shard) for shard in weights] 368 else: 369 weights_tf = tf.constant(weights) 370 biases_tf = tf.constant(biases) 371 hidden_acts_tf = tf.constant(hidden_acts, 372 shape=(self._batch_size, self._dim)) 373 labels_tf = tf.constant(labels, 374 dtype=tf.int64, 375 shape=(self._batch_size, num_true)) 376 377 pred_logits_tf, pred_labels_tf = tf.nn._compute_sampled_logits( 378 weights_tf, 379 biases_tf, 380 hidden_acts_tf, 381 labels_tf, 382 num_sampled, 383 num_classes, 384 num_true, 385 sampled_vals, 386 subtract_log_q=subtract_log_q, 387 remove_accidental_hits=remove_accidental_hits, 388 name=name) 389 return pred_logits_tf, pred_labels_tf 390 391 def testComputeSampledLogitsShapes(self): 392 # We just check that the shapes of the returned values are correct. 393 weights, biases, hidden_acts, _ = self._GenerateTestInputs() 394 sampled = [1, 0, 2, 3] 395 num_sampled = len(sampled) 396 true_exp = sampled_exp = [1., 1., 1., 1.] 397 test_sampled_vals = (sampled, true_exp, sampled_exp) 398 sampled_w, sampled_b = weights[sampled], biases[sampled] 399 400 with self.test_session() as sess: 401 for num_true_test in range(1, 5): 402 labels = np.random.randint(low=0, high=self._num_classes, 403 size=self._batch_size * num_true_test) 404 true_w, true_b = weights[labels], biases[labels] 405 406 logits_np, labels_np = self._ComputeSampledLogitsNP( 407 true_w, true_b, sampled_w, sampled_b, hidden_acts, 408 num_true=num_true_test) 409 410 logits_tf, labels_tf = self._ComputeSampledLogitsTF( 411 weights, biases, hidden_acts, labels, num_sampled, 412 self._num_classes, 413 num_true=num_true_test, 414 sampled_vals=test_sampled_vals, 415 remove_accidental_hits=True, 416 subtract_log_q=False) 417 418 logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf]) 419 self.assertEqual(logits_np.shape, logits_tf_val.shape) 420 self.assertEqual(labels_np.shape, labels_tf_val.shape) 421 422 def testComputeSampledLogitsValues(self): 423 # Here we check the actual numerics. 424 weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() 425 eps = 1e-3 426 sampled = [1, 0, 2, 3] 427 num_sampled = len(sampled) 428 true_exp = np.empty([self._batch_size, 1], dtype=np.float32) 429 true_exp.fill(0.5) 430 sampled_exp = np.empty([num_sampled], dtype=np.float32) 431 sampled_exp.fill(0.5) 432 sampled_w, sampled_b = weights[sampled], biases[sampled] 433 test_sampled_vals = (sampled, true_exp, sampled_exp) 434 435 with self.test_session() as sess: 436 for num_true_test in range(1, 5): 437 # Generate test data for this run 438 labels = np.random.randint(low=0, high=self._num_classes, 439 size=self._batch_size * num_true_test) 440 true_w, true_b = weights[labels], biases[labels] 441 442 # Test 1: Without accidental hit removal or subtract_log_q 443 logits_np, labels_np = self._ComputeSampledLogitsNP( 444 true_w, true_b, sampled_w, sampled_b, hidden_acts, 445 num_true=num_true_test) 446 logits_tf, labels_tf = self._ComputeSampledLogitsTF( 447 weights, biases, hidden_acts, labels, num_sampled, 448 self._num_classes, 449 num_true=num_true_test, 450 sampled_vals=test_sampled_vals, 451 subtract_log_q=False, 452 remove_accidental_hits=False, 453 name="sampled_loss_test1_num_true%d" % num_true_test) 454 455 logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf]) 456 self.assertAllClose(logits_np, logits_tf_val, eps) 457 self.assertAllClose(labels_np, labels_tf_val, eps) 458 459 # Test 2: With accidental hit removal, no subtract_log_q 460 logits_tf, labels_tf = self._ComputeSampledLogitsTF( 461 weights, biases, hidden_acts, labels, num_sampled, 462 self._num_classes, 463 num_true=num_true_test, 464 sampled_vals=test_sampled_vals, 465 subtract_log_q=False, 466 remove_accidental_hits=True, 467 name="sampled_loss_test2_num_true%d" % num_true_test) 468 469 # Test that the exponentiated logits of accidental hits are near 0. 470 # First we need to find the hits in this random test run: 471 labels_reshape = labels.reshape((self._batch_size, num_true_test)) 472 logits_tf_np = logits_tf.eval() 473 for row in xrange(self._batch_size): 474 row_labels = labels_reshape[row, :] 475 for col in xrange(num_sampled): 476 if sampled[col] in row_labels: 477 # We need to add the num_true_test offset into logits_* 478 self.assertNear( 479 np.exp(logits_tf_np[row, col + num_true_test]), 0., eps) 480 481 # Test 3: With subtract_log_q, no accidental hit removal 482 logits_np, labels_np = self._ComputeSampledLogitsNP( 483 true_w, true_b, sampled_w, sampled_b, hidden_acts, 484 num_true=num_true_test, 485 true_expected=true_exp, 486 sampled_expected=sampled_exp) 487 logits_tf, labels_tf = self._ComputeSampledLogitsTF( 488 weights, biases, hidden_acts, labels, num_sampled, 489 self._num_classes, 490 num_true=num_true_test, 491 sampled_vals=test_sampled_vals, 492 subtract_log_q=True, 493 remove_accidental_hits=False, 494 name="sampled_loss_test3_num_true%d" % num_true_test) 495 496 logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf]) 497 self.assertAllClose(logits_np, logits_tf_val, eps) 498 self.assertAllClose(labels_np, labels_tf_val, eps) 499 500 # Test 4: Test 1, with sharded weights 501 logits_np, labels_np = self._ComputeSampledLogitsNP( 502 true_w, true_b, sampled_w, sampled_b, hidden_acts, 503 num_true=num_true_test) 504 logits_tf, labels_tf = self._ComputeSampledLogitsTF( 505 sharded_weights, biases, hidden_acts, labels, num_sampled, 506 self._num_classes, 507 num_true=num_true_test, 508 sampled_vals=test_sampled_vals, 509 subtract_log_q=False, 510 remove_accidental_hits=False, 511 name="sampled_loss_test1_num_true%d" % num_true_test) 512 513 logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf]) 514 self.assertAllClose(logits_np, logits_tf_val, eps) 515 self.assertAllClose(labels_np, labels_tf_val, eps) 516 517 def testNCELoss(self): 518 # A simple test to verify the numerics. 519 520 def _SigmoidCrossEntropyWithLogits(logits, targets): 521 # logits, targets: float arrays of the same shape. 522 assert logits.shape == targets.shape 523 pred = 1. / (1. + np.exp(-logits)) 524 eps = 0.0001 525 pred = np.minimum(np.maximum(pred, eps), 1 - eps) 526 return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred) 527 528 weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() 529 labels = [0, 1, 2] 530 true_w, true_b = weights[labels], biases[labels] 531 sampled = [1, 0, 2, 3] 532 num_sampled = len(sampled) 533 true_exp = np.empty([self._batch_size, 1], dtype=np.float32) 534 true_exp.fill(0.5) 535 sampled_exp = np.empty([num_sampled], dtype=np.float32) 536 sampled_exp.fill(0.5) 537 sampled_w, sampled_b = weights[sampled], biases[sampled] 538 test_sampled_vals = (sampled, true_exp, sampled_exp) 539 540 with self.test_session(): 541 logits_np, labels_np = self._ComputeSampledLogitsNP( 542 true_w, true_b, sampled_w, sampled_b, hidden_acts, 543 true_expected=true_exp, 544 sampled_expected=sampled_exp) 545 nce_loss_np = np.sum( 546 _SigmoidCrossEntropyWithLogits(logits_np, labels_np), 1) 547 548 labels_tf = tf.constant(labels, shape=(self._batch_size, 1)) 549 weights_tf = tf.constant(weights) 550 biases_tf = tf.constant(biases) 551 inputs_tf = tf.constant(hidden_acts) 552 553 nce_loss_tf = tf.nn.nce_loss(weights_tf, 554 biases_tf, 555 inputs_tf, 556 labels_tf, 557 num_sampled=1, 558 num_classes=self._num_classes, 559 num_true=1, 560 sampled_values=test_sampled_vals) 561 562 self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4) 563 564 # Test with sharded weights 565 nce_loss_tf = tf.nn.nce_loss( 566 [tf.constant(shard) for shard in sharded_weights], 567 biases_tf, 568 inputs_tf, 569 labels_tf, 570 num_sampled=1, 571 num_classes=self._num_classes, 572 num_true=1, 573 sampled_values=test_sampled_vals) 574 575 self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4) 576 577 def testSampledSoftmaxLoss(self): 578 # A simple test to verify the numerics. 579 580 def _SoftmaxCrossEntropyWithLogits(logits, targets): 581 # logits, targets: float arrays of the same shape. 582 assert logits.shape == targets.shape 583 stable_exp_logits = np.exp(logits - np.amax( 584 logits, axis=1, keepdims=True)) 585 pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) 586 return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) 587 588 weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() 589 labels = [0, 1, 2] 590 true_w, true_b = weights[labels], biases[labels] 591 sampled = [1, 0, 2, 3] 592 num_sampled = len(sampled) 593 true_exp = np.full([self._batch_size, 1], fill_value=0.5, dtype=np.float32) 594 sampled_exp = np.full([num_sampled], fill_value=0.5, dtype=np.float32) 595 sampled_w, sampled_b = weights[sampled], biases[sampled] 596 test_sampled_vals = (sampled, true_exp, sampled_exp) 597 598 with self.test_session(): 599 logits_np, labels_np = self._ComputeSampledLogitsNP( 600 true_w, true_b, sampled_w, sampled_b, hidden_acts, 601 true_expected=true_exp, 602 sampled_expected=sampled_exp) 603 sampled_softmax_loss_np = _SoftmaxCrossEntropyWithLogits(logits_np, 604 labels_np) 605 606 labels_tf = tf.constant(labels, shape=(self._batch_size, 1)) 607 weights_tf = tf.constant(weights) 608 biases_tf = tf.constant(biases) 609 inputs_tf = tf.constant(hidden_acts) 610 611 sampled_softmax_loss_tf = tf.nn.sampled_softmax_loss( 612 weights_tf, 613 biases_tf, 614 inputs_tf, 615 labels_tf, 616 num_sampled=1, 617 num_classes=self._num_classes, 618 num_true=1, 619 sampled_values=test_sampled_vals, 620 remove_accidental_hits=False) 621 622 self.assertAllClose( 623 sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4) 624 625 # Test with sharded weights 626 sampled_softmax_loss_tf = tf.nn.sampled_softmax_loss( 627 [tf.constant(shard) for shard in sharded_weights], 628 biases_tf, 629 inputs_tf, 630 labels_tf, 631 num_sampled=1, 632 num_classes=self._num_classes, 633 num_true=1, 634 sampled_values=test_sampled_vals, 635 remove_accidental_hits=False) 636 637 self.assertAllClose( 638 sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4) 639 640 641if __name__ == "__main__": 642 tf.test.main() 643