conv_ops_3d_test.py revision e70c00950d295c519fd9c7f8b12e13a3c5aaf710
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for 3d convolutional operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import math
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import gradient_checker
28from tensorflow.python.ops import nn_ops
29import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
30from tensorflow.python.platform import test
31
32
33def GetTestConfigs():
34  """Get all the valid tests configs to run.
35
36  Returns:
37    all the valid test configs as tuples of data_format and use_gpu.
38  """
39  test_configs = [("NDHWC", False), ("NDHWC", True)]
40  if test.is_gpu_available(cuda_only=True):
41    # "NCDHW" format is only supported on CUDA.
42    test_configs += [("NCDHW", True)]
43  return test_configs
44
45
46class Conv3DTest(test.TestCase):
47
48  def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, stride,
49                            padding, data_format, use_gpu):
50    total_size_1 = 1
51    total_size_2 = 1
52    for s in tensor_in_sizes:
53      total_size_1 *= s
54    for s in filter_in_sizes:
55      total_size_2 *= s
56
57    # Initializes the input tensor with array containing incrementing
58    # numbers from 1.
59    x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
60    x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
61    with self.test_session(use_gpu=use_gpu):
62      t1 = constant_op.constant(x1, shape=tensor_in_sizes)
63      t2 = constant_op.constant(x2, shape=filter_in_sizes)
64
65      if isinstance(stride, collections.Iterable):
66        strides = [1] + list(stride) + [1]
67      else:
68        strides = [1, stride, stride, stride, 1]
69
70      if data_format == "NCDHW":
71        t1 = test_util.NHWCToNCHW(t1)
72        strides = test_util.NHWCToNCHW(strides)
73      conv = nn_ops.conv3d(t1, t2, strides, padding=padding,
74                           data_format=data_format)
75      if data_format == "NCDHW":
76        conv = test_util.NCHWToNHWC(conv)
77
78      return conv
79
80  def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,
81                    expected):
82    results = []
83    for data_format, use_gpu in GetTestConfigs():
84      result = self._SetupValuesForDevice(
85          tensor_in_sizes,
86          filter_in_sizes,
87          stride,
88          padding,
89          data_format,
90          use_gpu=use_gpu)
91      results.append(result)
92      tolerance = 1e-2 if use_gpu else 1e-5
93      with self.test_session() as sess:
94        values = sess.run(results)
95        for value in values:
96          print("expected = ", expected)
97          print("actual = ", value)
98          self.assertAllClose(expected, value.flatten(), atol=tolerance,
99                              rtol=1e-6)
100
101  def testConv3D1x1x1Filter(self):
102    expected_output = [
103        30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0,
104        204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0
105    ]
106
107    # These are equivalent to the Conv2D1x1 case.
108    self._VerifyValues(
109        tensor_in_sizes=[1, 2, 3, 1, 3],
110        filter_in_sizes=[1, 1, 1, 3, 3],
111        stride=1,
112        padding="VALID",
113        expected=expected_output)
114    self._VerifyValues(
115        tensor_in_sizes=[1, 2, 1, 3, 3],
116        filter_in_sizes=[1, 1, 1, 3, 3],
117        stride=1,
118        padding="VALID",
119        expected=expected_output)
120    self._VerifyValues(
121        tensor_in_sizes=[1, 1, 2, 3, 3],
122        filter_in_sizes=[1, 1, 1, 3, 3],
123        stride=1,
124        padding="VALID",
125        expected=expected_output)
126
127  # Expected values computed using scipy's correlate function.
128  def testConv3D2x2x2Filter(self):
129    expected_output = [
130        19554., 19962., 20370., 22110., 22590., 23070., 34890., 35730., 36570.,
131        37446., 38358., 39270., 50226., 51498., 52770., 52782., 54126., 55470.
132    ]
133    # expected_shape = [1, 3, 1, 2, 5]
134    self._VerifyValues(
135        tensor_in_sizes=[1, 4, 2, 3, 3],  # b, z, y, x, fin
136        filter_in_sizes=[2, 2, 2, 3, 3],  # z, y, x, fin, fout
137        stride=1,
138        padding="VALID",
139        expected=expected_output)
140
141  def testConv3DStrides(self):
142    expected_output = [
143        102.,
144        151.,
145        172.,
146        193.,
147        214.,
148        235.,
149        142.,
150        438.,
151        592.,
152        613.,
153        634.,
154        655.,
155        676.,
156        394.,
157        774.,
158        1033.,
159        1054.,
160        1075.,
161        1096.,
162        1117.,
163        646.,
164        1894.,
165        2503.,
166        2524.,
167        2545.,
168        2566.,
169        2587.,
170        1486.,
171        2230.,
172        2944.,
173        2965.,
174        2986.,
175        3007.,
176        3028.,
177        1738.,
178        2566.,
179        3385.,
180        3406.,
181        3427.,
182        3448.,
183        3469.,
184        1990.,
185        3686.,
186        4855.,
187        4876.,
188        4897.,
189        4918.,
190        4939.,
191        2830.,
192        4022.,
193        5296.,
194        5317.,
195        5338.,
196        5359.,
197        5380.,
198        3082.,
199        4358.,
200        5737.,
201        5758.,
202        5779.,
203        5800.,
204        5821.,
205        3334.,
206    ]
207    self._VerifyValues(
208        tensor_in_sizes=[1, 5, 8, 7, 1],
209        filter_in_sizes=[1, 2, 3, 1, 1],
210        stride=[2, 3, 1],  # different stride for each spatial dimension
211        padding="SAME",
212        expected=expected_output)
213
214  def testConv3D2x2x2FilterStride2(self):
215    expected_output = [19554., 19962., 20370., 50226., 51498., 52770.]
216    self._VerifyValues(
217        tensor_in_sizes=[1, 4, 2, 3, 3],
218        filter_in_sizes=[2, 2, 2, 3, 3],
219        stride=2,
220        padding="VALID",
221        expected=expected_output)
222
223  def testConv3DStride3(self):
224    expected_output = [
225        36564., 38022., 39480., 37824., 39354., 40884., 39084., 40686., 42288.,
226        46644., 48678., 50712., 47904., 50010., 52116., 49164., 51342., 53520.,
227        107124., 112614., 118104., 108384., 113946., 119508., 109644., 115278.,
228        120912., 117204., 123270., 129336., 118464., 124602., 130740., 119724.,
229        125934., 132144.
230    ]
231    self._VerifyValues(
232        tensor_in_sizes=[1, 6, 7, 8, 2],
233        filter_in_sizes=[3, 2, 1, 2, 3],
234        stride=3,
235        padding="VALID",
236        expected=expected_output)
237
238  def testConv3D2x2x2FilterStride2Same(self):
239    expected_output = [
240        19554., 19962., 20370., 10452., 10710., 10968., 50226., 51498., 52770.,
241        23844., 24534., 25224.
242    ]
243    self._VerifyValues(
244        tensor_in_sizes=[1, 4, 2, 3, 3],
245        filter_in_sizes=[2, 2, 2, 3, 3],
246        stride=2,
247        padding="SAME",
248        expected=expected_output)
249
250  def testKernelSmallerThanStride(self):
251    expected_output = [1., 3., 7., 9., 19., 21., 25., 27.]
252    self._VerifyValues(
253        tensor_in_sizes=[1, 3, 3, 3, 1],
254        filter_in_sizes=[1, 1, 1, 1, 1],
255        stride=2,
256        padding="SAME",
257        expected=expected_output)
258    self._VerifyValues(
259        tensor_in_sizes=[1, 3, 3, 3, 1],
260        filter_in_sizes=[1, 1, 1, 1, 1],
261        stride=2,
262        padding="VALID",
263        expected=expected_output)
264
265    expected_output = [
266        1484., 1592., 770., 2240., 2348., 1106., 1149., 1191., 539., 6776.,
267        6884., 3122., 7532., 7640., 3458., 3207., 3249., 1421., 3005., 3035.,
268        1225., 3215., 3245., 1309., 1013., 1022., 343.
269    ]
270    self._VerifyValues(
271        tensor_in_sizes=[1, 7, 7, 7, 1],
272        filter_in_sizes=[2, 2, 2, 1, 1],
273        stride=3,
274        padding="SAME",
275        expected=expected_output)
276
277    expected_output = [1484., 1592., 2240., 2348., 6776., 6884., 7532., 7640.]
278    self._VerifyValues(
279        tensor_in_sizes=[1, 7, 7, 7, 1],
280        filter_in_sizes=[2, 2, 2, 1, 1],
281        stride=3,
282        padding="VALID",
283        expected=expected_output)
284
285  def testKernelSizeMatchesInputSize(self):
286    self._VerifyValues(
287        tensor_in_sizes=[1, 2, 1, 2, 1],
288        filter_in_sizes=[2, 1, 2, 1, 2],
289        stride=1,
290        padding="VALID",
291        expected=[50, 60])
292
293  def _ConstructAndTestGradientForConfig(
294      self, batch, input_shape, filter_shape, in_depth, out_depth, stride,
295      padding, test_input, data_format, use_gpu):
296
297    input_planes, input_rows, input_cols = input_shape
298    filter_planes, filter_rows, filter_cols = filter_shape
299
300    input_shape = [batch, input_planes, input_rows, input_cols, in_depth]
301    filter_shape = [
302        filter_planes, filter_rows, filter_cols, in_depth, out_depth
303    ]
304
305    if isinstance(stride, collections.Iterable):
306      strides = [1] + list(stride) + [1]
307    else:
308      strides = [1, stride, stride, stride, 1]
309
310    if padding == "VALID":
311      output_planes = int(
312          math.ceil((input_planes - filter_planes + 1.0) / strides[1]))
313      output_rows = int(
314          math.ceil((input_rows - filter_rows + 1.0) / strides[2]))
315      output_cols = int(
316          math.ceil((input_cols - filter_cols + 1.0) / strides[3]))
317    else:
318      output_planes = int(math.ceil(float(input_planes) / strides[1]))
319      output_rows = int(math.ceil(float(input_rows) / strides[2]))
320      output_cols = int(math.ceil(float(input_cols) / strides[3]))
321    output_shape = [batch, output_planes, output_rows, output_cols, out_depth]
322    input_size = 1
323    for x in input_shape:
324      input_size *= x
325    filter_size = 1
326    for x in filter_shape:
327      filter_size *= x
328    input_data = [x * 1.0 / input_size for x in range(0, input_size)]
329    filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
330
331    if test.is_gpu_available() and use_gpu:
332      data_type = dtypes.float32
333      # TODO(mjanusz): Modify gradient_checker to also provide max relative
334      # error and synchronize the tolerance levels between the tests for forward
335      # and backward computations.
336      if test.is_gpu_available():
337        tolerance = 5e-3
338      else:
339        # As of Aug 2016, higher tolerance is needed for some CPU architectures.
340        # Runs on a single machine can also generate slightly different errors
341        # because of multithreading.
342        tolerance = 8e-3
343    else:
344      data_type = dtypes.float64
345      tolerance = 1e-8
346    with self.test_session(use_gpu=use_gpu):
347      orig_input_tensor = constant_op.constant(
348          input_data, shape=input_shape, dtype=data_type, name="input")
349      filter_tensor = constant_op.constant(
350          filter_data, shape=filter_shape, dtype=data_type, name="filter")
351
352      if data_format == "NCDHW":
353        input_tensor = test_util.NHWCToNCHW(orig_input_tensor)
354        strides = test_util.NHWCToNCHW(strides)
355      else:
356        input_tensor = orig_input_tensor
357
358      conv = nn_ops.conv3d(
359          input_tensor, filter_tensor, strides, padding,
360          data_format=data_format, name="conv")
361
362      if data_format == "NCDHW":
363        conv = test_util.NCHWToNHWC(conv)
364
365      if test_input:
366        err = gradient_checker.compute_gradient_error(orig_input_tensor,
367                                                      input_shape,
368                                                      conv, output_shape)
369      else:
370        err = gradient_checker.compute_gradient_error(filter_tensor,
371                                                      filter_shape, conv,
372                                                      output_shape)
373    print("conv3d gradient error = ", err)
374    self.assertLess(err, tolerance)
375
376  def ConstructAndTestGradient(self, **kwargs):
377    for data_format, use_gpu in GetTestConfigs():
378      self._ConstructAndTestGradientForConfig(data_format=data_format,
379                                              use_gpu=use_gpu, **kwargs)
380
381  def testInputGradientValidPaddingStrideOne(self):
382    self.ConstructAndTestGradient(
383        batch=2,
384        input_shape=(3, 5, 4),
385        filter_shape=(3, 3, 3),
386        in_depth=2,
387        out_depth=3,
388        stride=1,
389        padding="VALID",
390        test_input=True)
391
392  def testFilterGradientValidPaddingStrideOne(self):
393    self.ConstructAndTestGradient(
394        batch=4,
395        input_shape=(4, 6, 5),
396        filter_shape=(2, 2, 2),
397        in_depth=2,
398        out_depth=3,
399        stride=1,
400        padding="VALID",
401        test_input=False)
402
403  def testInputGradientValidPaddingStrideTwo(self):
404    self.ConstructAndTestGradient(
405        batch=2,
406        input_shape=(6, 3, 5),
407        filter_shape=(3, 3, 3),
408        in_depth=2,
409        out_depth=3,
410        stride=2,
411        padding="VALID",
412        test_input=True)
413
414  def testFilterGradientValidPaddingStrideTwo(self):
415    self.ConstructAndTestGradient(
416        batch=2,
417        input_shape=(7, 6, 5),
418        filter_shape=(2, 2, 2),
419        in_depth=2,
420        out_depth=3,
421        stride=2,
422        padding="VALID",
423        test_input=False)
424
425  def testInputGradientValidPaddingStrideThree(self):
426    self.ConstructAndTestGradient(
427        batch=2,
428        input_shape=(3, 7, 6),
429        filter_shape=(3, 3, 3),
430        in_depth=2,
431        out_depth=3,
432        stride=3,
433        padding="VALID",
434        test_input=True)
435
436  def testFilterGradientValidPaddingStrideThree(self):
437    self.ConstructAndTestGradient(
438        batch=2,
439        input_shape=(4, 4, 7),
440        filter_shape=(4, 4, 4),
441        in_depth=2,
442        out_depth=3,
443        stride=3,
444        padding="VALID",
445        test_input=False)
446
447  def testInputGradientSamePaddingStrideOne(self):
448    self.ConstructAndTestGradient(
449        batch=2,
450        input_shape=(3, 2, 2),
451        filter_shape=(3, 2, 1),
452        in_depth=2,
453        out_depth=1,
454        stride=1,
455        padding="SAME",
456        test_input=True)
457
458  def testFilterGradientSamePaddingStrideOne(self):
459    self.ConstructAndTestGradient(
460        batch=2,
461        input_shape=(3, 6, 5),
462        filter_shape=(2, 2, 2),
463        in_depth=2,
464        out_depth=3,
465        stride=1,
466        padding="SAME",
467        test_input=False)
468
469  def testInputGradientSamePaddingStrideTwo(self):
470    self.ConstructAndTestGradient(
471        batch=2,
472        input_shape=(6, 3, 4),
473        filter_shape=(3, 3, 3),
474        in_depth=2,
475        out_depth=3,
476        stride=2,
477        padding="SAME",
478        test_input=True)
479
480  def testFilterGradientSamePaddingStrideTwo(self):
481    self.ConstructAndTestGradient(
482        batch=4,
483        input_shape=(7, 3, 5),
484        filter_shape=(2, 2, 2),
485        in_depth=2,
486        out_depth=3,
487        stride=2,
488        padding="SAME",
489        test_input=False)
490
491  def testInputGradientSamePaddingStrideThree(self):
492    self.ConstructAndTestGradient(
493        batch=2,
494        input_shape=(9, 3, 6),
495        filter_shape=(3, 3, 3),
496        in_depth=2,
497        out_depth=3,
498        stride=3,
499        padding="SAME",
500        test_input=True)
501
502  def testFilterGradientSamePaddingStrideThree(self):
503    self.ConstructAndTestGradient(
504        batch=2,
505        input_shape=(9, 4, 7),
506        filter_shape=(4, 4, 4),
507        in_depth=2,
508        out_depth=3,
509        stride=3,
510        padding="SAME",
511        test_input=False)
512
513  def testInputGradientSamePaddingDifferentStrides(self):
514    self.ConstructAndTestGradient(
515        batch=1,
516        input_shape=(5, 8, 7),
517        filter_shape=(1, 2, 3),
518        in_depth=2,
519        out_depth=3,
520        stride=[2, 3, 1],
521        padding="SAME",
522        test_input=True)
523
524  def testFilterGradientKernelSizeMatchesInputSize(self):
525    self.ConstructAndTestGradient(
526        batch=2,
527        input_shape=(5, 4, 3),
528        filter_shape=(5, 4, 3),
529        in_depth=2,
530        out_depth=3,
531        stride=1,
532        padding="VALID",
533        test_input=False)
534
535  def testInputGradientKernelSizeMatchesInputSize(self):
536    self.ConstructAndTestGradient(
537        batch=2,
538        input_shape=(5, 4, 3),
539        filter_shape=(5, 4, 3),
540        in_depth=2,
541        out_depth=3,
542        stride=1,
543        padding="VALID",
544        test_input=True)
545
546  def disabledtestFilterGradientSamePaddingDifferentStrides(self):
547    self.ConstructAndTestGradient(
548        batch=1,
549        input_shape=(5, 8, 7),
550        filter_shape=(1, 2, 3),
551        in_depth=2,
552        out_depth=3,
553        stride=[2, 3, 1],
554        padding="SAME",
555        test_input=False)
556
557
558if __name__ == "__main__":
559  test.main()
560