pooling_ops_test.py revision 61d3a958d6d83cb6037490d933b47621cc4009cc
1"""Functional tests for pooling operations."""
2from __future__ import print_function
3import tensorflow.python.platform
4
5import numpy as np
6import tensorflow as tf
7
8from tensorflow.python.kernel_tests import gradient_checker as gc
9from tensorflow.python.ops import gen_nn_ops
10
11
12def GetInceptionMaxPoolShapes():
13  """Iterator for some of the max pool ops in the Inception 2015 model.
14
15  Yields:
16    Tuple (name, input_size, filter_size, out_size, strides, padding)
17  """
18  names = ["maxpool2", "maxpool3", "maxpool4", "maxpool5"]
19  input_sizes = [[32, 71, 71, 192],
20                 [32, 35, 35, 288], [32, 17, 17, 1248], [32, 8, 8, 2048]]
21  filter_sizes = [[1, 3, 3, 1], [1, 3, 3, 1],
22                  [1, 3, 3, 1], [1, 3, 3, 1]]
23  output_sizes = [[32, 35, 35, 192], [32, 17, 17, 288],
24                  [32, 8, 8, 1248], [32, 8, 8, 2048]]
25  strides = [[1, 2, 2, 1], [1, 2, 2, 1], [1, 2, 2, 1],
26             [1, 1, 1, 1]]
27  paddings = ["VALID", "VALID", "VALID", "SAME"]
28  for n, i, f, o, s, p in zip(names, input_sizes, filter_sizes, output_sizes,
29                              strides, paddings):
30    yield n, i, f, o, s, p
31
32
33class PoolingTest(tf.test.TestCase):
34
35  def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
36                    expected, use_gpu):
37    """Verifies the output values of the pooling function.
38
39    Args:
40      pool_func: Function to be called, co.MaxPool, co.AvgPool,
41        or the Lua version.
42      input_sizes: Input tensor dimensions.
43      ksize: The kernel size dimensions
44      strides: The stride dimensions
45      padding: Padding type.
46      expected: An array containing the expected operation outputs.
47      use_gpu: Whether we are running on GPU.
48    """
49    total_size = 1
50    for s in input_sizes:
51      total_size *= s
52    # Initializes the input tensor with array containing incrementing
53    # numbers from 1.
54    x = [f * 1.0 for f in range(1, total_size + 1)]
55    with self.test_session(use_gpu=use_gpu) as sess:
56      t = tf.constant(x, shape=input_sizes)
57      t = pool_func(t, ksize=ksize, strides=strides, padding=padding)
58      actual = t.eval()
59      self.assertAllClose(expected, actual.flatten())
60      self.assertShapeEqual(actual, t)
61
62  def _testAvgPoolValidPadding(self, use_gpu):
63    expected_output = [7.0, 8.0, 9.0]
64    self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 3, 3, 3],
65                       ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
66                       padding="VALID",
67                       expected=expected_output, use_gpu=use_gpu)
68
69  def _testAvgPoolSamePadding(self, use_gpu):
70    expected_output = [8.5, 9.5, 10.5, 14.5, 15.5, 16.5]
71    self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 2, 4, 3],
72                       ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
73                       padding="SAME",
74                       expected=expected_output, use_gpu=use_gpu)
75
76  def _testAvgPoolSamePaddingNonSquareWindow(self, use_gpu):
77    # input is:
78    # [1.0, 2.0
79    #  3.0  4.0]
80    #
81    # Window of [x, x] should do:
82    #  [avg(1.0, 2.0), avg(2.0, padded0),
83    #   avg(3.0, 4.0), avg(4.0, padded0)]
84    self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 2, 2, 1],
85                       ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1],
86                       padding="SAME",
87                       expected=[1.5, 2.0, 3.5, 4.0], use_gpu=use_gpu)
88
89    # Window of [x,
90    #            x] should do:
91    #  [avg(1.0, 3.0), avg(2.0, 4.0)
92    #   avg(3.0, padded0), avg(4.0, padded0)]
93    self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 2, 2, 1],
94                       ksize=[1, 2, 1, 1], strides=[1, 1, 1, 1],
95                       padding="SAME",
96                       expected=[2.0, 3.0, 3.0, 4.0], use_gpu=use_gpu)
97
98  def _testAvgPoolSamePaddingNonSquareWindowMultiBatch(self, use_gpu):
99    self._VerifyValues(tf.nn.avg_pool, input_sizes=[2, 2, 2, 2],
100                       ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1],
101                       padding="SAME",
102                       expected=[2.0, 3.0, 3.0, 4.0,
103                                 6.0, 7.0, 7.0, 8.0,
104                                 10.0, 11.0, 11.0, 12.0,
105                                 14.0, 15.0, 15.0, 16.0],
106                       use_gpu=use_gpu)
107    self._VerifyValues(tf.nn.avg_pool, input_sizes=[2, 2, 2, 2],
108                       ksize=[1, 2, 1, 1], strides=[1, 1, 1, 1],
109                       padding="SAME",
110                       expected=[3.0, 4.0, 5.0, 6.0,
111                                 5.0, 6.0, 7.0, 8.0,
112                                 11.0, 12.0, 13.0, 14.0,
113                                 13.0, 14.0, 15.0, 16.0],
114                       use_gpu=use_gpu)
115
116  def _testAvgPoolValidPaddingUnevenStride(self, use_gpu):
117    self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 3, 3, 3],
118                       ksize=[1, 2, 2, 1], strides=[1, 1, 2, 1],
119                       padding="VALID",
120                       expected=[7.0, 8.0, 9.0, 16.0, 17.0, 18.0],
121                       use_gpu=use_gpu)
122    self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 3, 3, 3],
123                       ksize=[1, 2, 2, 1], strides=[1, 2, 1, 1],
124                       padding="VALID",
125                       expected=[7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
126                       use_gpu=use_gpu)
127
128  def _testAvgPoolSamePadding4(self, use_gpu):
129    expected_output = [11.0, 12.0, 13.0, 14.0, 19.0, 20.0, 21.0, 22.0, 43.0,
130                       44.0, 45.0, 46.0, 51.0, 52.0, 53.0, 54.0]
131    self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 4, 4, 4],
132                       ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
133                       padding="SAME",
134                       expected=expected_output, use_gpu=use_gpu)
135
136  def _testAvgPoolSamePaddingPacket4(self, use_gpu):
137    expected_output = [21.0, 22.0, 23.0, 24.0, 27.0, 28.0, 29.0, 30.0,
138                       45.0, 46.0, 47.0, 48.0, 51.0, 52.0, 53.0, 54.0]
139    self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 4, 4, 4],
140                       ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
141                       padding="SAME",
142                       expected=expected_output, use_gpu=use_gpu)
143
144  def _testAvgPoolSamePaddingPacket8(self, use_gpu):
145    expected_output = [73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 89.0,
146                       90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 96.0, 105.0, 106.0,
147                       107.0, 108.0, 109.0, 110.0, 111.0, 112.0, 117.0, 118.0,
148                       119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 201.0, 202.0,
149                       203.0, 204.0, 205.0, 206.0, 207.0, 208.0, 217.0, 218.0,
150                       219.0, 220.0, 221.0, 222.0, 223.0, 224.0, 233.0, 234.0,
151                       235.0, 236.0, 237.0, 238.0, 239.0, 240.0, 245.0, 246.0,
152                       247.0, 248.0, 249.0, 250.0, 251.0, 252.0, 329.0, 330.0,
153                       331.0, 332.0, 333.0, 334.0, 335.0, 336.0, 345.0, 346.0,
154                       347.0, 348.0, 349.0, 350.0, 351.0, 352.0, 361.0, 362.0,
155                       363.0, 364.0, 365.0, 366.0, 367.0, 368.0, 373.0, 374.0,
156                       375.0, 376.0, 377.0, 378.0, 379.0, 380.0, 425.0, 426.0,
157                       427.0, 428.0, 429.0, 430.0, 431.0, 432.0, 441.0, 442.0,
158                       443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 457.0, 458.0,
159                       459.0, 460.0, 461.0, 462.0, 463.0, 464.0, 469.0, 470.0,
160                       471.0, 472.0, 473.0, 474.0, 475.0, 476.0]
161    self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 8, 8, 8],
162                       ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
163                       padding="SAME",
164                       expected=expected_output, use_gpu=use_gpu)
165
166  def testAvgPooling(self):
167    for use_gpu in True, False:
168      self._testAvgPoolValidPadding(use_gpu)
169      self._testAvgPoolSamePadding(use_gpu)
170      self._testAvgPoolSamePaddingNonSquareWindow(use_gpu)
171      self._testAvgPoolSamePaddingNonSquareWindowMultiBatch(use_gpu)
172      self._testAvgPoolValidPaddingUnevenStride(use_gpu)
173      self._testAvgPoolSamePadding4(use_gpu)
174      self._testAvgPoolSamePaddingPacket4(use_gpu)
175      self._testAvgPoolSamePaddingPacket8(use_gpu)
176
177  def _testMaxPoolValidPadding(self, use_gpu):
178    expected_output = [13.0, 14.0, 15.0]
179    self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 3, 3, 3],
180                       ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
181                       padding="VALID",
182                       expected=expected_output, use_gpu=use_gpu)
183
184  def _testMaxPoolSamePadding(self, use_gpu):
185    expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0]
186    self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 2, 3, 3],
187                       ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
188                       padding="SAME",
189                       expected=expected_output, use_gpu=use_gpu)
190
191  def _testMaxPoolSamePaddingNonSquareWindow(self, use_gpu):
192    # input is:
193    # [1.0, 2.0
194    #  3.0  4.0]
195    #
196    # Window of [x, x] should do:
197    #
198    #  [max(1.0, 2.0), max(2.0, padded0),
199    #   max(3.0, 4.0), max(4.0, padded0)]
200    self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 2, 2, 1],
201                       ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1],
202                       padding="SAME",
203                       expected=[2.0, 2.0, 4.0, 4.0], use_gpu=use_gpu)
204
205  def _testMaxPoolValidPaddingUnevenStride(self, use_gpu):
206    self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 4, 4, 1],
207                       ksize=[1, 2, 2, 1], strides=[1, 1, 2, 1],
208                       padding="VALID",
209                       expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0],
210                       use_gpu=use_gpu)
211    self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 4, 4, 1],
212                       ksize=[1, 2, 2, 1], strides=[1, 2, 1, 1],
213                       padding="VALID",
214                       expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0],
215                       use_gpu=use_gpu)
216
217  def _testMaxPoolSamePaddingPacket4(self, use_gpu):
218    expected_output = [21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0,
219                       54.0, 55.0, 56.0, 61.0, 62.0, 63.0, 64.0]
220    self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 4, 4, 4],
221                       ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
222                       padding="SAME",
223                       expected=expected_output, use_gpu=use_gpu)
224
225  def _testMaxPoolSamePaddingPacket8(self, use_gpu):
226    expected_output = [145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0,
227                       161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0,
228                       177.0, 178.0, 179.0, 180.0, 181.0, 182.0, 183.0, 184.0,
229                       185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, 192.0,
230                       273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0,
231                       289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0,
232                       305.0, 306.0, 307.0, 308.0, 309.0, 310.0, 311.0, 312.0,
233                       313.0, 314.0, 315.0, 316.0, 317.0, 318.0, 319.0, 320.0,
234                       401.0, 402.0, 403.0, 404.0, 405.0, 406.0, 407.0, 408.0,
235                       417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0,
236                       433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0,
237                       441.0, 442.0, 443.0, 444.0, 445.0, 446.0, 447.0, 448.0,
238                       465.0, 466.0, 467.0, 468.0, 469.0, 470.0, 471.0, 472.0,
239                       481.0, 482.0, 483.0, 484.0, 485.0, 486.0, 487.0, 488.0,
240                       497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0,
241                       505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0]
242    self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 8, 8, 8],
243                       ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
244                       padding="SAME",
245                       expected=expected_output, use_gpu=use_gpu)
246
247  def testMaxPooling(self):
248    for use_gpu in True, False:
249      self._testMaxPoolValidPadding(use_gpu)
250      self._testMaxPoolSamePadding(use_gpu)
251      self._testMaxPoolSamePaddingNonSquareWindow(use_gpu)
252      self._testMaxPoolValidPaddingUnevenStride(use_gpu)
253      self._testMaxPoolSamePaddingPacket4(use_gpu)
254      self._testMaxPoolSamePaddingPacket8(use_gpu)
255
256  # Tests for DepthwiseMaxPooling on CPU only.
257  def testDepthwiseMaxPool1x1DepthWindow1(self):
258    # input is:
259    # [1.0, ..., 10.0] along depth,
260    #
261    # We maxpool by depth in patches of 2.
262    self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 1, 1, 10],
263                       ksize=[1, 1, 1, 2], strides=[1, 1, 1, 2],
264                       padding="SAME",
265                       expected=[2.0, 4.0, 6.0, 8.0, 10.0], use_gpu=False)
266
267  def testDepthwiseMaxPool2x2DepthWindow3(self):
268    # input is:
269    #
270    # a 2x2x6 cube, and we depthwise max across 3 to produce a 2x2x2
271    # output.  Each node has contiguous values, so the depthwise max
272    # should be multiples of 3.0.
273    self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 2, 2, 6],
274                       ksize=[1, 1, 1, 3], strides=[1, 1, 1, 3],
275                       padding="SAME",
276                       expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0],
277                       use_gpu=False)
278
279  def _testDepthwiseMaxPoolInvalidConfig(self, in_size, ksize, strides,
280                                         error_msg, use_gpu=False):
281    t = tf.constant(1.0, shape=in_size)
282    with self.assertRaisesRegexp(ValueError, error_msg):
283      t = tf.nn.max_pool(t, ksize=ksize, strides=strides, padding="SAME")
284
285  def testDepthwiseMaxPoolInvalidConfigs(self):
286    self._testDepthwiseMaxPoolInvalidConfig(
287        [1, 2, 2, 4], [1, 2, 2, 2],
288        [1, 1, 1, 2], "exactly one of pooling across depth")
289    self._testDepthwiseMaxPoolInvalidConfig(
290        [1, 2, 2, 4], [1, 1, 1, 2],
291        [1, 1, 1, 1], "depth window to equal the depth stride")
292    self._testDepthwiseMaxPoolInvalidConfig(
293        [1, 2, 2, 4], [1, 1, 1, 3],
294        [1, 1, 1, 3], "evenly divide")
295    if tf.test.IsBuiltWithCuda():
296      with self.test_session(use_gpu=True):
297        t = tf.constant(1.0, shape=[1, 2, 2, 4])
298        with self.assertRaisesOpError("for CPU devices"):
299          tf.nn.max_pool(t, ksize=[1, 1, 1, 2], strides=[1, 1, 1, 2],
300                         padding="SAME").eval()
301
302  # The following are tests that verify that the CPU and GPU implementations
303  # produce the same resuts.
304  def _CompareMaxPoolingFwd(self, input_shape, ksize, strides, padding):
305    tensor_input = np.random.rand(*input_shape).astype(np.float32)
306    with self.test_session(use_gpu=True):
307      t = tf.constant(tensor_input, shape=input_shape)
308      out_op, _ = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
309      gpu_val = out_op.eval()
310    with self.test_session(use_gpu=False):
311      t = tf.constant(tensor_input, shape=input_shape)
312      out_op = tf.nn.max_pool(t, ksize, strides, padding)
313      cpu_val = out_op.eval()
314    self.assertAllClose(cpu_val, gpu_val, rtol=1e-5, atol=1e-5)
315
316  def _CompareMaxPoolingBk(self, input_shape, output_shape, ksize, strides,
317                           padding):
318    # Generate numbers in a narrow range, so that there are many duplicates
319    # in the input.
320    tensor_input = np.random.random_integers(0, 3,
321                                             input_shape).astype(np.float32)
322    tensor_output = np.random.rand(*output_shape).astype(np.float32)
323    with self.test_session(use_gpu=True):
324      t = tf.constant(tensor_input, shape=input_shape)
325      _, argmax_op = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
326      argmax = argmax_op.eval()
327      grad_in = tf.constant(tensor_output, shape=output_shape)
328      out_op = gen_nn_ops._max_pool_grad_with_argmax(t, grad_in, argmax,
329                                                     ksize, strides, padding)
330      gpu_val = out_op.eval()
331      self.assertShapeEqual(gpu_val, out_op)
332    with self.test_session(use_gpu=False):
333      t = tf.constant(tensor_input, shape=input_shape)
334      out_op = tf.nn.max_pool(t, ksize, strides, padding)
335      orig_out = out_op.eval()
336      grad_in = tf.constant(tensor_output, shape=output_shape)
337      out_op = gen_nn_ops._max_pool_grad(t, orig_out, grad_in, ksize,
338                                         strides, padding)
339      cpu_val = out_op.eval()
340      self.assertShapeEqual(cpu_val, out_op)
341    self.assertAllClose(cpu_val, gpu_val, rtol=1e-5, atol=1e-5)
342
343  def testMaxPoolingWithArgmax(self):
344    # MaxPoolWithArgMax is implemented only on GPU.
345    if not tf.test.IsBuiltWithCuda():
346      return
347    tensor_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
348    with self.test_session(use_gpu=True) as sess:
349      t = tf.constant(tensor_input, shape=[1, 3, 3, 1])
350      out_op, argmax_op = tf.nn.max_pool_with_argmax(t,
351                                                   ksize=[1, 2, 2, 1],
352                                                   strides=[1, 1, 1, 1],
353                                                   Targmax=tf.int64,
354                                                   padding="VALID")
355      out, argmax = sess.run([out_op, argmax_op])
356      self.assertShapeEqual(out, out_op)
357      self.assertShapeEqual(argmax, argmax_op)
358      self.assertAllClose(out.ravel(), [1.0, 1.0, 1.0, 1.0])
359      self.assertAllEqual(argmax.ravel(), [0, 1, 3, 5])
360
361  def testMaxPoolingGradWithArgmax(self):
362    # MaxPoolWithArgMax is implemented only on GPU.
363    if not tf.test.IsBuiltWithCuda():
364      return
365    orig_input = [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
366    tensor_input = [11.0, 12.0, 13.0, 14.0]
367    tensor_argmax = list(np.array([0, 1, 3, 5], dtype=np.int64))
368    with self.test_session(use_gpu=True) as sess:
369      orig_in = tf.constant(orig_input, shape=[1, 3, 3, 1])
370      t = tf.constant(tensor_input, shape=[1, 2, 2, 1])
371      argmax = tf.constant(tensor_argmax, shape=[1, 2, 2, 1],
372                                    dtype=tf.int64)
373      out_op = gen_nn_ops._max_pool_grad_with_argmax(orig_in, t, argmax,
374                                                     ksize=[1, 2, 2, 1],
375                                                     strides=[1, 1, 1, 1],
376                                                     padding="VALID")
377      out = out_op.eval().flatten()
378      self.assertAllClose(out, [11.0, 12.0, 0.0, 13.0, 0.0,
379                                14.0, 0.0, 0.0, 0.0])
380
381  def _ConstructAndTestGradient(self, pool_func, input_sizes, output_sizes,
382                                window_rows, window_cols, row_stride,
383                                col_stride, padding, use_gpu,
384                                x_init_value=None):
385    """Verifies the gradients of the avg pooling function.
386
387    Args:
388      pool_func: Function to be called, co.MaxPool, co.AvgPool,
389        or the Lua version.
390      input_sizes: Input tensor dimensions.
391      output_sizes: Output tensor dimensions.
392      window_rows: kernel size in row dim
393      window_cols: kernel size in col dim
394      row_stride: Row Stride.
395      col_stride: Col Stride.
396      padding: Padding type.
397      use_gpu: whether we are running on GPU
398      x_init_value: Values to be passed to the gradient checker.
399    """
400    total_size = 1
401    for s in input_sizes:
402      total_size *= s
403    # Initializes the input tensor with array containing incrementing
404    # numbers from 1.
405    x = [f * 1.0 for f in range(1, total_size + 1)]
406    with self.test_session(use_gpu=use_gpu):
407      input_tensor = tf.constant(x, shape=input_sizes, name="input")
408      if pool_func == tf.nn.avg_pool:
409        func_name = "avg_pool"
410        err_margin = 1e-4
411      else:
412        if x_init_value is None:
413          x_init_value = np.asfarray(
414              np.arange(1, total_size + 1),
415              dtype=np.float32).reshape(input_sizes)
416        func_name = "max_pool"
417        err_margin = 1e-3
418      t = pool_func(input_tensor, ksize=[1, window_rows, window_rows, 1],
419                    strides=[1, row_stride, col_stride, 1],
420                    padding=padding, name=func_name)
421      err = gc.ComputeGradientError(
422          input_tensor, input_sizes, t, output_sizes,
423          x_init_value=x_init_value, delta=1e-2)
424    print("%s gradient error = " % func_name, err)
425    self.assertLess(err, err_margin)
426
427  def _testMaxPoolGradValidPadding1_1(self, use_gpu):
428    self._ConstructAndTestGradient(
429        tf.nn.max_pool, input_sizes=[1, 3, 3, 1],
430        output_sizes=[1, 3, 3, 1], window_rows=1, window_cols=1, row_stride=1,
431        col_stride=1, padding="VALID", use_gpu=use_gpu)
432
433  def _testMaxPoolGradValidPadding2_1_6(self, use_gpu):
434    self._ConstructAndTestGradient(
435        tf.nn.max_pool, input_sizes=[2, 6, 6, 3],
436        output_sizes=[2, 5, 5, 3], window_rows=2, window_cols=2, row_stride=1,
437        col_stride=1, padding="VALID", use_gpu=use_gpu)
438
439  def _testMaxPoolGradValidPadding2_1_7(self, use_gpu):
440    self._ConstructAndTestGradient(
441        tf.nn.max_pool, input_sizes=[2, 7, 7, 3],
442        output_sizes=[2, 6, 6, 3], window_rows=2, window_cols=2, row_stride=1,
443        col_stride=1, padding="VALID", use_gpu=use_gpu)
444
445  def _testMaxPoolGradValidPadding2_2(self, use_gpu):
446    self._ConstructAndTestGradient(
447        tf.nn.max_pool, input_sizes=[2, 2, 2, 3],
448        output_sizes=[2, 1, 1, 3], window_rows=2, window_cols=2, row_stride=2,
449        col_stride=2, padding="VALID", use_gpu=use_gpu)
450
451  def _testMaxPoolGradSamePadding1_1(self, use_gpu):
452    self._ConstructAndTestGradient(
453        tf.nn.max_pool, input_sizes=[2, 2, 4, 3],
454        output_sizes=[2, 2, 4, 3], window_rows=1, window_cols=1, row_stride=1,
455        col_stride=1, padding="SAME", use_gpu=use_gpu)
456
457  def _testMaxPoolGradSamePadding2_1(self, use_gpu):
458    self._ConstructAndTestGradient(
459        tf.nn.max_pool, input_sizes=[2, 2, 4, 3],
460        output_sizes=[2, 2, 4, 3], window_rows=2, window_cols=2, row_stride=1,
461        col_stride=1, padding="SAME", use_gpu=use_gpu)
462
463  def _testMaxPoolGradSamePadding2_2(self, use_gpu):
464    self._ConstructAndTestGradient(
465        tf.nn.max_pool, input_sizes=[2, 2, 4, 3],
466        output_sizes=[2, 1, 2, 3], window_rows=2, window_cols=2, row_stride=2,
467        col_stride=2, padding="SAME", use_gpu=use_gpu)
468
469  def _testMaxPoolGradSamePadding3_1(self, use_gpu):
470    self._ConstructAndTestGradient(
471        tf.nn.max_pool, input_sizes=[1, 7, 7, 1],
472        output_sizes=[1, 7, 7, 1], window_rows=3, window_cols=3, row_stride=1,
473        col_stride=1, padding="SAME", use_gpu=use_gpu)
474
475  def testMaxPoolGrad(self):
476    for use_gpu in True, False:
477      self._testMaxPoolGradValidPadding1_1(use_gpu=use_gpu)
478      self._testMaxPoolGradValidPadding2_1_6(use_gpu=use_gpu)
479      self._testMaxPoolGradValidPadding2_1_7(use_gpu=use_gpu)
480      self._testMaxPoolGradValidPadding2_2(use_gpu=use_gpu)
481      self._testMaxPoolGradSamePadding1_1(use_gpu=use_gpu)
482      self._testMaxPoolGradSamePadding2_1(use_gpu=use_gpu)
483      self._testMaxPoolGradSamePadding2_2(use_gpu=use_gpu)
484      self._testMaxPoolGradSamePadding3_1(use_gpu=use_gpu)
485
486  def _MaxPoolGrad(self, orig_input, orig_output, grad, window_rows,
487                   window_cols, row_stride, col_stride, padding):
488    """Max Pooling Gradient.
489
490    Args:
491      orig_input: A float Tensor. The original input tensor.
492      orig_output: A float Tensor. The original output tensor.
493      grad: A float Tensor.
494        The 4D (batch x rows x cols x depth) output backprop.
495      window_rows: integer. Kernel size along rows dimension.
496      window_cols: integer. Kernel size along cols dimension.
497      row_stride: integer. Stride along rows dimension
498      col_stride: integer. Stride along cols dimension
499      padding: PoolingOpDef.Padding.  Padding type.
500
501    Returns:
502      A Tensor.
503    """
504    return gen_nn_ops._max_pool_grad(
505        orig_input, orig_output, grad,
506        [1, window_rows, window_cols, 1], [1, row_stride, col_stride, 1],
507        padding)
508
509  def _testMaxPoolGradDirect(self, input_data, output_backprop,
510                             expected_input_backprop, input_sizes, output_sizes,
511                             window_rows, window_cols, row_stride, col_stride,
512                             padding, use_gpu):
513    with self.test_session(use_gpu=use_gpu) as sess:
514      input_tensor = tf.constant(input_data, shape=input_sizes)
515      output_tensor = tf.nn.max_pool(
516          input_tensor, [1, window_rows, window_cols, 1],
517          [1, row_stride, col_stride, 1], padding)
518      output_backprop_tensor = tf.constant(output_backprop,
519                                                    shape=output_sizes)
520
521      input_backprop_tensor = self._MaxPoolGrad(
522          input_tensor, output_tensor, output_backprop_tensor,
523          window_rows, window_cols, row_stride, col_stride, padding)
524
525      actual_input_backprop = input_backprop_tensor.eval()
526      self.assertShapeEqual(actual_input_backprop, input_backprop_tensor)
527      actual_input_backprop = actual_input_backprop.flatten()
528      actual_input_backprop = self._GetNdArray(actual_input_backprop)
529
530      actual_output = output_tensor.eval().flatten()
531      actual_output = self._GetNdArray(actual_output)
532
533      self.assertAllClose(expected_input_backprop, actual_input_backprop,
534                          rtol=1e-6, atol=1e-6)
535
536  def _testMaxPoolGradDirect1_1(self):
537    input_data = [
538        1.0, 1.0, 1.0, 1.0,
539        1.0, 1.0, 1.0, 1.0,
540        1.0, 1.0, 1.0, 1.0,
541        1.0, 1.0, 1.0, 1.0]
542    output_backprop = [
543        11.0, 12.0, 13.0,
544        15.0, 16.0, 17.0,
545        19.0, 20.0, 21.0]
546    expected_input_backprop = [
547        11.0, 12.0, 13.0, 0.0,
548        15.0, 16.0, 17.0, 0.0,
549        19.0, 20.0, 21.0, 0.0,
550        0.0, 0.0, 0.0, 0.0]
551
552    for use_gpu in True, False:
553      self._testMaxPoolGradDirect(
554          input_data, output_backprop, expected_input_backprop,
555          input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
556          window_rows=2, window_cols=2, row_stride=1, col_stride=1,
557          padding="VALID", use_gpu=use_gpu)
558
559  def _testMaxPoolGradDirect1_2(self):
560    input_data = [
561        1.0, 0.0, 1.0, 0.0,
562        0.0, 1.0, 0.0, 1.0,
563        1.0, 0.0, 1.0, 0.0,
564        0.0, 1.0, 0.0, 1.0]
565    output_backprop = [
566        11.0, 12.0, 13.0,
567        15.0, 16.0, 17.0,
568        19.0, 20.0, 21.0]
569    expected_input_backprop = [
570        11.0, 0.0, 25.0, 0.0,
571        0.0, 31.0, 0.0, 17.0,
572        19.0, 0.0, 41.0, 0.0,
573        0.0, 0.0, 0.0, 0.0]
574
575    for use_gpu in True, False:
576      self._testMaxPoolGradDirect(
577          input_data, output_backprop, expected_input_backprop,
578          input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
579          window_rows=2, window_cols=2, row_stride=1, col_stride=1,
580          padding="VALID", use_gpu=use_gpu)
581
582  def _testMaxPoolGradDirect1_3(self):
583    input_data = [
584        1.0, 0.0, 1.0, 0.0,
585        0.0, 1.0, 0.0, 1.0,
586        1.0, 0.0, 1.0, 0.0,
587        0.0, 1.0, 0.0, 1.0,]
588    output_backprop = [
589        11.0, 12.0, 13.0, 14.0,
590        15.0, 16.0, 17.0, 18.0,
591        19.0, 20.0, 21.0, 22.0,
592        23.0, 24.0, 25.0, 26.0]
593    expected_input_backprop = [
594        54, 0.0, 62, 0.0,
595        0.0, 60, 0.0, 22.0,
596        47, 0.0, 51, 0.0,
597        0.0, 0.0, 0.0, 0.0,]
598
599    for use_gpu in True, False:
600      self._testMaxPoolGradDirect(
601          input_data, output_backprop, expected_input_backprop,
602          input_sizes=[1, 4, 4, 1], output_sizes=[1, 4, 4, 1],
603          window_rows=3, window_cols=3, row_stride=1, col_stride=1,
604          padding="SAME", use_gpu=use_gpu)
605
606  def _testMaxPoolGradDirectWithNans2_1(self):
607    input_data = [float("nan")] * 16
608    output_backprop = [
609        11.0, 12.0, 13.0,
610        15.0, 16.0, 17.0,
611        19.0, 20.0, 21.0]
612    # Test the CPU implementation, which propagates diffs in case of NaN
613    expected_input_backprop_tf_cpu = [
614        11.0, 12.0, 13.0, 0.0,
615        15.0, 16.0, 17.0, 0.0,
616        19.0, 20.0, 21.0, 0.0,
617        0.0, 0.0, 0.0, 0.0]
618    self._testMaxPoolGradDirect(
619        input_data, output_backprop, expected_input_backprop_tf_cpu,
620        input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
621        window_rows=2, window_cols=2, row_stride=1, col_stride=1,
622        padding="VALID", use_gpu=False)
623
624    if not tf.test.IsBuiltWithCuda():
625      return
626
627    # Test the GPU implementation that uses cudnn for now.
628    # It does not propagate the diff in cases of NaNs
629    expected_input_backprop_cudnn = [
630        0.0, 0.0, 0.0, 0.0,
631        0.0, 0.0, 0.0, 0.0,
632        0.0, 0.0, 0.0, 0.0,
633        0.0, 0.0, 0.0, 0.0]
634    self._testMaxPoolGradDirect(
635        input_data, output_backprop, expected_input_backprop_cudnn,
636        input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
637        window_rows=2, window_cols=2, row_stride=1, col_stride=1,
638        padding="VALID", use_gpu=True)
639
640  def _testMaxPoolGradDirectWithNans2_2(self):
641    input_data = [float("nan")] * 16
642    output_backprop = [
643        float("nan"), 12.0, 13.0,
644        15.0, float("nan"), 17.0,
645        19.0, 20.0, float("nan")]
646    # Test the CPU implementation, which propagates diffs in case of NaN
647    expected_input_backprop_tf_cpu = [
648        float("nan"), 12.0, 13.0, 0.0,
649        15.0, float("nan"), 17.0, 0.0,
650        19.0, 20.0, float("nan"), 0.0,
651        0.0, 0.0, 0.0, 0.0]
652    self._testMaxPoolGradDirect(
653        input_data, output_backprop, expected_input_backprop_tf_cpu,
654        input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
655        window_rows=2, window_cols=2, row_stride=1, col_stride=1,
656        padding="VALID", use_gpu=False)
657
658    if not tf.test.IsBuiltWithCuda():
659      return
660
661    # Test the GPU implementation that uses cudnn for now.
662    # It does not propagate the diff in cases of NaNs
663    expected_input_backprop_cudnn = [
664        0.0, 0.0, 0.0, 0.0,
665        0.0, 0.0, 0.0, 0.0,
666        0.0, 0.0, 0.0, 0.0,
667        0.0, 0.0, 0.0, 0.0]
668    self._testMaxPoolGradDirect(
669        input_data, output_backprop, expected_input_backprop_cudnn,
670        input_sizes=[1, 4, 4, 1], output_sizes=[1, 3, 3, 1],
671        window_rows=2, window_cols=2, row_stride=1, col_stride=1,
672        padding="VALID", use_gpu=True)
673
674  def testMaxPoolGradDirect(self):
675    self._testMaxPoolGradDirect1_1()
676    self._testMaxPoolGradDirect1_2()
677    self._testMaxPoolGradDirect1_3()
678    self._testMaxPoolGradDirectWithNans2_1()
679    self._testMaxPoolGradDirectWithNans2_2()
680
681  def testAvgPoolGrad(self):
682    for use_gpu in False, True:
683      self._testAvgPoolGradValidPadding1_1(use_gpu)
684      self._testAvgPoolGradValidPadding2_1(use_gpu)
685      self._testAvgPoolGradValidPadding2_2(use_gpu)
686      self._testAvgPoolGradSamePadding1_1(use_gpu)
687      self._testAvgPoolGradSamePadding2_1(use_gpu)
688      self._testAvgPoolGradSamePadding2_2(use_gpu)
689      self._testAvgPoolGradSamePadding3_1(use_gpu)
690
691  def _testAvgPoolGradValidPadding1_1(self, use_gpu):
692    self._ConstructAndTestGradient(
693        tf.nn.avg_pool, input_sizes=[2, 3, 3, 3],
694        output_sizes=[2, 3, 3, 3], window_rows=1, window_cols=1, row_stride=1,
695        col_stride=1, padding="VALID", use_gpu=use_gpu)
696
697  def _testAvgPoolGradValidPadding2_1(self, use_gpu):
698    self._ConstructAndTestGradient(
699        tf.nn.avg_pool, input_sizes=[2, 3, 3, 3],
700        output_sizes=[2, 2, 2, 3], window_rows=2, window_cols=2, row_stride=1,
701        col_stride=1, padding="VALID", use_gpu=use_gpu)
702
703  def _testAvgPoolGradValidPadding2_2(self, use_gpu):
704    self._ConstructAndTestGradient(
705        tf.nn.avg_pool, input_sizes=[2, 2, 2, 3],
706        output_sizes=[2, 1, 1, 3], window_rows=2, window_cols=2, row_stride=2,
707        col_stride=2, padding="VALID", use_gpu=use_gpu)
708
709  def _testAvgPoolGradSamePadding1_1(self, use_gpu):
710    self._ConstructAndTestGradient(
711        tf.nn.avg_pool, input_sizes=[2, 2, 4, 3],
712        output_sizes=[2, 2, 4, 3], window_rows=1, window_cols=1, row_stride=1,
713        col_stride=1, padding="SAME", use_gpu=use_gpu)
714
715  def _testAvgPoolGradSamePadding2_1(self, use_gpu):
716    self._ConstructAndTestGradient(
717        tf.nn.avg_pool, input_sizes=[2, 2, 4, 3],
718        output_sizes=[2, 2, 4, 3], window_rows=2, window_cols=2, row_stride=1,
719        col_stride=1, padding="SAME", use_gpu=use_gpu)
720
721  def _testAvgPoolGradSamePadding2_2(self, use_gpu):
722    self._ConstructAndTestGradient(
723        tf.nn.avg_pool, input_sizes=[2, 2, 4, 3],
724        output_sizes=[2, 1, 2, 3], window_rows=2, window_cols=2, row_stride=2,
725        col_stride=2, padding="SAME", use_gpu=use_gpu)
726
727  def _testAvgPoolGradSamePadding3_1(self, use_gpu):
728    self._ConstructAndTestGradient(
729        tf.nn.avg_pool, input_sizes=[1, 7, 7, 1],
730        output_sizes=[1, 7, 7, 1], window_rows=3, window_cols=3, row_stride=1,
731        col_stride=1, padding="SAME", use_gpu=use_gpu)
732
733  def testShapeFunctionEdgeCases(self):
734    # All shapes unknown.
735    for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]:
736      p = tf.nn.max_pool(tf.placeholder(tf.float32),
737                         ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
738                         padding="SAME")
739      self.assertEqual([None, None, None, None], p.get_shape().as_list())
740    p, am = tf.nn.max_pool_with_argmax(
741        tf.placeholder(tf.float32),
742        ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
743        padding="SAME")
744    self.assertEqual([None, None, None, None], p.get_shape().as_list())
745    self.assertEqual([None, None, None, None], am.get_shape().as_list())
746
747    # Incorrect input shape.
748    for pool_func in [tf.nn.max_pool, tf.nn.avg_pool,
749                      tf.nn.max_pool_with_argmax]:
750      with self.assertRaises(ValueError):
751        pool_func(tf.placeholder(tf.float32, shape=[1, 3]),
752                  ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1], padding="SAME")
753
754    # Illegal strides.
755    for pool_func in [tf.nn.max_pool, tf.nn.avg_pool,
756                      tf.nn.max_pool_with_argmax]:
757      with self.assertRaisesRegexp(ValueError, "strides in the batch"):
758        pool_func(tf.placeholder(tf.float32),
759                  ksize=[1, 1, 1, 1], strides=[2, 1, 1, 1], padding="SAME")
760    with self.assertRaisesRegexp(ValueError, "strides in the batch and depth"):
761      tf.nn.avg_pool(tf.placeholder(tf.float32),
762                     ksize=[1, 1, 1, 1], strides=[1, 1, 1, 2], padding="SAME")
763
764    # Filter larger than input.
765    for pool_func in [tf.nn.max_pool, tf.nn.avg_pool,
766                      tf.nn.max_pool_with_argmax]:
767      with self.assertRaisesRegexp(ValueError,
768                                   "filter must not be larger than the input"):
769        pool_func(tf.placeholder(tf.float32,
770                                        shape=[32, 20, 20, 3]),
771                  ksize=[1, 20, 21, 1], strides=[1, 1, 1, 1], padding="SAME")
772      with self.assertRaisesRegexp(ValueError,
773                                   "filter must not be larger than the input"):
774        pool_func(tf.placeholder(tf.float32,
775                                        shape=[32, 20, 20, 3]),
776                  ksize=[1, 21, 20, 1], strides=[1, 1, 1, 1], padding="SAME")
777
778    # Stride larger than filter.
779    for pool_func in [tf.nn.max_pool, tf.nn.avg_pool,
780                      tf.nn.max_pool_with_argmax]:
781      with self.assertRaisesRegexp(
782          ValueError, "stride must be less than or equal to filter"):
783        pool_func(tf.placeholder(tf.float32,
784                                        shape=[32, 20, 20, 3]),
785                  ksize=[1, 5, 3, 1], strides=[1, 5, 5, 1], padding="SAME")
786      with self.assertRaisesRegexp(
787          ValueError, "stride must be less than or equal to filter"):
788        pool_func(tf.placeholder(tf.float32,
789                                        shape=[32, 20, 20, 3]),
790                  ksize=[1, 3, 5, 1], strides=[1, 5, 5, 1], padding="SAME")
791
792
793def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):
794  def Test(self):
795    # MaxPoolWithArgMax is implemented only on GPU.
796    if not tf.test.IsBuiltWithCuda():
797      return
798    self._CompareMaxPoolingFwd(input_size, filter_size, strides, padding)
799  return Test
800
801
802def GetMaxPoolGradTest(input_size, filter_size, output_size, strides, padding):
803  def Test(self):
804    # MaxPoolWithArgMax is implemented only on GPU.
805    if not tf.test.IsBuiltWithCuda():
806      return
807    self._CompareMaxPoolingBk(input_size, output_size,
808                              filter_size, strides, padding)
809  return Test
810
811
812if __name__ == "__main__":
813  for (name_, input_size_, filter_size_, output_size_, stride_,
814       padding_) in GetInceptionMaxPoolShapes():
815    setattr(PoolingTest, "testMaxPoolFwd_" + name_,
816            GetMaxPoolFwdTest(input_size_, filter_size_, stride_, padding_))
817    setattr(PoolingTest, "testMaxPoolGrad_" + name_,
818            GetMaxPoolGradTest(input_size_, filter_size_, output_size_,
819                               stride_, padding_))
820  tf.test.main()
821