nn_test.py revision bce6216610d57f8f4b1e9e79836737df109c4e42
1# Copyright 2015 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tests for tensorflow.ops.nn."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23import tensorflow as tf
24import numpy as np
25from six.moves import xrange  # pylint: disable=redefined-builtin
26
27from tensorflow.python.ops import gen_nn_ops
28
29exp = math.exp
30log = math.log
31
32
33class SigmoidCrossEntropyWithLogitsTest(tf.test.TestCase):
34
35  def _SigmoidCrossEntropyWithLogits(self, logits, targets):
36    assert len(logits) == len(targets)
37    pred = [1 / (1 + exp(-x)) for x in logits]
38    eps = 0.0001
39    pred = [min(max(p, eps), 1 - eps) for p in pred]
40    return [-z * log(y) - (1 - z) * log(1 - y) for y, z in zip(pred, targets)]
41
42  def _Inputs(self, x=None, y=None, dtype=tf.float64, sizes=None):
43    x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x
44    y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y
45    assert len(x) == len(y)
46    sizes = sizes if sizes else [len(x)]
47    logits = tf.constant(x, shape=sizes, dtype=dtype, name="logits")
48    targets = tf.constant(y, shape=sizes, dtype=dtype, name="targets")
49    losses = np.array(self._SigmoidCrossEntropyWithLogits(x, y)).reshape(*sizes)
50    return logits, targets, losses
51
52  def testConstructionNamed(self):
53    with self.test_session():
54      logits, targets, _ = self._Inputs()
55      loss = tf.nn.sigmoid_cross_entropy_with_logits(logits,
56                                                     targets,
57                                                     name="mylogistic")
58    self.assertEqual("mylogistic", loss.op.name)
59
60  def testLogisticOutput(self):
61    for use_gpu in [True, False]:
62      with self.test_session(use_gpu=use_gpu):
63        logits, targets, losses = self._Inputs(dtype=tf.float32)
64        loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets)
65        np_loss = np.array(losses).astype(np.float32)
66        tf_loss = loss.eval()
67      self.assertAllClose(np_loss, tf_loss, atol=0.001)
68
69  def testLogisticOutputMultiDim(self):
70    for use_gpu in [True, False]:
71      with self.test_session(use_gpu=use_gpu):
72        logits, targets, losses = self._Inputs(dtype=tf.float32,
73                                               sizes=[2, 2, 2])
74        loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets)
75        np_loss = np.array(losses).astype(np.float32)
76        tf_loss = loss.eval()
77      self.assertAllClose(np_loss, tf_loss, atol=0.001)
78
79  def testGradient(self):
80    sizes = [4, 2]
81    with self.test_session():
82      logits, targets, _ = self._Inputs(sizes=sizes)
83      loss = tf.nn.sigmoid_cross_entropy_with_logits(logits, targets)
84      err = tf.test.compute_gradient_error(logits, sizes, loss, sizes)
85    print("logistic loss gradient err = ", err)
86    self.assertLess(err, 1e-7)
87
88
89class ZeroFractionTest(tf.test.TestCase):
90
91  def _ZeroFraction(self, x):
92    assert x.shape
93    total_elements = np.prod(x.shape)
94    nonzeros = np.count_nonzero(x.flatten())
95    return 1.0 - nonzeros / total_elements
96
97  def testZeroFraction(self):
98    x_shape = [5, 17]
99    x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32)
100    y_np = self._ZeroFraction(x_np)
101    with self.test_session():
102      x_tf = tf.constant(x_np)
103      x_tf.set_shape(x_shape)
104      y_tf = tf.nn.zero_fraction(x_tf)
105      y_tf_np = y_tf.eval()
106    eps = 1e-8
107    self.assertAllClose(y_tf_np, y_np, eps)
108
109  def testZeroFractionEmpty(self):
110    with self.test_session():
111      x = np.zeros(0)
112      y = tf.nn.zero_fraction(x).eval()
113      self.assertTrue(np.isnan(y))
114
115
116class SoftmaxTest(tf.test.TestCase):
117
118  def _softmax(self, x):
119    assert len(x.shape) == 2
120    m = x.max(1)[:, np.newaxis]
121    u = np.exp(x - m)
122    z = u.sum(1)[:, np.newaxis]
123    return u / z
124
125  def testSoftmax(self):
126    x_shape = [5, 10]
127    x_np = np.random.randn(*x_shape).astype(np.float32)
128    y_np = self._softmax(x_np)
129    with self.test_session():
130      x_tf = tf.constant(x_np)
131      y_tf = tf.nn.softmax(x_tf)
132      y_tf_np = y_tf.eval()
133    eps = 1e-3
134    self.assertAllClose(y_tf_np, y_np, eps)
135
136  def testGradient(self):
137    x_shape = [5, 10]
138    x_np = np.random.randn(*x_shape).astype(np.float64)
139    with self.test_session():
140      x_tf = tf.constant(x_np)
141      y_tf = tf.nn.softmax(x_tf)
142      err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape)
143    eps = 1e-8
144    self.assertLess(err, eps)
145
146
147class Conv2DTransposeTest(tf.test.TestCase):
148
149  def testConv2DTransposeSingleStride(self):
150    with self.test_session():
151      strides = [1, 1, 1, 1]
152
153      # Input, output: [batch, height, width, depth]
154      x_shape = [2, 6, 4, 3]
155      y_shape = [2, 6, 4, 2]
156
157      # Filter: [kernel_height, kernel_width, output_depth, input_depth]
158      f_shape = [3, 3, 2, 3]
159
160      x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32)
161      f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32)
162      output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides,
163                                      padding="SAME")
164      value = output.eval()
165
166      # We count the number of cells being added at the locations in the output.
167      # At the center, #cells=kernel_height * kernel_width
168      # At the corners, #cells=ceil(kernel_height/2) * ceil(kernel_width/2)
169      # At the borders, #cells=ceil(kernel_height/2)*kernel_width or
170      #                        kernel_height * ceil(kernel_width/2)
171
172      for n in xrange(x_shape[0]):
173        for k in xrange(f_shape[2]):
174          for w in xrange(y_shape[2]):
175            for h in xrange(y_shape[1]):
176              target = 4 * 3.0
177              h_in = h > 0 and h < y_shape[1] - 1
178              w_in = w > 0 and w < y_shape[2] - 1
179              if h_in and w_in:
180                target += 5 * 3.0
181              elif h_in or w_in:
182                target += 2 * 3.0
183              self.assertAllClose(target, value[n, h, w, k])
184
185  def testConv2DTransposeSame(self):
186    with self.test_session():
187      strides = [1, 2, 2, 1]
188
189      # Input, output: [batch, height, width, depth]
190      x_shape = [2, 6, 4, 3]
191      y_shape = [2, 12, 8, 2]
192
193      # Filter: [kernel_height, kernel_width, output_depth, input_depth]
194      f_shape = [3, 3, 2, 3]
195
196      x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32)
197      f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32)
198      output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides,
199                                      padding="SAME")
200      value = output.eval()
201
202      for n in xrange(x_shape[0]):
203        for k in xrange(f_shape[2]):
204          for w in xrange(y_shape[2]):
205            for h in xrange(y_shape[1]):
206              target = 3.0
207              # We add a case for locations divisible by the stride.
208              h_in = h % strides[1] == 0 and h > 0 and h < y_shape[1] - 1
209              w_in = w % strides[2] == 0 and w > 0 and w < y_shape[2] - 1
210              if h_in and w_in:
211                target += 9.0
212              elif h_in or w_in:
213                target += 3.0
214              self.assertAllClose(target, value[n, h, w, k])
215
216  def testConv2DTransposeValid(self):
217    with self.test_session():
218      strides = [1, 2, 2, 1]
219
220      # Input, output: [batch, height, width, depth]
221      x_shape = [2, 6, 4, 3]
222      y_shape = [2, 13, 9, 2]
223
224      # Filter: [kernel_height, kernel_width, output_depth, input_depth]
225      f_shape = [3, 3, 2, 3]
226
227      x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32)
228      f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32)
229      output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides,
230                                      padding="VALID")
231      value = output.eval()
232
233      cache_values = np.zeros(y_shape, dtype=np.float32)
234
235      # The amount of padding added
236      pad = 1
237
238      for n in xrange(x_shape[0]):
239        for k in xrange(f_shape[2]):
240          for w in xrange(pad, y_shape[2] - pad):
241            for h in xrange(pad, y_shape[1] - pad):
242              target = 3.0
243              # We add a case for locations divisible by the stride.
244              h_in = h % strides[
245                  1] == 0 and h > pad and h < y_shape[1] - 1 - pad
246              w_in = w % strides[
247                  2] == 0 and w > pad and w < y_shape[2] - 1 - pad
248              if h_in and w_in:
249                target += 9.0
250              elif h_in or w_in:
251                target += 3.0
252              cache_values[n, h, w, k] = target
253
254          # copy values in the border
255          cache_values[n, :, 0, k] = cache_values[n, :, 1, k]
256          cache_values[n, :, -1, k] = cache_values[n, :, -2, k]
257          cache_values[n, 0, :, k] = cache_values[n, 1, :, k]
258          cache_values[n, -1, :, k] = cache_values[n, -2, :, k]
259
260    self.assertAllClose(cache_values, value)
261
262  def testGradient(self):
263    x_shape = [2, 6, 4, 3]
264    f_shape = [3, 3, 2, 3]
265    y_shape = [2, 12, 8, 2]
266    strides = [1, 2, 2, 1]
267    np.random.seed(1)  # Make it reproducible.
268    x_val = np.random.random_sample(x_shape).astype(np.float64)
269    f_val = np.random.random_sample(f_shape).astype(np.float64)
270    with self.test_session():
271      x = tf.constant(x_val, name="x", dtype=tf.float32)
272      f = tf.constant(f_val, name="f", dtype=tf.float32)
273      output = tf.nn.conv2d_transpose(x, f, y_shape, strides=strides,
274                                      padding="SAME")
275      err = tf.test.compute_gradient_error(
276          [x, f], [x_shape, f_shape], output, y_shape)
277    print("DeConv gradient err = %g " % err)
278    err_tolerance = 0.0005
279    self.assertLess(err, err_tolerance)
280
281
282class L2LossTest(tf.test.TestCase):
283
284  def testL2Loss(self):
285    with self.test_session():
286      x = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="x")
287      l2loss = tf.nn.l2_loss(x)
288      value = l2loss.eval()
289    self.assertAllClose(7.0, value)
290
291  def testGradient(self):
292    x_shape = [20, 7, 3]
293    np.random.seed(1)  # Make it reproducible.
294    x_val = np.random.random_sample(x_shape).astype(np.float64)
295    with self.test_session():
296      x = tf.constant(x_val, name="x")
297      output = tf.nn.l2_loss(x)
298      err = tf.test.compute_gradient_error(x, x_shape, output, [1])
299    print("L2Loss gradient err = %g " % err)
300    err_tolerance = 1e-11
301    self.assertLess(err, err_tolerance)
302
303
304class L2NormalizeTest(tf.test.TestCase):
305
306  def _l2Normalize(self, x, dim):
307    norm = np.apply_along_axis(np.linalg.norm, dim, x)
308    return x / np.expand_dims(norm, dim)
309
310  def testL2Normalize(self):
311    x_shape = [20, 7, 3]
312    np.random.seed(1)
313    x_np = np.random.random_sample(x_shape).astype(np.float32)
314    for dim in range(len(x_shape)):
315      y_np = self._l2Normalize(x_np, dim)
316      with self.test_session():
317        x_tf = tf.constant(x_np, name="x")
318        y_tf = tf.nn.l2_normalize(x_tf, dim)
319        self.assertAllClose(y_np, y_tf.eval())
320
321  def testL2NormalizeGradient(self):
322    x_shape = [20, 7, 3]
323    np.random.seed(1)
324    x_np = np.random.random_sample(x_shape).astype(np.float64)
325    for dim in range(len(x_shape)):
326      with self.test_session():
327        x_tf = tf.constant(x_np, name="x")
328        y_tf = tf.nn.l2_normalize(x_tf, dim)
329        err = tf.test.compute_gradient_error(x_tf, x_shape, y_tf, x_shape)
330      print("L2Normalize gradient err = %g " % err)
331      self.assertLess(err, 1e-4)
332
333
334class DropoutTest(tf.test.TestCase):
335
336  def testDropout(self):
337    # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
338    # that it is producing approximately the right number of ones over a large
339    # number of samples, based on the keep probability.
340    x_dim = 40
341    y_dim = 30
342    num_iter = 10
343    for keep_prob in [0.1, 0.5, 0.8]:
344      with self.test_session():
345        t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
346        dropout = tf.nn.dropout(t, keep_prob)
347        final_count = 0
348        self.assertEqual([x_dim, y_dim], dropout.get_shape())
349        for _ in xrange(0, num_iter):
350          value = dropout.eval()
351          final_count += np.count_nonzero(value)
352          # Verifies that there are only two values: 0 and 1/keep_prob.
353          sorted_value = np.unique(np.sort(value))
354          self.assertEqual(0, sorted_value[0])
355          self.assertAllClose(1 / keep_prob, sorted_value[1])
356      # Check that we are in the 15% error range
357      expected_count = x_dim * y_dim * keep_prob * num_iter
358      rel_error = math.fabs(final_count - expected_count) / expected_count
359      print(rel_error)
360      self.assertTrue(rel_error < 0.15)
361
362  def testShapedDropout(self):
363    # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
364    # that it is producing approximately the right number of ones over a large
365    # number of samples, based on the keep probability. This time with shaped
366    # noise.
367    x_dim = 40 * 30
368    y_dim = 3
369    num_iter = 10
370    for keep_prob in [0.1, 0.5, 0.8]:
371      with self.test_session():
372        t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
373        dropout = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
374        self.assertEqual([x_dim, y_dim], dropout.get_shape())
375        final_count = 0
376        for _ in xrange(0, num_iter):
377          value = dropout.eval()
378          final_count += np.count_nonzero(value)
379          # Verifies that there are only two values: 0 and 1/keep_prob.
380          sorted_value = np.unique(np.sort(value))
381          self.assertEqual(0, sorted_value[0])
382          self.assertAllClose(1 / keep_prob, sorted_value[1])
383      # Check that we are in the 15% error range
384      expected_count = x_dim * y_dim * keep_prob * num_iter
385      rel_error = math.fabs(final_count - expected_count) / expected_count
386      print(rel_error)
387      self.assertTrue(rel_error < 0.15)
388
389  def testShapedDropoutCorrelation(self):
390    # Runs a shaped dropout and tests that the correlations are correct.
391    x_dim = 40
392    y_dim = 30
393    num_iter = 10
394    for keep_prob in [0.1, 0.5, 0.8]:
395      with self.test_session():
396        t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
397        dropout = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
398        self.assertEqual([x_dim, y_dim], dropout.get_shape())
399        for _ in xrange(0, num_iter):
400          value = dropout.eval()
401          # Verifies that each y column as only one type of activation.
402          for i in xrange(x_dim):
403            sorted_value = np.unique(np.sort(value[i, :]))
404            self.assertEqual(sorted_value.size, 1)
405
406  def testDropoutPlaceholderKeepProb(self):
407    # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
408    # that it is producing approximately the right number of ones over a large
409    # number of samples, based on the keep probability.
410    x_dim = 40
411    y_dim = 30
412    num_iter = 10
413    for keep_prob in [0.1, 0.5, 0.8]:
414      with self.test_session():
415        t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
416        keep_prob_placeholder = tf.placeholder(tf.float32)
417        dropout = tf.nn.dropout(t, keep_prob_placeholder)
418        final_count = 0
419        self.assertEqual([x_dim, y_dim], dropout.get_shape())
420        for _ in xrange(0, num_iter):
421          value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob})
422          final_count += np.count_nonzero(value)
423          # Verifies that there are only two values: 0 and 1/keep_prob.
424          sorted_value = np.unique(np.sort(value))
425          self.assertEqual(0, sorted_value[0])
426          self.assertAllClose(1 / keep_prob, sorted_value[1])
427      # Check that we are in the 15% error range
428      expected_count = x_dim * y_dim * keep_prob * num_iter
429      rel_error = math.fabs(final_count - expected_count) / expected_count
430      print(rel_error)
431      self.assertTrue(rel_error < 0.15)
432
433  def testShapedDropoutUnknownShape(self):
434    x_dim = 40
435    y_dim = 30
436    keep_prob = 0.5
437    x = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
438    dropout_x = tf.nn.dropout(x,
439                              keep_prob,
440                              noise_shape=tf.placeholder(tf.int32))
441    self.assertEqual(x.get_shape(), dropout_x.get_shape())
442
443  def testInvalidKeepProb(self):
444    x_dim = 40
445    y_dim = 30
446    t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
447    with self.assertRaises(ValueError):
448      tf.nn.dropout(t, -1.0)
449    with self.assertRaises(ValueError):
450      tf.nn.dropout(t, 1.1)
451    with self.assertRaises(ValueError):
452      tf.nn.dropout(t, [0.0, 1.0])
453    with self.assertRaises(ValueError):
454      tf.nn.dropout(t, tf.placeholder(tf.float64))
455    with self.assertRaises(ValueError):
456      tf.nn.dropout(t, tf.placeholder(tf.float32, shape=[2]))
457
458  def testShapedDropoutShapeError(self):
459    # Runs shaped dropout and verifies an error is thrown on misshapen noise.
460    x_dim = 40
461    y_dim = 30
462    keep_prob = 0.5
463    t = tf.constant(1.0, shape=[x_dim, y_dim], dtype=tf.float32)
464    with self.assertRaises(ValueError):
465      _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
466    with self.assertRaises(ValueError):
467      _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
468    with self.assertRaises(ValueError):
469      _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim + 3])
470    with self.assertRaises(ValueError):
471      _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim])
472    # test that broadcasting proceeds
473    _ = tf.nn.dropout(t, keep_prob, noise_shape=[y_dim])
474    _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, y_dim])
475    _ = tf.nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
476    _ = tf.nn.dropout(t, keep_prob, noise_shape=[1, 1])
477
478
479class BatchNormalizationTest(tf.test.TestCase):
480
481  def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
482                   scale_after_normalization, shift_after_normalization):
483    y = (x - m) / np.sqrt(v + epsilon)
484    y = y * gamma if scale_after_normalization else y
485    return y + beta if shift_after_normalization else y
486
487  def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
488                    scale_after_normalization, shift_after_normalization):
489    y = (x - m) * tf.rsqrt(v + epsilon)
490    if scale_after_normalization:
491      y = gamma * y
492    return y + beta if shift_after_normalization else y
493
494  def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
495                     scale_after_normalization):
496    """Original implementation."""
497    # _batch_norm_with_global_normalization is deprecated in v9
498    tf.get_default_graph().graph_def_versions.producer = 8
499    # pylint: disable=protected-access
500    return gen_nn_ops._batch_norm_with_global_normalization(
501        x, m, v, beta, gamma, epsilon, scale_after_normalization)
502    # pylint: enable=protected-access
503
504  def _tfBatchNormV1BW(self, x, m, v, beta, gamma, epsilon,
505                       scale_after_normalization):
506    """Re-implementation of the original kernel for backward compatibility."""
507    return tf.nn.batch_norm_with_global_normalization(
508        x, m, v, beta, gamma, epsilon, scale_after_normalization)
509
510  def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
511                     scale_after_normalization, shift_after_normalization):
512    """New implementation."""
513    return tf.nn.batch_normalization(
514        x, m, v, beta if shift_after_normalization else None,
515        gamma if scale_after_normalization else None, epsilon)
516
517  def testBatchNorm(self):
518    x_shape = [3, 5, 4, 2]
519    param_shape = [2]
520    x_val = np.random.random_sample(x_shape).astype(np.float32)
521    m_val = np.random.random_sample(param_shape).astype(np.float32)
522    v_val = np.random.random_sample(param_shape).astype(np.float32)
523    beta_val = np.random.random_sample(param_shape).astype(np.float32)
524    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
525    for use_gpu in [True, False]:
526      with self.test_session(use_gpu=use_gpu) as sess:
527        x = tf.constant(x_val, name="x")
528        m = tf.constant(m_val, name="m")
529        v = tf.constant(v_val, name="v")
530        beta = tf.constant(beta_val, name="beta")
531        gamma = tf.constant(gamma_val, name="gamma")
532        epsilon = 0.001
533        for scale_after_normalization in [True, False]:
534          for shift_after_normalization in [True, False]:
535            bn2 = self._tfBatchNormV2(
536                x, m, v, beta, gamma, epsilon, scale_after_normalization,
537                shift_after_normalization)
538            bn1bw = self._tfBatchNormV1BW(
539                x, m, v, beta, gamma, epsilon, scale_after_normalization)
540            bn1 = self._tfBatchNormV1(
541                x, m, v, beta, gamma, epsilon, scale_after_normalization)
542            on = self._opsBatchNorm(
543                x, m, v, beta, gamma, epsilon, scale_after_normalization,
544                shift_after_normalization)
545            np_bn = self._npBatchNorm(
546                x_val, m_val, v_val, beta_val, gamma_val, epsilon,
547                scale_after_normalization, shift_after_normalization)
548            tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
549                [bn2, bn1bw, bn1, on])
550            self.assertAllClose(np_bn, ops_bn, atol=0.000001)
551            self.assertAllClose(np_bn, tf_bn_v2, atol=0.000001)
552            self.assertAllClose(tf_bn_v2, ops_bn, atol=0.000001)
553            # shift_after_normalization=False is not supported in v1.
554            if shift_after_normalization:
555              self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.000001)
556              self.assertAllClose(np_bn, tf_bn_v1, atol=0.000001)
557              self.assertAllClose(tf_bn_v1, ops_bn, atol=0.000001)
558              self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.000001)
559
560  def _testBatchNormGradient(self, param_index, tag, scale_after_normalization,
561                             shift_after_normalization, version,
562                             err_tolerance=1e-11):
563    x_shape = [3, 5, 4, 5]
564    param_shape = [5]
565    np.random.seed(1)  # Make it reproducible.
566    x_val = np.random.random_sample(x_shape).astype(np.float64)
567    m_val = np.random.random_sample(param_shape).astype(np.float64)
568    v_val = np.random.random_sample(param_shape).astype(np.float64)
569    beta_val = np.random.random_sample(param_shape).astype(np.float64)
570    gamma_val = np.random.random_sample(param_shape).astype(np.float64)
571    with self.test_session():
572      x = tf.constant(x_val, name="x")
573      m = tf.constant(m_val, name="m")
574      v = tf.constant(v_val, name="v")
575      beta = tf.constant(beta_val, name="beta")
576      gamma = tf.constant(gamma_val, name="gamma")
577      epsilon = 0.001
578      if version == 1:
579        output = self._tfBatchNormV1(
580            x, m, v, beta, gamma, epsilon, scale_after_normalization)
581      elif version == 2:
582        output = self._tfBatchNormV2(
583            x, m, v, beta, gamma, epsilon, scale_after_normalization,
584            shift_after_normalization)
585      else:
586        print("Invalid version", version)
587        raise
588      all_params = [x, m, v, beta, gamma]
589      all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
590      err = tf.test.compute_gradient_error(
591          all_params[param_index], all_shapes[param_index], output, x_shape)
592    print("Batch normalization v%d %s gradient %s scale and %s shift err = " %
593          (version, tag, "with" if scale_after_normalization else "without",
594           "with" if shift_after_normalization else "without"),
595          err)
596    self.assertLess(err, err_tolerance)
597
598  def _testBatchNormGradientInAllNeedConfigs(
599      self, param_index, tag, err_tolerance=1e-11):
600    for scale_after_normalization in [True, False]:
601      for shift_after_normalization in [True, False]:
602        # shift_after_normalization=False is not supported in version 1.
603        for v in ([1, 2] if shift_after_normalization else [2]):
604          self._testBatchNormGradient(
605              param_index, tag, scale_after_normalization,
606              shift_after_normalization, v, err_tolerance)
607
608  def testBatchNormInputGradient(self):
609    self._testBatchNormGradientInAllNeedConfigs(0, "x")
610
611  def testBatchNormMeanGradient(self):
612    self._testBatchNormGradientInAllNeedConfigs(1, "mean")
613
614  def testBatchNormVarianceGradient(self):
615    self._testBatchNormGradientInAllNeedConfigs(2, "variance",
616                                                err_tolerance=1e-03)
617
618  def testBatchNormBetaGradient(self):
619    # Since beta does not exist when scale_after_normalization=False, we only
620    # test for scale_after_normalization=True.
621    for scale_after_normalization in [True, False]:
622      for v in [1, 2]:
623        self._testBatchNormGradient(3, "beta", scale_after_normalization, True,
624                                    v)
625
626  def testBatchNormGammaGradient(self):
627    # If scale_after_normalization is False, backprop for gamma in v1
628    # will be 0. In version 2 of the API, if scale_after_normalization is False,
629    # gamma is not used at all, and the gradient is None, which displeases the
630    # gradient checker.
631    for scale_after_normalization in [True, False]:
632      self._testBatchNormGradient(4, "gamma", scale_after_normalization, True,
633                                  1)
634    for shift_after_normalization in [True, False]:
635      self._testBatchNormGradient(4, "gamma", True, shift_after_normalization,
636                                  2)
637
638  def testBatchNormGradImpl(self):
639    x_shape = [7, 5, 4, 6]
640    param_shape = [6]
641    np.random.seed(1)  # Make it reproducible.
642    x_val = np.random.random_sample(x_shape).astype(np.float32)
643    m_val = np.random.random_sample(param_shape).astype(np.float32)
644    v_val = np.random.random_sample(param_shape).astype(np.float32)
645    beta_val = np.random.random_sample(param_shape).astype(np.float32)
646    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
647    backprop_val = np.random.random_sample(x_shape).astype(np.float32)
648    for use_gpu in [False, True]:
649      with self.test_session(use_gpu=use_gpu) as sess:
650        x = tf.constant(x_val, name="x")
651        m = tf.constant(m_val, name="m")
652        v = tf.constant(v_val, name="v")
653        beta = tf.constant(beta_val, name="beta")
654        gamma = tf.constant(gamma_val, name="gamma")
655        backprop = tf.constant(backprop_val, name="backprop")
656        epsilon = 0.001
657        for scale_after_normalization in [True, False]:
658          # _batch_norm_with_global_normalization_grad is deprecated in v9
659          tf.get_default_graph().graph_def_versions.producer = 8
660          dx, dm, dv, db, dg = (
661              gen_nn_ops._batch_norm_with_global_normalization_grad(
662              x, m, v, gamma, backprop, epsilon, scale_after_normalization))
663          on = self._opsBatchNorm(
664              x, m, v, beta, gamma, epsilon, scale_after_normalization, True)
665          odx, odm, odv, odb, odg = tf.gradients(
666              [on], [x, m, v, beta, gamma], [backprop])
667          if scale_after_normalization:
668            all_grads = sess.run([dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
669            to_check = ["dx", "dm", "dv", "db", "dg"]
670          else:
671            all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb])
672            to_check = ["dx", "dm", "dv", "db"]
673          for i, _ in enumerate(to_check):
674            self.assertAllClose(
675                all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
676
677  def testBatchNormKeepDims(self):
678    """Test for tf.nn.moments(..., keep_dims=True / False).
679
680    Make sure that parameters with shape (1, 1, 1, depth) yield the same
681    result as parameters with shape (depth)
682    """
683    x_shape = (3, 5, 4, 2)
684    param_shape = (2)
685    keep_dims_param_shape = (1, 1, 1, 2)
686    x_val = np.random.random_sample(x_shape).astype(np.float32)
687    m_val = np.random.random_sample(param_shape).astype(np.float32)
688    v_val = np.random.random_sample(param_shape).astype(np.float32)
689    beta_val = np.random.random_sample(param_shape).astype(np.float32)
690    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
691    for use_gpu in [True, False]:
692      with self.test_session(use_gpu=use_gpu) as sess:
693        x = tf.constant(x_val, name="x")
694        m = tf.constant(m_val, name="m")
695        v = tf.constant(v_val, name="v")
696        beta = tf.constant(beta_val, name="beta")
697        gamma = tf.constant(gamma_val, name="gamma")
698        keep_dims_m = tf.reshape(m, keep_dims_param_shape, name="keep_dims_m")
699        keep_dims_v = tf.reshape(v, keep_dims_param_shape, name="keep_dims_v")
700        keep_dims_beta = tf.reshape(
701            beta, keep_dims_param_shape, name="keep_dims_beta")
702        keep_dims_gamma = tf.reshape(
703            gamma, keep_dims_param_shape, name="keep_dims_gamma")
704        epsilon = 0.001
705        for scale_after_normalization in [True, False]:
706          for shift_after_normalization in [True, False]:
707            bn = self._tfBatchNormV2(
708                x, m, v, beta, gamma, epsilon, scale_after_normalization,
709                shift_after_normalization)
710            keep_dims_bn = self._tfBatchNormV2(
711                x, keep_dims_m, keep_dims_v, keep_dims_beta,
712                keep_dims_gamma, epsilon, scale_after_normalization,
713                shift_after_normalization)
714            tf_batch_norm, keep_dims_tf_batch_norm = sess.run(
715                [bn, keep_dims_bn])
716            self.assertEquals(x_shape, tf_batch_norm.shape)
717            self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape)
718            self.assertAllClose(
719                tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001)
720
721  def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.000001):
722    x_val = np.random.random_sample(x_shape).astype(np.float32)
723    m_val = np.random.random_sample(param_shape).astype(np.float32)
724    v_val = np.random.random_sample(param_shape).astype(np.float32)
725    beta_val = np.random.random_sample(param_shape).astype(np.float32)
726    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
727    for use_gpu in [True, False]:
728      with self.test_session(use_gpu=use_gpu) as sess:
729        x = tf.constant(x_val, name="x")
730        m = tf.constant(m_val, name="m")
731        v = tf.constant(v_val, name="v")
732        beta = tf.constant(beta_val, name="beta")
733        gamma = tf.constant(gamma_val, name="gamma")
734        epsilon = 0.001
735        for scale_after_normalization in [True, False]:
736          for shift_after_normalization in [True, False]:
737            bn = self._tfBatchNormV2(
738                x, m, v, beta, gamma, epsilon, scale_after_normalization,
739                shift_after_normalization)
740            np_batch_norm = self._npBatchNorm(
741                x_val, m_val, v_val, beta_val, gamma_val, epsilon,
742                scale_after_normalization, shift_after_normalization)
743            [tf_batch_norm] = sess.run([bn])
744            self.assertEquals(x_shape, np_batch_norm.shape)
745            self.assertEquals(x_shape, tf_batch_norm.shape)
746            self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
747
748  def testBatchNormArbitraryShapes(self):
749    """Test for a variety of shapes and moments.
750
751    Batch normalization is expected to work regardless of the position and
752    dimensionality of the 'depth' axis/axes.
753    """
754    self._testBatchNormArbitraryShapes((3, 3), (1, 3))
755    self._testBatchNormArbitraryShapes((3, 3), (3, 1))
756    self._testBatchNormArbitraryShapes((3, 2, 4, 5), (1, 2, 1, 1))
757    self._testBatchNormArbitraryShapes((2, 3, 2, 4, 5), (1, 1, 1, 4, 5),
758                                       atol=0.005)
759
760
761class SufficientStatisticsTest(tf.test.TestCase):
762
763  def _npSuffStats(self, x, axes, shift, keep_dims):
764    axis = tuple(axes)
765    if shift:
766      shift_value = x[[slice(None) if i not in set(axis) else slice(0, 1)
767                       for i in xrange(x.ndim)]]
768      m_ss = np.sum(x - shift_value, axis=axis, keepdims=keep_dims)
769      v_ss = np.sum(
770          (x - shift_value) * (x - shift_value),
771          axis=axis,
772          keepdims=keep_dims)
773    else:
774      shift_value = None
775      m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
776      v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
777    count = 1.0
778    for d in xrange(x.ndim):
779      if d in set(axes):
780        count *= x.shape[d]
781    if not keep_dims:
782      shift_value = np.squeeze(shift_value, axis=axis)
783    return count, m_ss, v_ss, shift_value
784
785  def _opSuffStats(self, x, axes, shift, keep_dims):
786    return tf.nn.sufficient_statistics(x, axes, shift, keep_dims)
787
788  def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape):
789    x_val = np.random.random_sample(x_shape).astype(np.float32)
790    np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims)
791    for use_gpu in [True, False]:
792      with self.test_session(use_gpu=use_gpu) as sess:
793        if has_shape:
794          x = tf.constant(x_val, name="x")
795          x.set_shape(x_shape)
796          op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
797          if shift:
798            tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s])
799          else:
800            tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v])
801        else:
802          x = tf.placeholder(dtype=tf.float32,
803                             shape=[None] * len(x_shape),
804                             name="x")
805          op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
806          if shift:
807            tf_c, tf_m, tf_v, tf_s = sess.run(
808                [op_c, op_m, op_v, op_s],
809                feed_dict={x: x_val})
810          else:
811            tf_c, tf_m, tf_v = sess.run(
812                [op_c, op_m, op_v],
813                feed_dict={x: x_val})
814        self.assertAllClose(np_c, tf_c, atol=0.000001)
815        self.assertAllClose(np_m, tf_m, atol=0.000001)
816        self.assertAllClose(np_v, tf_v, atol=0.000001)
817        if shift:
818          self.assertAllClose(np_s, tf_s, atol=0.000001)
819
820  def testSuffStats(self):
821    for has_shape in [True, False]:
822      for keep_dims in [True, False]:
823        for shift in [True, False]:
824          self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
825          self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
826          self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
827
828
829class AggregateMomentsTest(tf.test.TestCase):
830
831  def _npAggregateMoments(self, counts, mean_ss, variance_ss, shift):
832    mean = mean_ss / counts
833    variance = variance_ss / counts - mean * mean
834    if shift is not None:
835      mean += shift
836    return mean, variance
837
838  def _opAggregateMoments(self, counts, mean_ss, variance_ss, shift):
839    return tf.nn.aggregate_moments(counts, mean_ss, variance_ss, shift)
840
841  def _testAggregateMoments(self, shape, shift):
842    counts = np.ones([1]).astype(np.float32)
843    mean_ss = np.random.random_sample(shape).astype(np.float32)
844    variance_ss = np.random.random_sample(shape).astype(np.float32)
845    variance_ss *= variance_ss
846    if shift:
847      shift_v = np.random.random_sample(shape).astype(np.float32)
848    else:
849      shift_v = None
850    npm, npv = self._npAggregateMoments(counts, mean_ss, variance_ss, shift_v)
851    for use_gpu in [True, False]:
852      with self.test_session(use_gpu=use_gpu) as sess:
853        tf_counts = tf.constant(counts, name="counts")
854        tf_mean_ss = tf.constant(mean_ss, name="mean_ss")
855        tf_variance_ss = tf.constant(variance_ss, name="variance_ss")
856        if shift:
857          tf_shift_v = tf.constant(shift_v, name="shift")
858        else:
859          tf_shift_v = None
860        opm, opv = self._opAggregateMoments(tf_counts, tf_mean_ss,
861                                            tf_variance_ss, tf_shift_v)
862        tfm, tfv = sess.run([opm, opv])
863        self.assertAllClose(npm, tfm, atol=0.000001)
864        self.assertAllClose(npv, tfv, atol=0.000001)
865
866  def testAggregateMoments(self):
867    for shift in [True, False]:
868      self._testAggregateMoments([3], shift)
869      self._testAggregateMoments([2, 3], shift)
870
871
872class MomentsTest(tf.test.TestCase):
873
874  def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims):
875    with self.test_session():
876      # shape = [batch, width, height, depth]
877      assert len(shape) == 4
878
879      x_numpy = np.random.normal(size=shape).astype(np.float32)
880      x = tf.placeholder(tf.float32, shape=[None] * len(shape))
881
882      mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims)
883
884      num_elements = np.prod([shape[i] for i in axes])
885
886      ax = tuple(axes)
887      expected_mean = np.sum(
888          x_numpy, axis=ax, keepdims=keep_dims) / num_elements
889      expected_mean_squared = np.multiply(expected_mean, expected_mean)
890      expected_x_squared = np.sum(
891          np.multiply(x_numpy, x_numpy),
892          axis=ax,
893          keepdims=keep_dims) / num_elements
894      expected_variance = expected_x_squared - expected_mean_squared
895
896      # Check that the moments are correct.
897      self.assertAllClose(expected_mean, mean.eval(feed_dict={x: x_numpy}))
898      self.assertAllClose(expected_variance, var.eval(feed_dict={x: x_numpy}))
899
900  def RunMomentTest(self, shape, axes, keep_dims):
901    with self.test_session():
902      # shape = [batch, width, height, depth]
903      assert len(shape) == 4
904
905      x_numpy = np.random.normal(size=shape).astype(np.float32)
906      x = tf.constant(x_numpy)
907
908      mean, var = tf.nn.moments(x, axes, keep_dims=keep_dims)
909
910      num_elements = np.prod([shape[i] for i in axes])
911
912      ax = tuple(axes)
913      expected_mean = np.sum(
914          x_numpy, axis=ax, keepdims=keep_dims) / num_elements
915      expected_mean_squared = np.multiply(expected_mean, expected_mean)
916      expected_x_squared = np.sum(
917          np.multiply(x_numpy, x_numpy),
918          axis=ax,
919          keepdims=keep_dims) / num_elements
920      expected_variance = expected_x_squared - expected_mean_squared
921
922      # Check that the moments are correct.
923      self.assertAllClose(expected_mean, mean.eval())
924      self.assertAllClose(expected_variance, var.eval())
925
926  def testBasic(self):
927    for keep_dims in [False, True]:
928      self.RunMomentTest(shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims)
929      self.RunMomentTestWithDynamicShape(
930          shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims)
931
932  def testGlobalNormalization(self):
933    for keep_dims in [False, True]:
934      self.RunMomentTest(
935          shape=[2, 3, 5, 4], axes=[0, 1, 2], keep_dims=keep_dims)
936      self.RunMomentTestWithDynamicShape(
937          shape=[2, 3, 5, 4], axes=[0, 1, 2], keep_dims=keep_dims)
938
939  def testAxes(self):
940    for keep_dims in [False, True]:
941      self.RunMomentTest(
942          shape=[2, 3, 5, 4], axes=[1, 2, 3], keep_dims=keep_dims)
943      self.RunMomentTestWithDynamicShape(
944          shape=[2, 3, 5, 4], axes=[1, 2, 3], keep_dims=keep_dims)
945
946  def _testGlobalGradient(self, from_y="mean"):
947    with self.test_session():
948      x_shape = [3, 5, 4, 2]
949      x_val = np.random.random_sample(x_shape).astype(np.float64)
950      x = tf.constant(x_val)
951      x.set_shape(x_shape)
952
953      axes = [0, 1, 2]
954      y_shape = [2]  # Depth of x
955      out_mean, out_var = tf.nn.moments(x, axes)
956      if from_y == "mean":
957        y = out_mean
958      elif from_y == "var":
959        y = out_var
960      err = tf.test.compute_gradient_error(x, x_shape, y, y_shape)
961      print("Moments %s gradient err = %g" % (from_y, err))
962      self.assertLess(err, 1e-11)
963
964  def testMeanGlobalGradient(self):
965    self._testGlobalGradient(from_y="mean")
966
967  def testVarGlobalGradient(self):
968    self._testGlobalGradient(from_y="var")
969
970  def testOutputNamesNoKeep(self):
971    """Make sure the output names are stable."""
972    with self.test_session():
973      mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False)
974      self.assertEquals(mean.op.name, "moments/aggregate/mean")
975      self.assertEquals(var.op.name, "moments/aggregate/variance")
976
977  def testOutputNamesKeep(self):
978    """Make sure the output names are stable."""
979    with self.test_session():
980      mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True)
981      self.assertEquals(mean.op.name, "moments/aggregate/mean")
982      self.assertEquals(var.op.name, "moments/aggregate/variance")
983
984
985class ComputeSampledLogitsTest(tf.test.TestCase):
986
987  def setUp(self):
988    self._num_classes = 5
989    self._dim = 10
990    self._batch_size = 3
991    self._num_shards = 3
992
993  def _GenerateTestInputs(self):
994    np.random.seed(0)
995    weights = np.random.randn(self._num_classes, self._dim).astype(np.float32)
996    biases = np.random.randn(self._num_classes).astype(np.float32)
997    hidden_acts = np.random.randn(self._batch_size, self._dim).astype(
998        np.float32)
999    sharded_weights = [
1000        weights[[row for row in range(self._num_classes)
1001                 if row % self._num_shards == shard]]
1002        for shard in range(self._num_shards)]
1003    return weights, biases, hidden_acts, sharded_weights
1004
1005  def _ComputeSampledLogitsNP(self, true_w, true_b, sampled_w, sampled_b,
1006                              hidden_acts,
1007                              num_true=1,
1008                              true_expected=None,
1009                              sampled_expected=None):
1010
1011    batch_size, dim = hidden_acts.shape
1012    true_logits = np.sum(
1013        hidden_acts.reshape((batch_size, 1, dim)) * true_w.reshape(
1014            (batch_size, num_true, dim)),
1015        axis=2)
1016    true_b = true_b.reshape((batch_size, num_true))
1017    true_logits += true_b
1018    sampled_logits = np.dot(hidden_acts, sampled_w.T) + sampled_b
1019
1020    if true_expected is not None:
1021      true_logits -= np.log(true_expected)
1022    if sampled_expected is not None:
1023      sampled_logits -= np.log(sampled_expected[np.newaxis, :])
1024
1025    out_logits = np.concatenate([true_logits, sampled_logits], axis=1)
1026    out_labels = np.hstack((np.ones_like(true_logits) / num_true,
1027                            np.zeros_like(sampled_logits)))
1028
1029    return out_logits, out_labels
1030
1031  def _ComputeSampledLogitsTF(self, weights, biases, hidden_acts, labels,
1032                              num_sampled, num_classes, num_true, sampled_vals,
1033                              subtract_log_q, remove_accidental_hits,
1034                              name="sampled_loss_TF"):
1035    # Should be called from within a `with test_session():` block
1036    if isinstance(weights, list):
1037      weights_tf = [tf.constant(shard) for shard in weights]
1038    else:
1039      weights_tf = tf.constant(weights)
1040    biases_tf = tf.constant(biases)
1041    hidden_acts_tf = tf.constant(hidden_acts,
1042                                 shape=(self._batch_size, self._dim))
1043    labels_tf = tf.constant(labels,
1044                            dtype=tf.int64,
1045                            shape=(self._batch_size, num_true))
1046
1047    pred_logits_tf, pred_labels_tf = tf.nn._compute_sampled_logits(
1048        weights_tf,
1049        biases_tf,
1050        hidden_acts_tf,
1051        labels_tf,
1052        num_sampled,
1053        num_classes,
1054        num_true,
1055        sampled_vals,
1056        subtract_log_q=subtract_log_q,
1057        remove_accidental_hits=remove_accidental_hits,
1058        name=name)
1059    return pred_logits_tf, pred_labels_tf
1060
1061  def testComputeSampledLogitsShapes(self):
1062    # We just check that the shapes of the returned values are correct.
1063    weights, biases, hidden_acts, _ = self._GenerateTestInputs()
1064    sampled = [1, 0, 2, 3]
1065    num_sampled = len(sampled)
1066    true_exp = sampled_exp = [1., 1., 1., 1.]
1067    test_sampled_vals = (sampled, true_exp, sampled_exp)
1068    sampled_w, sampled_b = weights[sampled], biases[sampled]
1069
1070    with self.test_session() as sess:
1071      for num_true_test in range(1, 5):
1072        labels = np.random.randint(low=0, high=self._num_classes,
1073                                   size=self._batch_size * num_true_test)
1074        true_w, true_b = weights[labels], biases[labels]
1075
1076        logits_np, labels_np = self._ComputeSampledLogitsNP(
1077            true_w, true_b, sampled_w, sampled_b, hidden_acts,
1078            num_true=num_true_test)
1079
1080        logits_tf, labels_tf = self._ComputeSampledLogitsTF(
1081            weights, biases, hidden_acts, labels, num_sampled,
1082            self._num_classes,
1083            num_true=num_true_test,
1084            sampled_vals=test_sampled_vals,
1085            remove_accidental_hits=True,
1086            subtract_log_q=False)
1087
1088      logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
1089      self.assertEqual(logits_np.shape, logits_tf_val.shape)
1090      self.assertEqual(labels_np.shape, labels_tf_val.shape)
1091
1092  def testComputeSampledLogitsValues(self):
1093    # Here we check the actual numerics.
1094    weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs()
1095    eps = 1e-3
1096    sampled = [1, 0, 2, 3]
1097    num_sampled = len(sampled)
1098    true_exp = np.empty([self._batch_size, 1], dtype=np.float32)
1099    true_exp.fill(0.5)
1100    sampled_exp = np.empty([num_sampled], dtype=np.float32)
1101    sampled_exp.fill(0.5)
1102    sampled_w, sampled_b = weights[sampled], biases[sampled]
1103    test_sampled_vals = (sampled, true_exp, sampled_exp)
1104
1105    with self.test_session() as sess:
1106      for num_true_test in range(1, 5):
1107        # Generate test data for this run
1108        labels = np.random.randint(low=0, high=self._num_classes,
1109                                   size=self._batch_size * num_true_test)
1110        true_w, true_b = weights[labels], biases[labels]
1111
1112        # Test 1: Without accidental hit removal or subtract_log_q
1113        logits_np, labels_np = self._ComputeSampledLogitsNP(
1114            true_w, true_b, sampled_w, sampled_b, hidden_acts,
1115            num_true=num_true_test)
1116        logits_tf, labels_tf = self._ComputeSampledLogitsTF(
1117            weights, biases, hidden_acts, labels, num_sampled,
1118            self._num_classes,
1119            num_true=num_true_test,
1120            sampled_vals=test_sampled_vals,
1121            subtract_log_q=False,
1122            remove_accidental_hits=False,
1123            name="sampled_loss_test1_num_true%d" % num_true_test)
1124
1125        logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
1126        self.assertAllClose(logits_np, logits_tf_val, eps)
1127        self.assertAllClose(labels_np, labels_tf_val, eps)
1128
1129        # Test 2: With accidental hit removal, no subtract_log_q
1130        logits_tf, labels_tf = self._ComputeSampledLogitsTF(
1131            weights, biases, hidden_acts, labels, num_sampled,
1132            self._num_classes,
1133            num_true=num_true_test,
1134            sampled_vals=test_sampled_vals,
1135            subtract_log_q=False,
1136            remove_accidental_hits=True,
1137            name="sampled_loss_test2_num_true%d" % num_true_test)
1138
1139        # Test that the exponentiated logits of accidental hits are near 0.
1140        # First we need to find the hits in this random test run:
1141        labels_reshape = labels.reshape((self._batch_size, num_true_test))
1142        logits_tf_np = logits_tf.eval()
1143        for row in xrange(self._batch_size):
1144          row_labels = labels_reshape[row, :]
1145          for col in xrange(num_sampled):
1146            if sampled[col] in row_labels:
1147              # We need to add the num_true_test offset into logits_*
1148              self.assertNear(
1149                  np.exp(logits_tf_np[row, col + num_true_test]), 0., eps)
1150
1151        # Test 3: With subtract_log_q, no accidental hit removal
1152        logits_np, labels_np = self._ComputeSampledLogitsNP(
1153            true_w, true_b, sampled_w, sampled_b, hidden_acts,
1154            num_true=num_true_test,
1155            true_expected=true_exp,
1156            sampled_expected=sampled_exp)
1157        logits_tf, labels_tf = self._ComputeSampledLogitsTF(
1158            weights, biases, hidden_acts, labels, num_sampled,
1159            self._num_classes,
1160            num_true=num_true_test,
1161            sampled_vals=test_sampled_vals,
1162            subtract_log_q=True,
1163            remove_accidental_hits=False,
1164            name="sampled_loss_test3_num_true%d" % num_true_test)
1165
1166        logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
1167        self.assertAllClose(logits_np, logits_tf_val, eps)
1168        self.assertAllClose(labels_np, labels_tf_val, eps)
1169
1170        # Test 4: Test 1, with sharded weights
1171        logits_np, labels_np = self._ComputeSampledLogitsNP(
1172            true_w, true_b, sampled_w, sampled_b, hidden_acts,
1173            num_true=num_true_test)
1174        logits_tf, labels_tf = self._ComputeSampledLogitsTF(
1175            sharded_weights, biases, hidden_acts, labels, num_sampled,
1176            self._num_classes,
1177            num_true=num_true_test,
1178            sampled_vals=test_sampled_vals,
1179            subtract_log_q=False,
1180            remove_accidental_hits=False,
1181            name="sampled_loss_test1_num_true%d" % num_true_test)
1182
1183        logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf])
1184        self.assertAllClose(logits_np, logits_tf_val, eps)
1185        self.assertAllClose(labels_np, labels_tf_val, eps)
1186
1187  def testNCELoss(self):
1188    # A simple test to verify the numerics.
1189
1190    def _SigmoidCrossEntropyWithLogits(logits, targets):
1191      # logits, targets: float arrays of the same shape.
1192      assert logits.shape == targets.shape
1193      pred = 1. / (1. + np.exp(-logits))
1194      eps = 0.0001
1195      pred = np.minimum(np.maximum(pred, eps), 1 - eps)
1196      return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred)
1197
1198    weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs()
1199    labels = [0, 1, 2]
1200    true_w, true_b = weights[labels], biases[labels]
1201    sampled = [1, 0, 2, 3]
1202    num_sampled = len(sampled)
1203    true_exp = np.empty([self._batch_size, 1], dtype=np.float32)
1204    true_exp.fill(0.5)
1205    sampled_exp = np.empty([num_sampled], dtype=np.float32)
1206    sampled_exp.fill(0.5)
1207    sampled_w, sampled_b = weights[sampled], biases[sampled]
1208    test_sampled_vals = (sampled, true_exp, sampled_exp)
1209
1210    with self.test_session():
1211      logits_np, labels_np = self._ComputeSampledLogitsNP(
1212          true_w, true_b, sampled_w, sampled_b, hidden_acts,
1213          true_expected=true_exp,
1214          sampled_expected=sampled_exp)
1215      nce_loss_np = np.sum(
1216          _SigmoidCrossEntropyWithLogits(logits_np, labels_np), 1)
1217
1218      labels_tf = tf.constant(labels, shape=(self._batch_size, 1))
1219      weights_tf = tf.constant(weights)
1220      biases_tf = tf.constant(biases)
1221      inputs_tf = tf.constant(hidden_acts)
1222
1223      nce_loss_tf = tf.nn.nce_loss(weights_tf,
1224                                   biases_tf,
1225                                   inputs_tf,
1226                                   labels_tf,
1227                                   num_sampled=1,
1228                                   num_classes=self._num_classes,
1229                                   num_true=1,
1230                                   sampled_values=test_sampled_vals)
1231
1232      self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4)
1233
1234      # Test with sharded weights
1235      nce_loss_tf = tf.nn.nce_loss(
1236          [tf.constant(shard) for shard in sharded_weights],
1237          biases_tf,
1238          inputs_tf,
1239          labels_tf,
1240          num_sampled=1,
1241          num_classes=self._num_classes,
1242          num_true=1,
1243          sampled_values=test_sampled_vals)
1244
1245      self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4)
1246
1247  def testSampledSoftmaxLoss(self):
1248    # A simple test to verify the numerics.
1249
1250    def _SoftmaxCrossEntropyWithLogits(logits, targets):
1251      # logits, targets: float arrays of the same shape.
1252      assert logits.shape == targets.shape
1253      stable_exp_logits = np.exp(logits - np.amax(
1254          logits, axis=1, keepdims=True))
1255      pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True)
1256      return -np.sum(targets * np.log(pred + 1.0e-20), axis=1)
1257
1258    weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs()
1259    labels = [0, 1, 2]
1260    true_w, true_b = weights[labels], biases[labels]
1261    sampled = [1, 0, 2, 3]
1262    num_sampled = len(sampled)
1263    true_exp = np.full([self._batch_size, 1], fill_value=0.5, dtype=np.float32)
1264    sampled_exp = np.full([num_sampled], fill_value=0.5, dtype=np.float32)
1265    sampled_w, sampled_b = weights[sampled], biases[sampled]
1266    test_sampled_vals = (sampled, true_exp, sampled_exp)
1267
1268    with self.test_session():
1269      logits_np, labels_np = self._ComputeSampledLogitsNP(
1270          true_w, true_b, sampled_w, sampled_b, hidden_acts,
1271          true_expected=true_exp,
1272          sampled_expected=sampled_exp)
1273      sampled_softmax_loss_np = _SoftmaxCrossEntropyWithLogits(logits_np,
1274                                                               labels_np)
1275
1276      labels_tf = tf.constant(labels, shape=(self._batch_size, 1))
1277      weights_tf = tf.constant(weights)
1278      biases_tf = tf.constant(biases)
1279      inputs_tf = tf.constant(hidden_acts)
1280
1281      sampled_softmax_loss_tf = tf.nn.sampled_softmax_loss(
1282          weights_tf,
1283          biases_tf,
1284          inputs_tf,
1285          labels_tf,
1286          num_sampled=1,
1287          num_classes=self._num_classes,
1288          num_true=1,
1289          sampled_values=test_sampled_vals,
1290          remove_accidental_hits=False)
1291
1292      self.assertAllClose(
1293          sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4)
1294
1295      # Test with sharded weights
1296      sampled_softmax_loss_tf = tf.nn.sampled_softmax_loss(
1297          [tf.constant(shard) for shard in sharded_weights],
1298          biases_tf,
1299          inputs_tf,
1300          labels_tf,
1301          num_sampled=1,
1302          num_classes=self._num_classes,
1303          num_true=1,
1304          sampled_values=test_sampled_vals,
1305          remove_accidental_hits=False)
1306
1307      self.assertAllClose(
1308          sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4)
1309
1310
1311if __name__ == "__main__":
1312  tf.test.main()
1313