conv_ops_3d_test.py revision 5866e065bc95c1d7de8a27413b368016941889a6
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for 3d convolutional operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import math
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.ops import gradient_checker
27from tensorflow.python.ops import nn_ops
28import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
29from tensorflow.python.platform import test
30
31
32class Conv3DTest(test.TestCase):
33
34  def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,
35                    expected):
36    total_size_1 = 1
37    total_size_2 = 1
38    for s in tensor_in_sizes:
39      total_size_1 *= s
40    for s in filter_in_sizes:
41      total_size_2 *= s
42
43    if isinstance(stride, collections.Iterable):
44      strides = [1] + list(stride) + [1]
45    else:
46      strides = [1, stride, stride, stride, 1]
47
48    # Initializes the input tensor with array containing incrementing
49    # numbers from 1.
50    x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
51    x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
52    with self.test_session(use_gpu=True) as sess:
53      t1 = constant_op.constant(x1, shape=tensor_in_sizes)
54      t2 = constant_op.constant(x2, shape=filter_in_sizes)
55      conv = nn_ops.conv3d(t1, t2, strides, padding=padding)
56      value = sess.run(conv)
57    print("expected = ", expected)
58    print("actual = ", value)
59    self.assertArrayNear(expected, value.flatten(), 1e-5)
60
61  def testConv3D1x1x1Filter(self):
62    expected_output = [
63        30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0,
64        204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0
65    ]
66
67    # These are equivalent to the Conv2D1x1 case.
68    self._VerifyValues(
69        tensor_in_sizes=[1, 2, 3, 1, 3],
70        filter_in_sizes=[1, 1, 1, 3, 3],
71        stride=1,
72        padding="VALID",
73        expected=expected_output)
74    self._VerifyValues(
75        tensor_in_sizes=[1, 2, 1, 3, 3],
76        filter_in_sizes=[1, 1, 1, 3, 3],
77        stride=1,
78        padding="VALID",
79        expected=expected_output)
80    self._VerifyValues(
81        tensor_in_sizes=[1, 1, 2, 3, 3],
82        filter_in_sizes=[1, 1, 1, 3, 3],
83        stride=1,
84        padding="VALID",
85        expected=expected_output)
86
87  # Expected values computed using scipy's correlate function.
88  def testConv3D2x2x2Filter(self):
89    expected_output = [
90        19554., 19962., 20370., 22110., 22590., 23070., 34890., 35730., 36570.,
91        37446., 38358., 39270., 50226., 51498., 52770., 52782., 54126., 55470.
92    ]
93    # expected_shape = [1, 3, 1, 2, 5]
94    self._VerifyValues(
95        tensor_in_sizes=[1, 4, 2, 3, 3],  # b, z, y, x, fin
96        filter_in_sizes=[2, 2, 2, 3, 3],  # z, y, x, fin, fout
97        stride=1,
98        padding="VALID",
99        expected=expected_output)
100
101  def testConv3DStrides(self):
102    expected_output = [
103        102.,
104        151.,
105        172.,
106        193.,
107        214.,
108        235.,
109        142.,
110        438.,
111        592.,
112        613.,
113        634.,
114        655.,
115        676.,
116        394.,
117        774.,
118        1033.,
119        1054.,
120        1075.,
121        1096.,
122        1117.,
123        646.,
124        1894.,
125        2503.,
126        2524.,
127        2545.,
128        2566.,
129        2587.,
130        1486.,
131        2230.,
132        2944.,
133        2965.,
134        2986.,
135        3007.,
136        3028.,
137        1738.,
138        2566.,
139        3385.,
140        3406.,
141        3427.,
142        3448.,
143        3469.,
144        1990.,
145        3686.,
146        4855.,
147        4876.,
148        4897.,
149        4918.,
150        4939.,
151        2830.,
152        4022.,
153        5296.,
154        5317.,
155        5338.,
156        5359.,
157        5380.,
158        3082.,
159        4358.,
160        5737.,
161        5758.,
162        5779.,
163        5800.,
164        5821.,
165        3334.,
166    ]
167    self._VerifyValues(
168        tensor_in_sizes=[1, 5, 8, 7, 1],
169        filter_in_sizes=[1, 2, 3, 1, 1],
170        stride=[2, 3, 1],  # different stride for each spatial dimension
171        padding="SAME",
172        expected=expected_output)
173
174  def testConv3D2x2x2FilterStride2(self):
175    expected_output = [19554., 19962., 20370., 50226., 51498., 52770.]
176    self._VerifyValues(
177        tensor_in_sizes=[1, 4, 2, 3, 3],
178        filter_in_sizes=[2, 2, 2, 3, 3],
179        stride=2,
180        padding="VALID",
181        expected=expected_output)
182
183  def testConv3DStride3(self):
184    expected_output = [
185        36564., 38022., 39480., 37824., 39354., 40884., 39084., 40686., 42288.,
186        46644., 48678., 50712., 47904., 50010., 52116., 49164., 51342., 53520.,
187        107124., 112614., 118104., 108384., 113946., 119508., 109644., 115278.,
188        120912., 117204., 123270., 129336., 118464., 124602., 130740., 119724.,
189        125934., 132144.
190    ]
191    self._VerifyValues(
192        tensor_in_sizes=[1, 6, 7, 8, 2],
193        filter_in_sizes=[3, 2, 1, 2, 3],
194        stride=3,
195        padding="VALID",
196        expected=expected_output)
197
198  def testConv3D2x2x2FilterStride2Same(self):
199    expected_output = [
200        19554., 19962., 20370., 10452., 10710., 10968., 50226., 51498., 52770.,
201        23844., 24534., 25224.
202    ]
203    self._VerifyValues(
204        tensor_in_sizes=[1, 4, 2, 3, 3],
205        filter_in_sizes=[2, 2, 2, 3, 3],
206        stride=2,
207        padding="SAME",
208        expected=expected_output)
209
210  def testKernelSmallerThanStride(self):
211    expected_output = [1., 3., 7., 9., 19., 21., 25., 27.]
212    self._VerifyValues(
213        tensor_in_sizes=[1, 3, 3, 3, 1],
214        filter_in_sizes=[1, 1, 1, 1, 1],
215        stride=2,
216        padding="SAME",
217        expected=expected_output)
218    self._VerifyValues(
219        tensor_in_sizes=[1, 3, 3, 3, 1],
220        filter_in_sizes=[1, 1, 1, 1, 1],
221        stride=2,
222        padding="VALID",
223        expected=expected_output)
224
225    expected_output = [
226        1484., 1592., 770., 2240., 2348., 1106., 1149., 1191., 539., 6776.,
227        6884., 3122., 7532., 7640., 3458., 3207., 3249., 1421., 3005., 3035.,
228        1225., 3215., 3245., 1309., 1013., 1022., 343.
229    ]
230    self._VerifyValues(
231        tensor_in_sizes=[1, 7, 7, 7, 1],
232        filter_in_sizes=[2, 2, 2, 1, 1],
233        stride=3,
234        padding="SAME",
235        expected=expected_output)
236
237    expected_output = [1484., 1592., 2240., 2348., 6776., 6884., 7532., 7640.]
238    self._VerifyValues(
239        tensor_in_sizes=[1, 7, 7, 7, 1],
240        filter_in_sizes=[2, 2, 2, 1, 1],
241        stride=3,
242        padding="VALID",
243        expected=expected_output)
244
245  def ConstructAndTestGradient(self, batch, input_planes, input_rows,
246                               input_cols, filter_planes, filter_rows,
247                               filter_cols, in_depth, out_depth, stride,
248                               padding, test_input):
249    input_shape = [batch, input_planes, input_rows, input_cols, in_depth]
250    filter_shape = [
251        filter_planes, filter_rows, filter_cols, in_depth, out_depth
252    ]
253
254    if isinstance(stride, collections.Iterable):
255      strides = [1] + list(stride) + [1]
256    else:
257      strides = [1, stride, stride, stride, 1]
258
259    if padding == "VALID":
260      output_planes = int(
261          math.ceil((input_planes - filter_planes + 1.0) / strides[1]))
262      output_rows = int(
263          math.ceil((input_rows - filter_rows + 1.0) / strides[2]))
264      output_cols = int(
265          math.ceil((input_cols - filter_cols + 1.0) / strides[3]))
266    else:
267      output_planes = int(math.ceil(float(input_planes) / strides[1]))
268      output_rows = int(math.ceil(float(input_rows) / strides[2]))
269      output_cols = int(math.ceil(float(input_cols) / strides[3]))
270    output_shape = [batch, output_planes, output_rows, output_cols, out_depth]
271    input_size = 1
272    for x in input_shape:
273      input_size *= x
274    filter_size = 1
275    for x in filter_shape:
276      filter_size *= x
277    input_data = [x * 1.0 / input_size for x in range(0, input_size)]
278    filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
279    if test.is_gpu_available():
280      data_type = dtypes.float32
281      if test.is_gpu_available():
282        tolerance = 4e-3
283      else:
284        # As of Aug 2016, higher tolerance is needed for some CPU architectures.
285        # Runs on a single machine can also generate slightly different errors
286        # because of multithreading.
287        tolerance = 8e-3
288    else:
289      data_type = dtypes.float64
290      tolerance = 1e-8
291    with self.test_session(use_gpu=True):
292      input_tensor = constant_op.constant(
293          input_data, shape=input_shape, dtype=data_type, name="input")
294      filter_tensor = constant_op.constant(
295          filter_data, shape=filter_shape, dtype=data_type, name="filter")
296      conv = nn_ops.conv3d(
297          input_tensor, filter_tensor, strides, padding, name="conv")
298
299      if test_input:
300        err = gradient_checker.compute_gradient_error(input_tensor, input_shape,
301                                                      conv, output_shape)
302      else:
303        err = gradient_checker.compute_gradient_error(filter_tensor,
304                                                      filter_shape, conv,
305                                                      output_shape)
306    print("conv3d gradient error = ", err)
307    self.assertLess(err, tolerance)
308
309  def testInputGradientValidPaddingStrideOne(self):
310    self.ConstructAndTestGradient(
311        batch=2,
312        input_planes=3,
313        input_rows=5,
314        input_cols=4,
315        filter_planes=3,
316        filter_rows=3,
317        filter_cols=3,
318        in_depth=2,
319        out_depth=3,
320        stride=1,
321        padding="VALID",
322        test_input=True)
323
324  def testFilterGradientValidPaddingStrideOne(self):
325    self.ConstructAndTestGradient(
326        batch=4,
327        input_planes=4,
328        input_rows=6,
329        input_cols=5,
330        filter_planes=2,
331        filter_rows=2,
332        filter_cols=2,
333        in_depth=2,
334        out_depth=3,
335        stride=1,
336        padding="VALID",
337        test_input=False)
338
339  def testInputGradientValidPaddingStrideTwo(self):
340    self.ConstructAndTestGradient(
341        batch=2,
342        input_planes=6,
343        input_rows=3,
344        input_cols=5,
345        filter_planes=3,
346        filter_rows=3,
347        filter_cols=3,
348        in_depth=2,
349        out_depth=3,
350        stride=2,
351        padding="VALID",
352        test_input=True)
353
354  def testFilterGradientValidPaddingStrideTwo(self):
355    self.ConstructAndTestGradient(
356        batch=2,
357        input_planes=7,
358        input_rows=6,
359        input_cols=5,
360        filter_planes=2,
361        filter_rows=2,
362        filter_cols=2,
363        in_depth=2,
364        out_depth=3,
365        stride=2,
366        padding="VALID",
367        test_input=False)
368
369  def testInputGradientValidPaddingStrideThree(self):
370    self.ConstructAndTestGradient(
371        batch=2,
372        input_planes=3,
373        input_rows=7,
374        input_cols=6,
375        filter_planes=3,
376        filter_rows=3,
377        filter_cols=3,
378        in_depth=2,
379        out_depth=3,
380        stride=3,
381        padding="VALID",
382        test_input=True)
383
384  def testFilterGradientValidPaddingStrideThree(self):
385    self.ConstructAndTestGradient(
386        batch=2,
387        input_planes=4,
388        input_rows=4,
389        input_cols=7,
390        filter_planes=4,
391        filter_rows=4,
392        filter_cols=4,
393        in_depth=2,
394        out_depth=3,
395        stride=3,
396        padding="VALID",
397        test_input=False)
398
399  def testInputGradientSamePaddingStrideOne(self):
400    self.ConstructAndTestGradient(
401        batch=2,
402        input_planes=3,
403        input_rows=2,
404        input_cols=2,
405        filter_planes=3,
406        filter_rows=2,
407        filter_cols=1,
408        in_depth=2,
409        out_depth=1,
410        stride=1,
411        padding="SAME",
412        test_input=True)
413
414  def testFilterGradientSamePaddingStrideOne(self):
415    self.ConstructAndTestGradient(
416        batch=2,
417        input_planes=3,
418        input_rows=6,
419        input_cols=5,
420        filter_planes=2,
421        filter_rows=2,
422        filter_cols=2,
423        in_depth=2,
424        out_depth=3,
425        stride=1,
426        padding="SAME",
427        test_input=False)
428
429  def testInputGradientSamePaddingStrideTwo(self):
430    self.ConstructAndTestGradient(
431        batch=2,
432        input_planes=6,
433        input_rows=3,
434        input_cols=4,
435        filter_planes=3,
436        filter_rows=3,
437        filter_cols=3,
438        in_depth=2,
439        out_depth=3,
440        stride=2,
441        padding="SAME",
442        test_input=True)
443
444  def testFilterGradientSamePaddingStrideTwo(self):
445    self.ConstructAndTestGradient(
446        batch=4,
447        input_planes=7,
448        input_rows=3,
449        input_cols=5,
450        filter_planes=2,
451        filter_rows=2,
452        filter_cols=2,
453        in_depth=2,
454        out_depth=3,
455        stride=2,
456        padding="SAME",
457        test_input=False)
458
459  def testInputGradientSamePaddingStrideThree(self):
460    self.ConstructAndTestGradient(
461        batch=2,
462        input_planes=9,
463        input_rows=3,
464        input_cols=6,
465        filter_planes=3,
466        filter_rows=3,
467        filter_cols=3,
468        in_depth=2,
469        out_depth=3,
470        stride=3,
471        padding="SAME",
472        test_input=True)
473
474  def testFilterGradientSamePaddingStrideThree(self):
475    self.ConstructAndTestGradient(
476        batch=2,
477        input_planes=9,
478        input_rows=4,
479        input_cols=7,
480        filter_planes=4,
481        filter_rows=4,
482        filter_cols=4,
483        in_depth=2,
484        out_depth=3,
485        stride=3,
486        padding="SAME",
487        test_input=False)
488
489  def testInputGradientSamePaddingDifferentStrides(self):
490    self.ConstructAndTestGradient(
491        batch=1,
492        input_planes=5,
493        input_rows=8,
494        input_cols=7,
495        filter_planes=1,
496        filter_rows=2,
497        filter_cols=3,
498        in_depth=2,
499        out_depth=3,
500        stride=[2, 3, 1],
501        padding="SAME",
502        test_input=True)
503
504  def disabledtestFilterGradientSamePaddingDifferentStrides(self):
505    self.ConstructAndTestGradient(
506        batch=1,
507        input_planes=5,
508        input_rows=8,
509        input_cols=7,
510        filter_planes=1,
511        filter_rows=2,
512        filter_cols=3,
513        in_depth=2,
514        out_depth=3,
515        stride=[2, 3, 1],
516        padding="SAME",
517        test_input=False)
518
519
520if __name__ == "__main__":
521  test.main()
522