1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for fft operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gen_spectral_ops
30from tensorflow.python.ops import gradient_checker
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import spectral_ops
33from tensorflow.python.ops import spectral_ops_test_util
34from tensorflow.python.platform import test
35
36VALID_FFT_RANKS = (1, 2, 3)
37
38
39class BaseFFTOpsTest(test.TestCase):
40
41  def _compare(self, x, rank, fft_length=None, use_placeholder=False):
42    self._compareForward(x, rank, fft_length, use_placeholder)
43    self._compareBackward(x, rank, fft_length, use_placeholder)
44
45  def _compareForward(self, x, rank, fft_length=None, use_placeholder=False):
46    x_np = self._npFFT(x, rank, fft_length)
47    if use_placeholder:
48      x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
49      x_tf = self._tfFFT(x_ph, rank, fft_length, feed_dict={x_ph: x})
50    else:
51      x_tf = self._tfFFT(x, rank, fft_length)
52
53    self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4)
54
55  def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False):
56    x_np = self._npIFFT(x, rank, fft_length)
57    if use_placeholder:
58      x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype))
59      x_tf = self._tfIFFT(x_ph, rank, fft_length, feed_dict={x_ph: x})
60    else:
61      x_tf = self._tfIFFT(x, rank, fft_length)
62
63    self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4)
64
65  def _checkMemoryFail(self, x, rank):
66    config = config_pb2.ConfigProto()
67    config.gpu_options.per_process_gpu_memory_fraction = 1e-2
68    with self.test_session(config=config, force_gpu=True):
69      self._tfFFT(x, rank, fft_length=None)
70
71  def _checkGradComplex(self, func, x, y, result_is_complex=True):
72    with self.test_session(use_gpu=True):
73      inx = ops.convert_to_tensor(x)
74      iny = ops.convert_to_tensor(y)
75      # func is a forward or inverse, real or complex, batched or unbatched FFT
76      # function with a complex input.
77      z = func(math_ops.complex(inx, iny))
78      # loss = sum(|z|^2)
79      loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
80
81      ((x_jacob_t, x_jacob_n),
82       (y_jacob_t, y_jacob_n)) = gradient_checker.compute_gradient(
83           [inx, iny], [list(x.shape), list(y.shape)],
84           loss, [1],
85           x_init_value=[x, y],
86           delta=1e-2)
87
88    self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2)
89    self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=1e-2)
90
91  def _checkGradReal(self, func, x):
92    with self.test_session(use_gpu=True):
93      inx = ops.convert_to_tensor(x)
94      # func is a forward RFFT function (batched or unbatched).
95      z = func(inx)
96      # loss = sum(|z|^2)
97      loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z)))
98      x_jacob_t, x_jacob_n = test.compute_gradient(
99          inx, list(x.shape), loss, [1], x_init_value=x, delta=1e-2)
100
101    self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2)
102
103
104class FFTOpsTest(BaseFFTOpsTest):
105
106  def _tfFFT(self, x, rank, fft_length=None, feed_dict=None):
107    # fft_length unused for complex FFTs.
108    with self.test_session(use_gpu=True):
109      return self._tfFFTForRank(rank)(x).eval(feed_dict=feed_dict)
110
111  def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None):
112    # fft_length unused for complex FFTs.
113    with self.test_session(use_gpu=True):
114      return self._tfIFFTForRank(rank)(x).eval(feed_dict=feed_dict)
115
116  def _npFFT(self, x, rank, fft_length=None):
117    if rank == 1:
118      return np.fft.fft2(x, s=fft_length, axes=(-1,))
119    elif rank == 2:
120      return np.fft.fft2(x, s=fft_length, axes=(-2, -1))
121    elif rank == 3:
122      return np.fft.fft2(x, s=fft_length, axes=(-3, -2, -1))
123    else:
124      raise ValueError("invalid rank")
125
126  def _npIFFT(self, x, rank, fft_length=None):
127    if rank == 1:
128      return np.fft.ifft2(x, s=fft_length, axes=(-1,))
129    elif rank == 2:
130      return np.fft.ifft2(x, s=fft_length, axes=(-2, -1))
131    elif rank == 3:
132      return np.fft.ifft2(x, s=fft_length, axes=(-3, -2, -1))
133    else:
134      raise ValueError("invalid rank")
135
136  def _tfFFTForRank(self, rank):
137    if rank == 1:
138      return spectral_ops.fft
139    elif rank == 2:
140      return spectral_ops.fft2d
141    elif rank == 3:
142      return spectral_ops.fft3d
143    else:
144      raise ValueError("invalid rank")
145
146  def _tfIFFTForRank(self, rank):
147    if rank == 1:
148      return spectral_ops.ifft
149    elif rank == 2:
150      return spectral_ops.ifft2d
151    elif rank == 3:
152      return spectral_ops.ifft3d
153    else:
154      raise ValueError("invalid rank")
155
156  def testEmpty(self):
157    with spectral_ops_test_util.fft_kernel_label_map():
158      for rank in VALID_FFT_RANKS:
159        for dims in xrange(rank, rank + 3):
160          x = np.zeros((0,) * dims).astype(np.complex64)
161          self.assertEqual(x.shape, self._tfFFT(x, rank).shape)
162          self.assertEqual(x.shape, self._tfIFFT(x, rank).shape)
163
164  def testBasic(self):
165    with spectral_ops_test_util.fft_kernel_label_map():
166      for rank in VALID_FFT_RANKS:
167        for dims in xrange(rank, rank + 3):
168          self._compare(
169              np.mod(np.arange(np.power(4, dims)), 10).reshape(
170                  (4,) * dims).astype(np.complex64), rank)
171
172  def testLargeBatch(self):
173    if test.is_gpu_available(cuda_only=True):
174      rank = 1
175      for dims in xrange(rank, rank + 3):
176        self._compare(
177            np.mod(np.arange(np.power(128, dims)), 10).reshape(
178                (128,) * dims).astype(np.complex64), rank)
179
180  # TODO(yangzihao): Disable before we can figure out a way to
181  # properly test memory fail for large batch fft.
182  # def testLargeBatchMemoryFail(self):
183  #   if test.is_gpu_available(cuda_only=True):
184  #     rank = 1
185  #     for dims in xrange(rank, rank + 3):
186  #       self._checkMemoryFail(
187  #           np.mod(np.arange(np.power(128, dims)), 64).reshape(
188  #               (128,) * dims).astype(np.complex64), rank)
189
190  def testBasicPlaceholder(self):
191    with spectral_ops_test_util.fft_kernel_label_map():
192      for rank in VALID_FFT_RANKS:
193        for dims in xrange(rank, rank + 3):
194          self._compare(
195              np.mod(np.arange(np.power(4, dims)), 10).reshape(
196                  (4,) * dims).astype(np.complex64),
197              rank,
198              use_placeholder=True)
199
200  def testRandom(self):
201    with spectral_ops_test_util.fft_kernel_label_map():
202      np.random.seed(12345)
203
204      def gen(shape):
205        n = np.prod(shape)
206        re = np.random.uniform(size=n)
207        im = np.random.uniform(size=n)
208        return (re + im * 1j).reshape(shape)
209
210      for rank in VALID_FFT_RANKS:
211        for dims in xrange(rank, rank + 3):
212          self._compare(gen((4,) * dims), rank)
213
214  def testError(self):
215    for rank in VALID_FFT_RANKS:
216      for dims in xrange(0, rank):
217        x = np.zeros((1,) * dims).astype(np.complex64)
218        with self.assertRaisesWithPredicateMatch(
219            ValueError, "Shape must be .*rank {}.*".format(rank)):
220          self._tfFFT(x, rank)
221        with self.assertRaisesWithPredicateMatch(
222            ValueError, "Shape must be .*rank {}.*".format(rank)):
223          self._tfIFFT(x, rank)
224
225  def testGrad_Simple(self):
226    with spectral_ops_test_util.fft_kernel_label_map():
227      for rank in VALID_FFT_RANKS:
228        for dims in xrange(rank, rank + 2):
229          re = np.ones(shape=(4,) * dims, dtype=np.float32) / 10.0
230          im = np.zeros(shape=(4,) * dims, dtype=np.float32)
231          self._checkGradComplex(self._tfFFTForRank(rank), re, im)
232          self._checkGradComplex(self._tfIFFTForRank(rank), re, im)
233
234  def testGrad_Random(self):
235    with spectral_ops_test_util.fft_kernel_label_map():
236      np.random.seed(54321)
237      for rank in VALID_FFT_RANKS:
238        for dims in xrange(rank, rank + 2):
239          re = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1
240          im = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1
241          self._checkGradComplex(self._tfFFTForRank(rank), re, im)
242          self._checkGradComplex(self._tfIFFTForRank(rank), re, im)
243
244
245class RFFTOpsTest(BaseFFTOpsTest):
246
247  def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False):
248    super(RFFTOpsTest, self)._compareBackward(x, rank, fft_length,
249                                              use_placeholder)
250
251  def _tfFFT(self, x, rank, fft_length=None, feed_dict=None):
252    with self.test_session(use_gpu=True):
253      return self._tfFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict)
254
255  def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None):
256    with self.test_session(use_gpu=True):
257      return self._tfIFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict)
258
259  def _npFFT(self, x, rank, fft_length=None):
260    if rank == 1:
261      return np.fft.rfft2(x, s=fft_length, axes=(-1,))
262    elif rank == 2:
263      return np.fft.rfft2(x, s=fft_length, axes=(-2, -1))
264    elif rank == 3:
265      return np.fft.rfft2(x, s=fft_length, axes=(-3, -2, -1))
266    else:
267      raise ValueError("invalid rank")
268
269  def _npIFFT(self, x, rank, fft_length=None):
270    if rank == 1:
271      return np.fft.irfft2(x, s=fft_length, axes=(-1,))
272    elif rank == 2:
273      return np.fft.irfft2(x, s=fft_length, axes=(-2, -1))
274    elif rank == 3:
275      return np.fft.irfft2(x, s=fft_length, axes=(-3, -2, -1))
276    else:
277      raise ValueError("invalid rank")
278
279  def _tfFFTForRank(self, rank):
280    if rank == 1:
281      return spectral_ops.rfft
282    elif rank == 2:
283      return spectral_ops.rfft2d
284    elif rank == 3:
285      return spectral_ops.rfft3d
286    else:
287      raise ValueError("invalid rank")
288
289  def _tfIFFTForRank(self, rank):
290    if rank == 1:
291      return spectral_ops.irfft
292    elif rank == 2:
293      return spectral_ops.irfft2d
294    elif rank == 3:
295      return spectral_ops.irfft3d
296    else:
297      raise ValueError("invalid rank")
298
299  def testEmpty(self):
300    with spectral_ops_test_util.fft_kernel_label_map():
301      for rank in VALID_FFT_RANKS:
302        for dims in xrange(rank, rank + 3):
303          x = np.zeros((0,) * dims).astype(np.float32)
304          self.assertEqual(x.shape, self._tfFFT(x, rank).shape)
305          x = np.zeros((0,) * dims).astype(np.complex64)
306          self.assertEqual(x.shape, self._tfIFFT(x, rank).shape)
307
308  def testBasic(self):
309    with spectral_ops_test_util.fft_kernel_label_map():
310      for rank in VALID_FFT_RANKS:
311        for dims in xrange(rank, rank + 3):
312          for size in (5, 6):
313            inner_dim = size // 2 + 1
314            r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
315                (size,) * dims)
316            self._compareForward(r2c.astype(np.float32), rank, (size,) * rank)
317            c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
318                         10).reshape((size,) * (dims - 1) + (inner_dim,))
319            self._compareBackward(
320                c2r.astype(np.complex64), rank, (size,) * rank)
321
322  def testLargeBatch(self):
323    if test.is_gpu_available(cuda_only=True):
324      rank = 1
325      for dims in xrange(rank, rank + 3):
326        for size in (64, 128):
327          inner_dim = size // 2 + 1
328          r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
329              (size,) * dims)
330          self._compareForward(r2c.astype(np.float32), rank, (size,) * rank)
331          c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
332                       10).reshape((size,) * (dims - 1) + (inner_dim,))
333          self._compareBackward(c2r.astype(np.complex64), rank, (size,) * rank)
334
335  def testBasicPlaceholder(self):
336    with spectral_ops_test_util.fft_kernel_label_map():
337      for rank in VALID_FFT_RANKS:
338        for dims in xrange(rank, rank + 3):
339          for size in (5, 6):
340            inner_dim = size // 2 + 1
341            r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
342                (size,) * dims)
343            self._compareForward(
344                r2c.astype(np.float32),
345                rank, (size,) * rank,
346                use_placeholder=True)
347            c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
348                         10).reshape((size,) * (dims - 1) + (inner_dim,))
349            self._compareBackward(
350                c2r.astype(np.complex64),
351                rank, (size,) * rank,
352                use_placeholder=True)
353
354  def testFftLength(self):
355    if test.is_gpu_available(cuda_only=True):
356      with spectral_ops_test_util.fft_kernel_label_map():
357        for rank in VALID_FFT_RANKS:
358          for dims in xrange(rank, rank + 3):
359            for size in (5, 6):
360              inner_dim = size // 2 + 1
361              r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape(
362                  (size,) * dims)
363              c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim),
364                           10).reshape((size,) * (dims - 1) + (inner_dim,))
365              # Test truncation (FFT size < dimensions).
366              fft_length = (size - 2,) * rank
367              self._compareForward(r2c.astype(np.float32), rank, fft_length)
368              self._compareBackward(c2r.astype(np.complex64), rank, fft_length)
369              # Confirm it works with unknown shapes as well.
370              self._compareForward(
371                  r2c.astype(np.float32),
372                  rank,
373                  fft_length,
374                  use_placeholder=True)
375              self._compareBackward(
376                  c2r.astype(np.complex64),
377                  rank,
378                  fft_length,
379                  use_placeholder=True)
380              # Test padding (FFT size > dimensions).
381              fft_length = (size + 2,) * rank
382              self._compareForward(r2c.astype(np.float32), rank, fft_length)
383              self._compareBackward(c2r.astype(np.complex64), rank, fft_length)
384              # Confirm it works with unknown shapes as well.
385              self._compareForward(
386                  r2c.astype(np.float32),
387                  rank,
388                  fft_length,
389                  use_placeholder=True)
390              self._compareBackward(
391                  c2r.astype(np.complex64),
392                  rank,
393                  fft_length,
394                  use_placeholder=True)
395
396  def testRandom(self):
397    with spectral_ops_test_util.fft_kernel_label_map():
398      np.random.seed(12345)
399
400      def gen_real(shape):
401        n = np.prod(shape)
402        re = np.random.uniform(size=n)
403        ret = re.reshape(shape)
404        return ret
405
406      def gen_complex(shape):
407        n = np.prod(shape)
408        re = np.random.uniform(size=n)
409        im = np.random.uniform(size=n)
410        ret = (re + im * 1j).reshape(shape)
411        return ret
412
413      for rank in VALID_FFT_RANKS:
414        for dims in xrange(rank, rank + 3):
415          for size in (5, 6):
416            inner_dim = size // 2 + 1
417            self._compareForward(gen_real((size,) * dims), rank, (size,) * rank)
418            complex_dims = (size,) * (dims - 1) + (inner_dim,)
419            self._compareBackward(
420                gen_complex(complex_dims), rank, (size,) * rank)
421
422  def testError(self):
423    with spectral_ops_test_util.fft_kernel_label_map():
424      for rank in VALID_FFT_RANKS:
425        for dims in xrange(0, rank):
426          x = np.zeros((1,) * dims).astype(np.complex64)
427          with self.assertRaisesWithPredicateMatch(
428              ValueError, "Shape .* must have rank at least {}".format(rank)):
429            self._tfFFT(x, rank)
430          with self.assertRaisesWithPredicateMatch(
431              ValueError, "Shape .* must have rank at least {}".format(rank)):
432            self._tfIFFT(x, rank)
433        for dims in xrange(rank, rank + 2):
434          x = np.zeros((1,) * rank)
435
436          # Test non-rank-1 fft_length produces an error.
437          fft_length = np.zeros((1, 1)).astype(np.int32)
438          with self.assertRaisesWithPredicateMatch(ValueError,
439                                                   "Shape .* must have rank 1"):
440            self._tfFFT(x, rank, fft_length)
441          with self.assertRaisesWithPredicateMatch(ValueError,
442                                                   "Shape .* must have rank 1"):
443            self._tfIFFT(x, rank, fft_length)
444
445          # Test wrong fft_length length.
446          fft_length = np.zeros((rank + 1,)).astype(np.int32)
447          with self.assertRaisesWithPredicateMatch(
448              ValueError, "Dimension must be .*but is {}.*".format(rank + 1)):
449            self._tfFFT(x, rank, fft_length)
450          with self.assertRaisesWithPredicateMatch(
451              ValueError, "Dimension must be .*but is {}.*".format(rank + 1)):
452            self._tfIFFT(x, rank, fft_length)
453
454        # Test that calling the kernel directly without padding to fft_length
455        # produces an error.
456        rffts_for_rank = {
457            1: [gen_spectral_ops.rfft, gen_spectral_ops.irfft],
458            2: [gen_spectral_ops.rfft2d, gen_spectral_ops.irfft2d],
459            3: [gen_spectral_ops.rfft3d, gen_spectral_ops.irfft3d]
460        }
461        rfft_fn, irfft_fn = rffts_for_rank[rank]
462        with self.assertRaisesWithPredicateMatch(
463            errors.InvalidArgumentError,
464            "Input dimension .* must have length of at least 6 but got: 5"):
465          x = np.zeros((5,) * rank).astype(np.float32)
466          fft_length = [6] * rank
467          with self.test_session():
468            rfft_fn(x, fft_length).eval()
469
470        with self.assertRaisesWithPredicateMatch(
471            errors.InvalidArgumentError,
472            "Input dimension .* must have length of at least .* but got: 3"):
473          x = np.zeros((3,) * rank).astype(np.complex64)
474          fft_length = [6] * rank
475          with self.test_session():
476            irfft_fn(x, fft_length).eval()
477
478  def testGrad_Simple(self):
479    with spectral_ops_test_util.fft_kernel_label_map():
480      for rank in VALID_FFT_RANKS:
481        # rfft3d/irfft3d do not have gradients yet.
482        if rank == 3:
483          continue
484        for dims in xrange(rank, rank + 2):
485          for size in (5, 6):
486            re = np.ones(shape=(size,) * dims, dtype=np.float32)
487            im = -np.ones(shape=(size,) * dims, dtype=np.float32)
488            self._checkGradReal(self._tfFFTForRank(rank), re)
489            self._checkGradComplex(
490                self._tfIFFTForRank(rank), re, im, result_is_complex=False)
491
492  def testGrad_Random(self):
493    with spectral_ops_test_util.fft_kernel_label_map():
494      np.random.seed(54321)
495      for rank in VALID_FFT_RANKS:
496        # rfft3d/irfft3d do not have gradients yet.
497        if rank == 3:
498          continue
499        for dims in xrange(rank, rank + 2):
500          for size in (5, 6):
501            re = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
502            im = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1
503            self._checkGradReal(self._tfFFTForRank(rank), re)
504            self._checkGradComplex(
505                self._tfIFFTForRank(rank), re, im, result_is_complex=False)
506
507
508if __name__ == "__main__":
509  test.main()
510