1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for fft operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22from six.moves import xrange # pylint: disable=redefined-builtin 23 24from tensorflow.core.protobuf import config_pb2 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gen_spectral_ops 30from tensorflow.python.ops import gradient_checker 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import spectral_ops 33from tensorflow.python.ops import spectral_ops_test_util 34from tensorflow.python.platform import test 35 36VALID_FFT_RANKS = (1, 2, 3) 37 38 39class BaseFFTOpsTest(test.TestCase): 40 41 def _compare(self, x, rank, fft_length=None, use_placeholder=False): 42 self._compareForward(x, rank, fft_length, use_placeholder) 43 self._compareBackward(x, rank, fft_length, use_placeholder) 44 45 def _compareForward(self, x, rank, fft_length=None, use_placeholder=False): 46 x_np = self._npFFT(x, rank, fft_length) 47 if use_placeholder: 48 x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) 49 x_tf = self._tfFFT(x_ph, rank, fft_length, feed_dict={x_ph: x}) 50 else: 51 x_tf = self._tfFFT(x, rank, fft_length) 52 53 self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4) 54 55 def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False): 56 x_np = self._npIFFT(x, rank, fft_length) 57 if use_placeholder: 58 x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) 59 x_tf = self._tfIFFT(x_ph, rank, fft_length, feed_dict={x_ph: x}) 60 else: 61 x_tf = self._tfIFFT(x, rank, fft_length) 62 63 self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4) 64 65 def _checkMemoryFail(self, x, rank): 66 config = config_pb2.ConfigProto() 67 config.gpu_options.per_process_gpu_memory_fraction = 1e-2 68 with self.test_session(config=config, force_gpu=True): 69 self._tfFFT(x, rank, fft_length=None) 70 71 def _checkGradComplex(self, func, x, y, result_is_complex=True): 72 with self.test_session(use_gpu=True): 73 inx = ops.convert_to_tensor(x) 74 iny = ops.convert_to_tensor(y) 75 # func is a forward or inverse, real or complex, batched or unbatched FFT 76 # function with a complex input. 77 z = func(math_ops.complex(inx, iny)) 78 # loss = sum(|z|^2) 79 loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z))) 80 81 ((x_jacob_t, x_jacob_n), 82 (y_jacob_t, y_jacob_n)) = gradient_checker.compute_gradient( 83 [inx, iny], [list(x.shape), list(y.shape)], 84 loss, [1], 85 x_init_value=[x, y], 86 delta=1e-2) 87 88 self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2) 89 self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=1e-2) 90 91 def _checkGradReal(self, func, x): 92 with self.test_session(use_gpu=True): 93 inx = ops.convert_to_tensor(x) 94 # func is a forward RFFT function (batched or unbatched). 95 z = func(inx) 96 # loss = sum(|z|^2) 97 loss = math_ops.reduce_sum(math_ops.real(z * math_ops.conj(z))) 98 x_jacob_t, x_jacob_n = test.compute_gradient( 99 inx, list(x.shape), loss, [1], x_init_value=x, delta=1e-2) 100 101 self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2) 102 103 104class FFTOpsTest(BaseFFTOpsTest): 105 106 def _tfFFT(self, x, rank, fft_length=None, feed_dict=None): 107 # fft_length unused for complex FFTs. 108 with self.test_session(use_gpu=True): 109 return self._tfFFTForRank(rank)(x).eval(feed_dict=feed_dict) 110 111 def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None): 112 # fft_length unused for complex FFTs. 113 with self.test_session(use_gpu=True): 114 return self._tfIFFTForRank(rank)(x).eval(feed_dict=feed_dict) 115 116 def _npFFT(self, x, rank, fft_length=None): 117 if rank == 1: 118 return np.fft.fft2(x, s=fft_length, axes=(-1,)) 119 elif rank == 2: 120 return np.fft.fft2(x, s=fft_length, axes=(-2, -1)) 121 elif rank == 3: 122 return np.fft.fft2(x, s=fft_length, axes=(-3, -2, -1)) 123 else: 124 raise ValueError("invalid rank") 125 126 def _npIFFT(self, x, rank, fft_length=None): 127 if rank == 1: 128 return np.fft.ifft2(x, s=fft_length, axes=(-1,)) 129 elif rank == 2: 130 return np.fft.ifft2(x, s=fft_length, axes=(-2, -1)) 131 elif rank == 3: 132 return np.fft.ifft2(x, s=fft_length, axes=(-3, -2, -1)) 133 else: 134 raise ValueError("invalid rank") 135 136 def _tfFFTForRank(self, rank): 137 if rank == 1: 138 return spectral_ops.fft 139 elif rank == 2: 140 return spectral_ops.fft2d 141 elif rank == 3: 142 return spectral_ops.fft3d 143 else: 144 raise ValueError("invalid rank") 145 146 def _tfIFFTForRank(self, rank): 147 if rank == 1: 148 return spectral_ops.ifft 149 elif rank == 2: 150 return spectral_ops.ifft2d 151 elif rank == 3: 152 return spectral_ops.ifft3d 153 else: 154 raise ValueError("invalid rank") 155 156 def testEmpty(self): 157 with spectral_ops_test_util.fft_kernel_label_map(): 158 for rank in VALID_FFT_RANKS: 159 for dims in xrange(rank, rank + 3): 160 x = np.zeros((0,) * dims).astype(np.complex64) 161 self.assertEqual(x.shape, self._tfFFT(x, rank).shape) 162 self.assertEqual(x.shape, self._tfIFFT(x, rank).shape) 163 164 def testBasic(self): 165 with spectral_ops_test_util.fft_kernel_label_map(): 166 for rank in VALID_FFT_RANKS: 167 for dims in xrange(rank, rank + 3): 168 self._compare( 169 np.mod(np.arange(np.power(4, dims)), 10).reshape( 170 (4,) * dims).astype(np.complex64), rank) 171 172 def testLargeBatch(self): 173 if test.is_gpu_available(cuda_only=True): 174 rank = 1 175 for dims in xrange(rank, rank + 3): 176 self._compare( 177 np.mod(np.arange(np.power(128, dims)), 10).reshape( 178 (128,) * dims).astype(np.complex64), rank) 179 180 # TODO(yangzihao): Disable before we can figure out a way to 181 # properly test memory fail for large batch fft. 182 # def testLargeBatchMemoryFail(self): 183 # if test.is_gpu_available(cuda_only=True): 184 # rank = 1 185 # for dims in xrange(rank, rank + 3): 186 # self._checkMemoryFail( 187 # np.mod(np.arange(np.power(128, dims)), 64).reshape( 188 # (128,) * dims).astype(np.complex64), rank) 189 190 def testBasicPlaceholder(self): 191 with spectral_ops_test_util.fft_kernel_label_map(): 192 for rank in VALID_FFT_RANKS: 193 for dims in xrange(rank, rank + 3): 194 self._compare( 195 np.mod(np.arange(np.power(4, dims)), 10).reshape( 196 (4,) * dims).astype(np.complex64), 197 rank, 198 use_placeholder=True) 199 200 def testRandom(self): 201 with spectral_ops_test_util.fft_kernel_label_map(): 202 np.random.seed(12345) 203 204 def gen(shape): 205 n = np.prod(shape) 206 re = np.random.uniform(size=n) 207 im = np.random.uniform(size=n) 208 return (re + im * 1j).reshape(shape) 209 210 for rank in VALID_FFT_RANKS: 211 for dims in xrange(rank, rank + 3): 212 self._compare(gen((4,) * dims), rank) 213 214 def testError(self): 215 for rank in VALID_FFT_RANKS: 216 for dims in xrange(0, rank): 217 x = np.zeros((1,) * dims).astype(np.complex64) 218 with self.assertRaisesWithPredicateMatch( 219 ValueError, "Shape must be .*rank {}.*".format(rank)): 220 self._tfFFT(x, rank) 221 with self.assertRaisesWithPredicateMatch( 222 ValueError, "Shape must be .*rank {}.*".format(rank)): 223 self._tfIFFT(x, rank) 224 225 def testGrad_Simple(self): 226 with spectral_ops_test_util.fft_kernel_label_map(): 227 for rank in VALID_FFT_RANKS: 228 for dims in xrange(rank, rank + 2): 229 re = np.ones(shape=(4,) * dims, dtype=np.float32) / 10.0 230 im = np.zeros(shape=(4,) * dims, dtype=np.float32) 231 self._checkGradComplex(self._tfFFTForRank(rank), re, im) 232 self._checkGradComplex(self._tfIFFTForRank(rank), re, im) 233 234 def testGrad_Random(self): 235 with spectral_ops_test_util.fft_kernel_label_map(): 236 np.random.seed(54321) 237 for rank in VALID_FFT_RANKS: 238 for dims in xrange(rank, rank + 2): 239 re = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1 240 im = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1 241 self._checkGradComplex(self._tfFFTForRank(rank), re, im) 242 self._checkGradComplex(self._tfIFFTForRank(rank), re, im) 243 244 245class RFFTOpsTest(BaseFFTOpsTest): 246 247 def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False): 248 super(RFFTOpsTest, self)._compareBackward(x, rank, fft_length, 249 use_placeholder) 250 251 def _tfFFT(self, x, rank, fft_length=None, feed_dict=None): 252 with self.test_session(use_gpu=True): 253 return self._tfFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict) 254 255 def _tfIFFT(self, x, rank, fft_length=None, feed_dict=None): 256 with self.test_session(use_gpu=True): 257 return self._tfIFFTForRank(rank)(x, fft_length).eval(feed_dict=feed_dict) 258 259 def _npFFT(self, x, rank, fft_length=None): 260 if rank == 1: 261 return np.fft.rfft2(x, s=fft_length, axes=(-1,)) 262 elif rank == 2: 263 return np.fft.rfft2(x, s=fft_length, axes=(-2, -1)) 264 elif rank == 3: 265 return np.fft.rfft2(x, s=fft_length, axes=(-3, -2, -1)) 266 else: 267 raise ValueError("invalid rank") 268 269 def _npIFFT(self, x, rank, fft_length=None): 270 if rank == 1: 271 return np.fft.irfft2(x, s=fft_length, axes=(-1,)) 272 elif rank == 2: 273 return np.fft.irfft2(x, s=fft_length, axes=(-2, -1)) 274 elif rank == 3: 275 return np.fft.irfft2(x, s=fft_length, axes=(-3, -2, -1)) 276 else: 277 raise ValueError("invalid rank") 278 279 def _tfFFTForRank(self, rank): 280 if rank == 1: 281 return spectral_ops.rfft 282 elif rank == 2: 283 return spectral_ops.rfft2d 284 elif rank == 3: 285 return spectral_ops.rfft3d 286 else: 287 raise ValueError("invalid rank") 288 289 def _tfIFFTForRank(self, rank): 290 if rank == 1: 291 return spectral_ops.irfft 292 elif rank == 2: 293 return spectral_ops.irfft2d 294 elif rank == 3: 295 return spectral_ops.irfft3d 296 else: 297 raise ValueError("invalid rank") 298 299 def testEmpty(self): 300 with spectral_ops_test_util.fft_kernel_label_map(): 301 for rank in VALID_FFT_RANKS: 302 for dims in xrange(rank, rank + 3): 303 x = np.zeros((0,) * dims).astype(np.float32) 304 self.assertEqual(x.shape, self._tfFFT(x, rank).shape) 305 x = np.zeros((0,) * dims).astype(np.complex64) 306 self.assertEqual(x.shape, self._tfIFFT(x, rank).shape) 307 308 def testBasic(self): 309 with spectral_ops_test_util.fft_kernel_label_map(): 310 for rank in VALID_FFT_RANKS: 311 for dims in xrange(rank, rank + 3): 312 for size in (5, 6): 313 inner_dim = size // 2 + 1 314 r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( 315 (size,) * dims) 316 self._compareForward(r2c.astype(np.float32), rank, (size,) * rank) 317 c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), 318 10).reshape((size,) * (dims - 1) + (inner_dim,)) 319 self._compareBackward( 320 c2r.astype(np.complex64), rank, (size,) * rank) 321 322 def testLargeBatch(self): 323 if test.is_gpu_available(cuda_only=True): 324 rank = 1 325 for dims in xrange(rank, rank + 3): 326 for size in (64, 128): 327 inner_dim = size // 2 + 1 328 r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( 329 (size,) * dims) 330 self._compareForward(r2c.astype(np.float32), rank, (size,) * rank) 331 c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), 332 10).reshape((size,) * (dims - 1) + (inner_dim,)) 333 self._compareBackward(c2r.astype(np.complex64), rank, (size,) * rank) 334 335 def testBasicPlaceholder(self): 336 with spectral_ops_test_util.fft_kernel_label_map(): 337 for rank in VALID_FFT_RANKS: 338 for dims in xrange(rank, rank + 3): 339 for size in (5, 6): 340 inner_dim = size // 2 + 1 341 r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( 342 (size,) * dims) 343 self._compareForward( 344 r2c.astype(np.float32), 345 rank, (size,) * rank, 346 use_placeholder=True) 347 c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), 348 10).reshape((size,) * (dims - 1) + (inner_dim,)) 349 self._compareBackward( 350 c2r.astype(np.complex64), 351 rank, (size,) * rank, 352 use_placeholder=True) 353 354 def testFftLength(self): 355 if test.is_gpu_available(cuda_only=True): 356 with spectral_ops_test_util.fft_kernel_label_map(): 357 for rank in VALID_FFT_RANKS: 358 for dims in xrange(rank, rank + 3): 359 for size in (5, 6): 360 inner_dim = size // 2 + 1 361 r2c = np.mod(np.arange(np.power(size, dims)), 10).reshape( 362 (size,) * dims) 363 c2r = np.mod(np.arange(np.power(size, dims - 1) * inner_dim), 364 10).reshape((size,) * (dims - 1) + (inner_dim,)) 365 # Test truncation (FFT size < dimensions). 366 fft_length = (size - 2,) * rank 367 self._compareForward(r2c.astype(np.float32), rank, fft_length) 368 self._compareBackward(c2r.astype(np.complex64), rank, fft_length) 369 # Confirm it works with unknown shapes as well. 370 self._compareForward( 371 r2c.astype(np.float32), 372 rank, 373 fft_length, 374 use_placeholder=True) 375 self._compareBackward( 376 c2r.astype(np.complex64), 377 rank, 378 fft_length, 379 use_placeholder=True) 380 # Test padding (FFT size > dimensions). 381 fft_length = (size + 2,) * rank 382 self._compareForward(r2c.astype(np.float32), rank, fft_length) 383 self._compareBackward(c2r.astype(np.complex64), rank, fft_length) 384 # Confirm it works with unknown shapes as well. 385 self._compareForward( 386 r2c.astype(np.float32), 387 rank, 388 fft_length, 389 use_placeholder=True) 390 self._compareBackward( 391 c2r.astype(np.complex64), 392 rank, 393 fft_length, 394 use_placeholder=True) 395 396 def testRandom(self): 397 with spectral_ops_test_util.fft_kernel_label_map(): 398 np.random.seed(12345) 399 400 def gen_real(shape): 401 n = np.prod(shape) 402 re = np.random.uniform(size=n) 403 ret = re.reshape(shape) 404 return ret 405 406 def gen_complex(shape): 407 n = np.prod(shape) 408 re = np.random.uniform(size=n) 409 im = np.random.uniform(size=n) 410 ret = (re + im * 1j).reshape(shape) 411 return ret 412 413 for rank in VALID_FFT_RANKS: 414 for dims in xrange(rank, rank + 3): 415 for size in (5, 6): 416 inner_dim = size // 2 + 1 417 self._compareForward(gen_real((size,) * dims), rank, (size,) * rank) 418 complex_dims = (size,) * (dims - 1) + (inner_dim,) 419 self._compareBackward( 420 gen_complex(complex_dims), rank, (size,) * rank) 421 422 def testError(self): 423 with spectral_ops_test_util.fft_kernel_label_map(): 424 for rank in VALID_FFT_RANKS: 425 for dims in xrange(0, rank): 426 x = np.zeros((1,) * dims).astype(np.complex64) 427 with self.assertRaisesWithPredicateMatch( 428 ValueError, "Shape .* must have rank at least {}".format(rank)): 429 self._tfFFT(x, rank) 430 with self.assertRaisesWithPredicateMatch( 431 ValueError, "Shape .* must have rank at least {}".format(rank)): 432 self._tfIFFT(x, rank) 433 for dims in xrange(rank, rank + 2): 434 x = np.zeros((1,) * rank) 435 436 # Test non-rank-1 fft_length produces an error. 437 fft_length = np.zeros((1, 1)).astype(np.int32) 438 with self.assertRaisesWithPredicateMatch(ValueError, 439 "Shape .* must have rank 1"): 440 self._tfFFT(x, rank, fft_length) 441 with self.assertRaisesWithPredicateMatch(ValueError, 442 "Shape .* must have rank 1"): 443 self._tfIFFT(x, rank, fft_length) 444 445 # Test wrong fft_length length. 446 fft_length = np.zeros((rank + 1,)).astype(np.int32) 447 with self.assertRaisesWithPredicateMatch( 448 ValueError, "Dimension must be .*but is {}.*".format(rank + 1)): 449 self._tfFFT(x, rank, fft_length) 450 with self.assertRaisesWithPredicateMatch( 451 ValueError, "Dimension must be .*but is {}.*".format(rank + 1)): 452 self._tfIFFT(x, rank, fft_length) 453 454 # Test that calling the kernel directly without padding to fft_length 455 # produces an error. 456 rffts_for_rank = { 457 1: [gen_spectral_ops.rfft, gen_spectral_ops.irfft], 458 2: [gen_spectral_ops.rfft2d, gen_spectral_ops.irfft2d], 459 3: [gen_spectral_ops.rfft3d, gen_spectral_ops.irfft3d] 460 } 461 rfft_fn, irfft_fn = rffts_for_rank[rank] 462 with self.assertRaisesWithPredicateMatch( 463 errors.InvalidArgumentError, 464 "Input dimension .* must have length of at least 6 but got: 5"): 465 x = np.zeros((5,) * rank).astype(np.float32) 466 fft_length = [6] * rank 467 with self.test_session(): 468 rfft_fn(x, fft_length).eval() 469 470 with self.assertRaisesWithPredicateMatch( 471 errors.InvalidArgumentError, 472 "Input dimension .* must have length of at least .* but got: 3"): 473 x = np.zeros((3,) * rank).astype(np.complex64) 474 fft_length = [6] * rank 475 with self.test_session(): 476 irfft_fn(x, fft_length).eval() 477 478 def testGrad_Simple(self): 479 with spectral_ops_test_util.fft_kernel_label_map(): 480 for rank in VALID_FFT_RANKS: 481 # rfft3d/irfft3d do not have gradients yet. 482 if rank == 3: 483 continue 484 for dims in xrange(rank, rank + 2): 485 for size in (5, 6): 486 re = np.ones(shape=(size,) * dims, dtype=np.float32) 487 im = -np.ones(shape=(size,) * dims, dtype=np.float32) 488 self._checkGradReal(self._tfFFTForRank(rank), re) 489 self._checkGradComplex( 490 self._tfIFFTForRank(rank), re, im, result_is_complex=False) 491 492 def testGrad_Random(self): 493 with spectral_ops_test_util.fft_kernel_label_map(): 494 np.random.seed(54321) 495 for rank in VALID_FFT_RANKS: 496 # rfft3d/irfft3d do not have gradients yet. 497 if rank == 3: 498 continue 499 for dims in xrange(rank, rank + 2): 500 for size in (5, 6): 501 re = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1 502 im = np.random.rand(*((size,) * dims)).astype(np.float32) * 2 - 1 503 self._checkGradReal(self._tfFFTForRank(rank), re) 504 self._checkGradComplex( 505 self._tfIFFTForRank(rank), re, im, result_is_complex=False) 506 507 508if __name__ == "__main__": 509 test.main() 510