nn_test.py revision b1b2dc893d616c024c5390dae8b2f932c917d7f8
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tests for miscellaneous functionality in tensorflow.ops.nn."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23import numpy as np
24from six.moves import xrange  # pylint: disable=redefined-builtin
25import tensorflow as tf
26
27
28class ZeroFractionTest(tf.test.TestCase):
29
30  def _ZeroFraction(self, x):
31    assert x.shape
32    total_elements = np.prod(x.shape)
33    nonzeros = np.count_nonzero(x.flatten())
34    return 1.0 - nonzeros / total_elements
35
36  def testZeroFraction(self):
37    x_shape = [5, 17]
38    x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32)
39    y_np = self._ZeroFraction(x_np)
40    with self.test_session():
41      x_tf = tf.constant(x_np)
42      x_tf.set_shape(x_shape)
43      y_tf = tf.nn.zero_fraction(x_tf)
44      y_tf_np = y_tf.eval()
45    eps = 1e-8
46    self.assertAllClose(y_tf_np, y_np, eps)
47
48  def testZeroFractionEmpty(self):
49    with self.test_session():
50      x = np.zeros(0)
51      y = tf.nn.zero_fraction(x).eval()
52      self.assertTrue(np.isnan(y))
53
54
55class SoftmaxTest(tf.test.TestCase):
56
57  def _softmax(self, x):
58    assert len(x.shape) == 2
59    m = x.max(1)[:, np.newaxis]
60    u = np.exp(x - m)
61    z = u.sum(1)[:, np.newaxis]
62    return u / z
63
64  def testSoftmax(self):
65    x_shape = [5, 10]
66    x_np = np.random.randn(*x_shape).astype(np.float32)
67    y_np = self._softmax(x_np)
68    with self.test_session():
69      x_tf = tf.constant(x_np)
70      y_tf = tf.nn.softmax(x_tf)
71      y_tf_np = y_tf.eval()
72    eps = 1e-3
73    self.assertAllClose(y_tf_np, y_np, eps)
74
75  def testGradient(self):
76    x_shape = [5, 10]
77    x_np = np.random.randn(*x_shape).astype(np.float64)
78    with self.test_session():
79      x_tf = tf.constant(x_np)
80      y_tf = tf.nn.softmax(x_tf)
81      err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape)
82    eps = 1e-8
83    self.assertLess(err, eps)
84
85
86class LogSoftmaxTest(tf.test.TestCase):
87
88  def _log_softmax(self, x):
89    assert len(x.shape) == 2
90    m = x.max(1)[:, np.newaxis]
91    u = x - m
92    return u - np.log(np.sum(np.exp(u), 1, keepdims=True))
93
94  def testLogSoftmax(self):
95    x_shape = [5, 10]
96    x_np = np.random.randn(*x_shape).astype(np.float32)
97    y_np = self._log_softmax(x_np)
98    with self.test_session():
99      x_tf = tf.constant(x_np)
100      y_tf = tf.nn.log_softmax(x_tf)
101      y_tf_np = y_tf.eval()
102    eps = 1e-3
103    self.assertAllClose(y_tf_np, y_np, eps)
104
105  def testGradient(self):
106    x_shape = [5, 10]
107    x_np = np.random.randn(*x_shape).astype(np.float64)
108    with self.test_session():
109      x_tf = tf.constant(x_np)
110      y_tf = tf.nn.log_softmax(x_tf)
111      err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape)
112    eps = 1e-7
113    self.assertLess(err, eps)
114
115
116class L2LossTest(tf.test.TestCase):
117
118  def testL2Loss(self):
119    for dtype in [tf.float32, tf.float64]:
120      with self.test_session():
121        x = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x",
122                        dtype=dtype)
123        l2loss = tf.nn.l2_loss(x)
124        value = l2loss.eval()
125      self.assertAllClose(7.0, value)
126
127  def testGradient(self):
128    x_shape = [20, 7, 3]
129    np.random.seed(1)  # Make it reproducible.
130    x_val = np.random.random_sample(x_shape).astype(np.float64)
131    with self.test_session():
132      x = tf.constant(x_val, name="x")
133      output = tf.nn.l2_loss(x)
134      err = tf.test.compute_gradient_error(x, x_shape, output, [1])
135    print("L2Loss gradient err = %g " % err)
136    err_tolerance = 1e-11
137    self.assertLess(err, err_tolerance)
138
139
140class L2NormalizeTest(tf.test.TestCase):
141
142  def _l2Normalize(self, x, dim):
143    norm = np.apply_along_axis(np.linalg.norm, dim, x)
144    return x / np.expand_dims(norm, dim)
145
146  def testL2Normalize(self):
147    x_shape = [20, 7, 3]
148    np.random.seed(1)
149    x_np = np.random.random_sample(x_shape).astype(np.float32)
150    for dim in range(len(x_shape)):
151      y_np = self._l2Normalize(x_np, dim)
152      with self.test_session():
153        x_tf = tf.constant(x_np, name="x")
154        y_tf = tf.nn.l2_normalize(x_tf, dim)
155        self.assertAllClose(y_np, y_tf.eval())
156
157  def testL2NormalizeGradient(self):
158    x_shape = [20, 7, 3]
159    np.random.seed(1)
160    x_np = np.random.random_sample(x_shape).astype(np.float64)
161    for dim in range(len(x_shape)):
162      with self.test_session():
163        x_tf = tf.constant(x_np, name="x")
164        y_tf = tf.nn.l2_normalize(x_tf, dim)
165        err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape)
166      print("L2Normalize gradient err = %g " % err)
167      self.assertLess(err, 1e-4)
168
169
170class DropoutTest(tf.test.TestCase):
171
172  def testDropout(self):
173    # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
174    # that it is producing approximately the right number of ones over a large
175    # number of samples, based on the keep probability.
176    x_dim = 40
177    y_dim = 30
178    num_iter = 10
179    for keep_prob in [0.1, 0.5, 0.8]:
180      with self.test_session():
181        t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
182        dropout = tf.nn.dropout(t, keep_prob)
183        final_count = 0
184        self.assertEqual([x_dim, y_dim], dropout.get_shape())
185        for _ in xrange(0, num_iter):
186          value = dropout.eval()
187          final_count += np.count_nonzero(value)
188          # Verifies that there are only two values: 0 and 1/keep_prob.
189          sorted_value = np.unique(np.sort(value))
190          self.assertEqual(0, sorted_value[0])
191          self.assertAllClose(1 / keep_prob, sorted_value[1])
192      # Check that we are in the 15% error range
193      expected_count = x_dim * y_dim * keep_prob * num_iter
194      rel_error = math.fabs(final_count - expected_count) / expected_count
195      print(rel_error)
196      self.assertTrue(rel_error < 0.15)
197
198  def testShapedDropout(self):
199    # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
200    # that it is producing approximately the right number of ones over a large
201    # number of samples, based on the keep probability. This time with shaped
202    # noise.
203    x_dim = 40 * 30
204    y_dim = 3
205    num_iter = 10
206    for keep_prob in [0.1, 0.5, 0.8]:
207      with self.test_session():
208        t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
209        dropout = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
210        self.assertEqual([x_dim, y_dim], dropout.get_shape())
211        final_count = 0
212        for _ in xrange(0, num_iter):
213          value = dropout.eval()
214          final_count += np.count_nonzero(value)
215          # Verifies that there are only two values: 0 and 1/keep_prob.
216          sorted_value = np.unique(np.sort(value))
217          self.assertEqual(0, sorted_value[0])
218          self.assertAllClose(1 / keep_prob, sorted_value[1])
219      # Check that we are in the 15% error range
220      expected_count = x_dim * y_dim * keep_prob * num_iter
221      rel_error = math.fabs(final_count - expected_count) / expected_count
222      print(rel_error)
223      self.assertTrue(rel_error < 0.15)
224
225  def testShapedDropoutCorrelation(self):
226    # Runs a shaped dropout and tests that the correlations are correct.
227    x_dim = 40
228    y_dim = 30
229    num_iter = 10
230    for keep_prob in [0.1, 0.5, 0.8]:
231      with self.test_session():
232        t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
233        dropout = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
234        self.assertEqual([x_dim, y_dim], dropout.get_shape())
235        for _ in xrange(0, num_iter):
236          value = dropout.eval()
237          # Verifies that each y column as only one type of activation.
238          for i in xrange(x_dim):
239            sorted_value = np.unique(np.sort(value[i, :]))
240            self.assertEqual(sorted_value.size, 1)
241
242  def testDropoutPlaceholderKeepProb(self):
243    # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
244    # that it is producing approximately the right number of ones over a large
245    # number of samples, based on the keep probability.
246    x_dim = 40
247    y_dim = 30
248    num_iter = 10
249    for keep_prob in [0.1, 0.5, 0.8]:
250      with self.test_session():
251        t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
252        keep_prob_placeholder = tf.placeholder(tf.float32)
253        dropout = tf.nn.dropout(t, keep_prob_placeholder)
254        final_count = 0
255        self.assertEqual([x_dim, y_dim], dropout.get_shape())
256        for _ in xrange(0, num_iter):
257          value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob})
258          final_count += np.count_nonzero(value)
259          # Verifies that there are only two values: 0 and 1/keep_prob.
260          sorted_value = np.unique(np.sort(value))
261          self.assertEqual(0, sorted_value[0])
262          self.assertAllClose(1 / keep_prob, sorted_value[1])
263      # Check that we are in the 15% error range
264      expected_count = x_dim * y_dim * keep_prob * num_iter
265      rel_error = math.fabs(final_count - expected_count) / expected_count
266      print(rel_error)
267      self.assertTrue(rel_error < 0.15)
268
269  def testShapedDropoutUnknownShape(self):
270    x_dim = 40
271    y_dim = 30
272    keep_prob = 0.5
273    x = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
274    dropout_x = tf.nn.dropout(x,
275                              keep_prob,
276                              noise_shape=tf.placeholder(tf.int32))
277    self.assertEqual(x.get_shape(), dropout_x.get_shape())
278
279  def testInvalidKeepProb(self):
280    x_dim = 40
281    y_dim = 30
282    t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
283    with self.assertRaises(ValueError):
284      tf.nn.dropout(t, -1.0)
285    with self.assertRaises(ValueError):
286      tf.nn.dropout(t, 1.1)
287    with self.assertRaises(ValueError):
288      tf.nn.dropout(t, [0.0, 1.0])
289    with self.assertRaises(ValueError):
290      tf.nn.dropout(t, tf.placeholder(tf.float64))
291    with self.assertRaises(ValueError):
292      tf.nn.dropout(t, tf.placeholder(tf.float32, shape=[2]))
293
294  def testShapedDropoutShapeError(self):
295    # Runs shaped dropout and verifies an error is thrown on misshapen noise.
296    x_dim = 40
297    y_dim = 30
298    keep_prob = 0.5
299    t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
300    with self.assertRaises(ValueError):
301      _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
302    with self.assertRaises(ValueError):
303      _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
304    with self.assertRaises(ValueError):
305      _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim + 3])
306    with self.assertRaises(ValueError):
307      _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim])
308    # test that broadcasting proceeds
309    _ = tf.nn.dropout(t, keep_prob, noise_shape=[y_dim])
310    _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, y_dim])
311    _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
312    _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, 1])
313
314
315class ComputeSampledLogitsTest(tf.test.TestCase):
316
317  def setUp(self):
318    self._num_classes = 5
319    self._dim = 10
320    self._batch_size = 3
321    self._num_shards = 3
322
323  def _GenerateTestInputs(self):
324    np.random.seed(0)
325    weights = np.random.randn(self._num_classes, self._dim).astype(np.float32)
326    biases = np.random.randn(self._num_classes).astype(np.float32)
327    hidden_acts = np.random.randn(self._batch_size, self._dim).astype(
328        np.float32)
329    sharded_weights = [
330        weights[[row for row in range(self._num_classes)
331                 if row % self._num_shards == shard]]
332        for shard in range(self._num_shards)]
333    return weights, biases, hidden_acts, sharded_weights
334
335  def _ComputeSampledLogitsNP(self, true_w, true_b, sampled_w, sampled_b,
336                              hidden_acts,
337                              num_true=1,
338                              true_expected=None,
339                              sampled_expected=None):
340
341    batch_size, dim = hidden_acts.shape
342    true_logits = np.sum(
343        hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape(
344            (batch_size, num_true, dim)),
345        axis=2)
346    true_b = true_b.reshape((batch_size, num_true))
347    true_logits += true_b
348    sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b
349
350    if true_expected is not None:
351      true_logits -= np.log(true_expected)
352    if sampled_expected is not None:
353      sampled_logits -= np.log(sampled_expected[np.newaxis, :])
354
355    out_logits = np.concatenate([true_logits, sampled_logits], axis=1)
356    out_labels = np.hstack((np.ones_like(true_logits) / num_true,
357                            np.zeros_like(sampled_logits)))
358
359    return out_logits, out_labels
360
361  def _ComputeSampledLogitsTF(self, weights, biases, hidden_acts, labels,
362                              num_sampled, num_classes, num_true, sampled_vals,
363                              subtract_log_q, remove_accidental_hits,
364                              name="sampled_loss_TF"):
365    # Should be called from within a `with test_session():` block
366    if isinstance(weights, list):
367      weights_tf = [tf.constant(shard) for shard in weights]
368    else:
369      weights_tf = tf.constant(weights)
370    biases_tf = tf.constant(biases)
371    hidden_acts_tf = tf.constant(hidden_acts,
372                                 shape=(self._batch_size, self._dim))
373    labels_tf = tf.constant(labels,
374                            dtype=tf.int64,
375                            shape=(self._batch_size, num_true))
376
377    pred_logits_tf, pred_labels_tf = tf.nn._compute_sampled_logits(
378        weights_tf,
379        biases_tf,
380        hidden_acts_tf,
381        labels_tf,
382        num_sampled,
383        num_classes,
384        num_true,
385        sampled_vals,
386        subtract_log_q=subtract_log_q,
387        remove_accidental_hits=remove_accidental_hits,
388        name=name)
389    return pred_logits_tf, pred_labels_tf
390
391  def testComputeSampledLogitsShapes(self):
392    # We just check that the shapes of the returned values are correct.
393    weights, biases, hidden_acts, _ = self._GenerateTestInputs()
394    sampled = [1, 0, 2, 3]
395    num_sampled = len(sampled)
396    true_exp = sampled_exp = [1., 1., 1., 1.]
397    test_sampled_vals = (sampled, true_exp, sampled_exp)
398    sampled_w, sampled_b = weights[sampled], biases[sampled]
399
400    with self.test_session() as sess:
401      for num_true_test in range(1, 5):
402        labels = np.random.randint(low=0, high=self._num_classes,
403                                   size=self._batch_size * num_true_test)
404        true_w, true_b = weights[labels], biases[labels]
405
406        logits_np, labels_np = self._ComputeSampledLogitsNP(
407            true_w, true_b, sampled_w, sampled_b, hidden_acts,
408            num_true=num_true_test)
409
410        logits_tf, labels_tf = self._ComputeSampledLogitsTF(
411            weights, biases, hidden_acts, labels, num_sampled,
412            self._num_classes,
413            num_true=num_true_test,
414            sampled_vals=test_sampled_vals,
415            remove_accidental_hits=True,
416            subtract_log_q=False)
417
418      logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
419      self.assertEqual(logits_np.shape, logits_tf_val.shape)
420      self.assertEqual(labels_np.shape, labels_tf_val.shape)
421
422  def testComputeSampledLogitsValues(self):
423    # Here we check the actual numerics.
424    weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs()
425    eps = 1e-3
426    sampled = [1, 0, 2, 3]
427    num_sampled = len(sampled)
428    true_exp = np.empty([self._batch_size, 1], dtype=np.float32)
429    true_exp.fill(0.5)
430    sampled_exp = np.empty([num_sampled], dtype=np.float32)
431    sampled_exp.fill(0.5)
432    sampled_w, sampled_b = weights[sampled], biases[sampled]
433    test_sampled_vals = (sampled, true_exp, sampled_exp)
434
435    with self.test_session() as sess:
436      for num_true_test in range(1, 5):
437        # Generate test data for this run
438        labels = np.random.randint(low=0, high=self._num_classes,
439                                   size=self._batch_size * num_true_test)
440        true_w, true_b = weights[labels], biases[labels]
441
442        # Test 1: Without accidental hit removal or subtract_log_q
443        logits_np, labels_np = self._ComputeSampledLogitsNP(
444            true_w, true_b, sampled_w, sampled_b, hidden_acts,
445            num_true=num_true_test)
446        logits_tf, labels_tf = self._ComputeSampledLogitsTF(
447            weights, biases, hidden_acts, labels, num_sampled,
448            self._num_classes,
449            num_true=num_true_test,
450            sampled_vals=test_sampled_vals,
451            subtract_log_q=False,
452            remove_accidental_hits=False,
453            name="sampled_loss_test1_num_true%d" % num_true_test)
454
455        logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
456        self.assertAllClose(logits_np, logits_tf_val, eps)
457        self.assertAllClose(labels_np, labels_tf_val, eps)
458
459        # Test 2: With accidental hit removal, no subtract_log_q
460        logits_tf, labels_tf = self._ComputeSampledLogitsTF(
461            weights, biases, hidden_acts, labels, num_sampled,
462            self._num_classes,
463            num_true=num_true_test,
464            sampled_vals=test_sampled_vals,
465            subtract_log_q=False,
466            remove_accidental_hits=True,
467            name="sampled_loss_test2_num_true%d" % num_true_test)
468
469        # Test that the exponentiated logits of accidental hits are near 0.
470        # First we need to find the hits in this random test run:
471        labels_reshape = labels.reshape((self._batch_size, num_true_test))
472        logits_tf_np = logits_tf.eval()
473        for row in xrange(self._batch_size):
474          row_labels = labels_reshape[row, :]
475          for col in xrange(num_sampled):
476            if sampled[col] in row_labels:
477              # We need to add the num_true_test offset into logits_*
478              self.assertNear(
479                  np.exp(logits_tf_np[row, col + num_true_test]), 0., eps)
480
481        # Test 3: With subtract_log_q, no accidental hit removal
482        logits_np, labels_np = self._ComputeSampledLogitsNP(
483            true_w, true_b, sampled_w, sampled_b, hidden_acts,
484            num_true=num_true_test,
485            true_expected=true_exp,
486            sampled_expected=sampled_exp)
487        logits_tf, labels_tf = self._ComputeSampledLogitsTF(
488            weights, biases, hidden_acts, labels, num_sampled,
489            self._num_classes,
490            num_true=num_true_test,
491            sampled_vals=test_sampled_vals,
492            subtract_log_q=True,
493            remove_accidental_hits=False,
494            name="sampled_loss_test3_num_true%d" % num_true_test)
495
496        logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
497        self.assertAllClose(logits_np, logits_tf_val, eps)
498        self.assertAllClose(labels_np, labels_tf_val, eps)
499
500        # Test 4: Test 1, with sharded weights
501        logits_np, labels_np = self._ComputeSampledLogitsNP(
502            true_w, true_b, sampled_w, sampled_b, hidden_acts,
503            num_true=num_true_test)
504        logits_tf, labels_tf = self._ComputeSampledLogitsTF(
505            sharded_weights, biases, hidden_acts, labels, num_sampled,
506            self._num_classes,
507            num_true=num_true_test,
508            sampled_vals=test_sampled_vals,
509            subtract_log_q=False,
510            remove_accidental_hits=False,
511            name="sampled_loss_test1_num_true%d" % num_true_test)
512
513        logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
514        self.assertAllClose(logits_np, logits_tf_val, eps)
515        self.assertAllClose(labels_np, labels_tf_val, eps)
516
517  def testNCELoss(self):
518    # A simple test to verify the numerics.
519
520    def _SigmoidCrossEntropyWithLogits(logits, targets):
521      # logits, targets: float arrays of the same shape.
522      assert logits.shape == targets.shape
523      pred = 1. / (1. + np.exp(-logits))
524      eps = 0.0001
525      pred = np.minimum(np.maximum(pred, eps), 1 - eps)
526      return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred)
527
528    weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs()
529    labels = [0, 1, 2]
530    true_w, true_b = weights[labels], biases[labels]
531    sampled = [1, 0, 2, 3]
532    num_sampled = len(sampled)
533    true_exp = np.empty([self._batch_size, 1], dtype=np.float32)
534    true_exp.fill(0.5)
535    sampled_exp = np.empty([num_sampled], dtype=np.float32)
536    sampled_exp.fill(0.5)
537    sampled_w, sampled_b = weights[sampled], biases[sampled]
538    test_sampled_vals = (sampled, true_exp, sampled_exp)
539
540    with self.test_session():
541      logits_np, labels_np = self._ComputeSampledLogitsNP(
542          true_w, true_b, sampled_w, sampled_b, hidden_acts,
543          true_expected=true_exp,
544          sampled_expected=sampled_exp)
545      nce_loss_np = np.sum(
546          _SigmoidCrossEntropyWithLogits(logits_np, labels_np), 1)
547
548      labels_tf = tf.constant(labels, shape=(self._batch_size, 1))
549      weights_tf = tf.constant(weights)
550      biases_tf = tf.constant(biases)
551      inputs_tf = tf.constant(hidden_acts)
552
553      nce_loss_tf = tf.nn.nce_loss(weights_tf,
554                                   biases_tf,
555                                   inputs_tf,
556                                   labels_tf,
557                                   num_sampled=1,
558                                   num_classes=self._num_classes,
559                                   num_true=1,
560                                   sampled_values=test_sampled_vals)
561
562      self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4)
563
564      # Test with sharded weights
565      nce_loss_tf = tf.nn.nce_loss(
566          [tf.constant(shard) for shard in sharded_weights],
567          biases_tf,
568          inputs_tf,
569          labels_tf,
570          num_sampled=1,
571          num_classes=self._num_classes,
572          num_true=1,
573          sampled_values=test_sampled_vals)
574
575      self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4)
576
577  def testSampledSoftmaxLoss(self):
578    # A simple test to verify the numerics.
579
580    def _SoftmaxCrossEntropyWithLogits(logits, targets):
581      # logits, targets: float arrays of the same shape.
582      assert logits.shape == targets.shape
583      stable_exp_logits = np.exp(logits - np.amax(
584          logits, axis=1, keepdims=True))
585      pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
586      return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
587
588    weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs()
589    labels = [0, 1, 2]
590    true_w, true_b = weights[labels], biases[labels]
591    sampled = [1, 0, 2, 3]
592    num_sampled = len(sampled)
593    true_exp = np.full([self._batch_size, 1], fill_value=0.5, dtype=np.float32)
594    sampled_exp = np.full([num_sampled], fill_value=0.5, dtype=np.float32)
595    sampled_w, sampled_b = weights[sampled], biases[sampled]
596    test_sampled_vals = (sampled, true_exp, sampled_exp)
597
598    with self.test_session():
599      logits_np, labels_np = self._ComputeSampledLogitsNP(
600          true_w, true_b, sampled_w, sampled_b, hidden_acts,
601          true_expected=true_exp,
602          sampled_expected=sampled_exp)
603      sampled_softmax_loss_np = _SoftmaxCrossEntropyWithLogits(logits_np,
604                                                               labels_np)
605
606      labels_tf = tf.constant(labels, shape=(self._batch_size, 1))
607      weights_tf = tf.constant(weights)
608      biases_tf = tf.constant(biases)
609      inputs_tf = tf.constant(hidden_acts)
610
611      sampled_softmax_loss_tf = tf.nn.sampled_softmax_loss(
612          weights_tf,
613          biases_tf,
614          inputs_tf,
615          labels_tf,
616          num_sampled=1,
617          num_classes=self._num_classes,
618          num_true=1,
619          sampled_values=test_sampled_vals,
620          remove_accidental_hits=False)
621
622      self.assertAllClose(
623          sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4)
624
625      # Test with sharded weights
626      sampled_softmax_loss_tf = tf.nn.sampled_softmax_loss(
627          [tf.constant(shard) for shard in sharded_weights],
628          biases_tf,
629          inputs_tf,
630          labels_tf,
631          num_sampled=1,
632          num_classes=self._num_classes,
633          num_true=1,
634          sampled_values=test_sampled_vals,
635          remove_accidental_hits=False)
636
637      self.assertAllClose(
638          sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4)
639
640
641if __name__ == "__main__":
642  tf.test.main()
643