1b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# 3b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# Licensed under the Apache License, Version 2.0 (the "License"); 4b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# you may not use this file except in compliance with the License. 5b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# You may obtain a copy of the License at 6b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# 7b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# http://www.apache.org/licenses/LICENSE-2.0 8b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# 9b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# Unless required by applicable law or agreed to in writing, software 10b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# distributed under the License is distributed on an "AS IS" BASIS, 11b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# See the License for the specific language governing permissions and 13b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# limitations under the License. 14b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# ============================================================================== 15b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan"""Gradients for operators defined in spectral_ops.py.""" 16b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom __future__ import absolute_import 17b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom __future__ import division 18b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom __future__ import print_function 19b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 20b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanimport numpy as np 21b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 22b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom tensorflow.python.framework import dtypes 23b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom tensorflow.python.framework import ops 24b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom tensorflow.python.ops import array_ops 25b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom tensorflow.python.ops import math_ops 26b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom tensorflow.python.ops import spectral_ops 27b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 28b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 29b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _FFTSizeForGrad(grad, rank): 30b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return math_ops.reduce_prod(array_ops.shape(grad)[-rank:]) 31b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 32b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 33b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("FFT") 34b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _FFTGrad(_, grad): 35b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan size = math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32) 36b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return spectral_ops.ifft(grad) * math_ops.complex(size, 0.) 37b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 38b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 39b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("IFFT") 40b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _IFFTGrad(_, grad): 41b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32) 42b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return spectral_ops.fft(grad) * math_ops.complex(rsize, 0.) 43b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 44b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 45b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("FFT2D") 46b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _FFT2DGrad(_, grad): 47b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan size = math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32) 48b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return spectral_ops.ifft2d(grad) * math_ops.complex(size, 0.) 49b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 50b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 51b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("IFFT2D") 52b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _IFFT2DGrad(_, grad): 53b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32) 54b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return spectral_ops.fft2d(grad) * math_ops.complex(rsize, 0.) 55b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 56b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 57b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("FFT3D") 58b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _FFT3DGrad(_, grad): 59b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan size = math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32) 60b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return spectral_ops.ifft3d(grad) * math_ops.complex(size, 0.) 61b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 62b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 63b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("IFFT3D") 64b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _IFFT3DGrad(_, grad): 65b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32) 66b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return spectral_ops.fft3d(grad) * math_ops.complex(rsize, 0.) 67b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 68b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 69b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _RFFTGradHelper(rank, irfft_fn): 70b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan """Returns a gradient function for an RFFT of the provided rank.""" 71b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # Can't happen because we don't register a gradient for RFFT3D. 72b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan assert rank in (1, 2), "Gradient for RFFT3D is not implemented." 73b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 74b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan def _Grad(op, grad): 75b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan """A gradient function for RFFT with the provided `rank` and `irfft_fn`.""" 76b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan fft_length = op.inputs[1] 77b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan input_shape = array_ops.shape(op.inputs[0]) 78b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan is_even = math_ops.cast(1 - (fft_length[-1] % 2), dtypes.complex64) 79b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 80b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan def _TileForBroadcasting(matrix, t): 81b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan expanded = array_ops.reshape( 82b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan matrix, 83b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan array_ops.concat([ 84b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan array_ops.ones([array_ops.rank(t) - 2], dtypes.int32), 85b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan array_ops.shape(matrix) 86b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan ], 0)) 87b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return array_ops.tile( 88b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan expanded, array_ops.concat([array_ops.shape(t)[:-2], [1, 1]], 0)) 89b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 90b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan def _MaskMatrix(length): 91b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # TODO(rjryan): Speed up computation of twiddle factors using the 92b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # following recurrence relation and cache them across invocations of RFFT. 93b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # 94b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # t_n = exp(sqrt(-1) * pi * n^2 / line_len) 95b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # for n = 0, 1,..., line_len-1. 96b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2 97b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan a = array_ops.tile( 98b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan array_ops.expand_dims(math_ops.range(length), 0), (length, 1)) 99b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan b = array_ops.transpose(a, [1, 0]) 100b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return math_ops.exp(-2j * np.pi * math_ops.cast(a * b, dtypes.complex64) / 101b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan math_ops.cast(length, dtypes.complex64)) 102b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 103b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan def _YMMask(length): 104b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`.""" 105b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return math_ops.cast(1 - 2 * (math_ops.range(length) % 2), 106b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan dtypes.complex64) 107b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 108b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan y0 = grad[..., 0:1] 109b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan if rank == 1: 110b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan ym = grad[..., -1:] 111b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan extra_terms = y0 + is_even * ym * _YMMask(input_shape[-1]) 112b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan elif rank == 2: 113b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # Create a mask matrix for y0 and ym. 114b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan base_mask = _MaskMatrix(input_shape[-2]) 115b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 116b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # Tile base_mask to match y0 in shape so that we can batch-matmul the 117b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # inner 2 dimensions. 118b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan tiled_mask = _TileForBroadcasting(base_mask, y0) 119b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 120b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan y0_term = math_ops.matmul(tiled_mask, math_ops.conj(y0)) 121b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan extra_terms = y0_term 122b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 123b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan ym = grad[..., -1:] 124b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan ym_term = math_ops.matmul(tiled_mask, math_ops.conj(ym)) 125b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 126b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan inner_dim = input_shape[-1] 127b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan ym_term = array_ops.tile( 128b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan ym_term, 129b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan array_ops.concat([ 130b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan array_ops.ones([array_ops.rank(grad) - 1], dtypes.int32), 131b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan [inner_dim] 132b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan ], 0)) * _YMMask(inner_dim) 133b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 134b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan extra_terms += is_even * ym_term 135b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 136b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # The gradient of RFFT is the IRFFT of the incoming gradient times a scaling 137b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # factor, plus some additional terms to make up for the components dropped 138b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # due to Hermitian symmetry. 139b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan input_size = math_ops.to_float(_FFTSizeForGrad(op.inputs[0], rank)) 140b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan irfft = irfft_fn(grad, fft_length) 141b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return 0.5 * (irfft * input_size + math_ops.real(extra_terms)), None 142b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 143b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return _Grad 144b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 145b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 146b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _IRFFTGradHelper(rank, rfft_fn): 147b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan """Returns a gradient function for an IRFFT of the provided rank.""" 148b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # Can't happen because we don't register a gradient for IRFFT3D. 149b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan assert rank in (1, 2), "Gradient for IRFFT3D is not implemented." 150b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 151b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan def _Grad(op, grad): 152b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan """A gradient function for IRFFT with the provided `rank` and `rfft_fn`.""" 153b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs 154b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the 155b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # graph we special-case the situation where the FFT length and last 156b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # dimension of the input are known at graph construction time. 157b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan fft_length = op.inputs[1] 158b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan is_odd = math_ops.mod(fft_length[-1], 2) 159b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan input_last_dimension = array_ops.shape(op.inputs[0])[-1] 160b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan mask = array_ops.concat( 161b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan [[1.0], 2.0 * array_ops.ones([input_last_dimension - 2 + is_odd]), 162b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan array_ops.ones([1 - is_odd])], 0) 163b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 164b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan rsize = math_ops.reciprocal(math_ops.to_float(_FFTSizeForGrad(grad, rank))) 165b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 166b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling 167b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # factor and a mask. The mask scales the gradient for the Hermitian 168b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # symmetric components of the RFFT by a factor of two, since these 169b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan # components are de-duplicated in the RFFT. 170b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan rfft = rfft_fn(grad, fft_length) 171b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return rfft * math_ops.cast(rsize * mask, dtypes.complex64), None 172b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 173b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan return _Grad 174b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 175b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan 176b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanops.RegisterGradient("RFFT")(_RFFTGradHelper(1, spectral_ops.irfft)) 177b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanops.RegisterGradient("IRFFT")(_IRFFTGradHelper(1, spectral_ops.rfft)) 178b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanops.RegisterGradient("RFFT2D")(_RFFTGradHelper(2, spectral_ops.irfft2d)) 179b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanops.RegisterGradient("IRFFT2D")(_IRFFTGradHelper(2, spectral_ops.rfft2d)) 180