1b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan#
3b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# Licensed under the Apache License, Version 2.0 (the "License");
4b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# you may not use this file except in compliance with the License.
5b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# You may obtain a copy of the License at
6b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan#
7b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan#     http://www.apache.org/licenses/LICENSE-2.0
8b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan#
9b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# Unless required by applicable law or agreed to in writing, software
10b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# distributed under the License is distributed on an "AS IS" BASIS,
11b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# See the License for the specific language governing permissions and
13b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# limitations under the License.
14b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan# ==============================================================================
15b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan"""Gradients for operators defined in spectral_ops.py."""
16b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom __future__ import absolute_import
17b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom __future__ import division
18b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom __future__ import print_function
19b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
20b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanimport numpy as np
21b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
22b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom tensorflow.python.framework import dtypes
23b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom tensorflow.python.framework import ops
24b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom tensorflow.python.ops import array_ops
25b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom tensorflow.python.ops import math_ops
26b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanfrom tensorflow.python.ops import spectral_ops
27b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
28b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
29b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _FFTSizeForGrad(grad, rank):
30b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  return math_ops.reduce_prod(array_ops.shape(grad)[-rank:])
31b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
32b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
33b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("FFT")
34b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _FFTGrad(_, grad):
35b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  size = math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32)
36b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  return spectral_ops.ifft(grad) * math_ops.complex(size, 0.)
37b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
38b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
39b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("IFFT")
40b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _IFFTGrad(_, grad):
41b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32)
42b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  return spectral_ops.fft(grad) * math_ops.complex(rsize, 0.)
43b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
44b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
45b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("FFT2D")
46b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _FFT2DGrad(_, grad):
47b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  size = math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32)
48b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  return spectral_ops.ifft2d(grad) * math_ops.complex(size, 0.)
49b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
50b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
51b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("IFFT2D")
52b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _IFFT2DGrad(_, grad):
53b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32)
54b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  return spectral_ops.fft2d(grad) * math_ops.complex(rsize, 0.)
55b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
56b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
57b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("FFT3D")
58b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _FFT3DGrad(_, grad):
59b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  size = math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32)
60b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  return spectral_ops.ifft3d(grad) * math_ops.complex(size, 0.)
61b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
62b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
63b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan@ops.RegisterGradient("IFFT3D")
64b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _IFFT3DGrad(_, grad):
65b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32)
66b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  return spectral_ops.fft3d(grad) * math_ops.complex(rsize, 0.)
67b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
68b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
69b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _RFFTGradHelper(rank, irfft_fn):
70b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  """Returns a gradient function for an RFFT of the provided rank."""
71b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  # Can't happen because we don't register a gradient for RFFT3D.
72b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  assert rank in (1, 2), "Gradient for RFFT3D is not implemented."
73b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
74b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  def _Grad(op, grad):
75b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    """A gradient function for RFFT with the provided `rank` and `irfft_fn`."""
76b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    fft_length = op.inputs[1]
77b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    input_shape = array_ops.shape(op.inputs[0])
78b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    is_even = math_ops.cast(1 - (fft_length[-1] % 2), dtypes.complex64)
79b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
80b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    def _TileForBroadcasting(matrix, t):
81b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      expanded = array_ops.reshape(
82b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan          matrix,
83b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan          array_ops.concat([
84b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan              array_ops.ones([array_ops.rank(t) - 2], dtypes.int32),
85b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan              array_ops.shape(matrix)
86b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan          ], 0))
87b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      return array_ops.tile(
88b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan          expanded, array_ops.concat([array_ops.shape(t)[:-2], [1, 1]], 0))
89b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
90b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    def _MaskMatrix(length):
91b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      # TODO(rjryan): Speed up computation of twiddle factors using the
92b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      # following recurrence relation and cache them across invocations of RFFT.
93b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      #
94b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      # t_n = exp(sqrt(-1) * pi * n^2 / line_len)
95b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      # for n = 0, 1,..., line_len-1.
96b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      # For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2
97b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      a = array_ops.tile(
98b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan          array_ops.expand_dims(math_ops.range(length), 0), (length, 1))
99b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      b = array_ops.transpose(a, [1, 0])
100b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      return math_ops.exp(-2j * np.pi * math_ops.cast(a * b, dtypes.complex64) /
101b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan                          math_ops.cast(length, dtypes.complex64))
102b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
103b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    def _YMMask(length):
104b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`."""
105b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      return math_ops.cast(1 - 2 * (math_ops.range(length) % 2),
106b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan                           dtypes.complex64)
107b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
108b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    y0 = grad[..., 0:1]
109b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    if rank == 1:
110b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      ym = grad[..., -1:]
111b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      extra_terms = y0 + is_even * ym * _YMMask(input_shape[-1])
112b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    elif rank == 2:
113b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      # Create a mask matrix for y0 and ym.
114b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      base_mask = _MaskMatrix(input_shape[-2])
115b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
116b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      # Tile base_mask to match y0 in shape so that we can batch-matmul the
117b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      # inner 2 dimensions.
118b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      tiled_mask = _TileForBroadcasting(base_mask, y0)
119b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
120b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      y0_term = math_ops.matmul(tiled_mask, math_ops.conj(y0))
121b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      extra_terms = y0_term
122b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
123b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      ym = grad[..., -1:]
124b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      ym_term = math_ops.matmul(tiled_mask, math_ops.conj(ym))
125b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
126b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      inner_dim = input_shape[-1]
127b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      ym_term = array_ops.tile(
128b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan          ym_term,
129b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan          array_ops.concat([
130b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan              array_ops.ones([array_ops.rank(grad) - 1], dtypes.int32),
131b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan              [inner_dim]
132b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan          ], 0)) * _YMMask(inner_dim)
133b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
134b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan      extra_terms += is_even * ym_term
135b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
136b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # The gradient of RFFT is the IRFFT of the incoming gradient times a scaling
137b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # factor, plus some additional terms to make up for the components dropped
138b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # due to Hermitian symmetry.
139b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    input_size = math_ops.to_float(_FFTSizeForGrad(op.inputs[0], rank))
140b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    irfft = irfft_fn(grad, fft_length)
141b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    return 0.5 * (irfft * input_size + math_ops.real(extra_terms)), None
142b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
143b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  return _Grad
144b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
145b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
146b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryandef _IRFFTGradHelper(rank, rfft_fn):
147b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  """Returns a gradient function for an IRFFT of the provided rank."""
148b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  # Can't happen because we don't register a gradient for IRFFT3D.
149b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  assert rank in (1, 2), "Gradient for IRFFT3D is not implemented."
150b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
151b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  def _Grad(op, grad):
152b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    """A gradient function for IRFFT with the provided `rank` and `rfft_fn`."""
153b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs
154b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the
155b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # graph we special-case the situation where the FFT length and last
156b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # dimension of the input are known at graph construction time.
157b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    fft_length = op.inputs[1]
158b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    is_odd = math_ops.mod(fft_length[-1], 2)
159b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    input_last_dimension = array_ops.shape(op.inputs[0])[-1]
160b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    mask = array_ops.concat(
161b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan        [[1.0], 2.0 * array_ops.ones([input_last_dimension - 2 + is_odd]),
162b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan         array_ops.ones([1 - is_odd])], 0)
163b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
164b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    rsize = math_ops.reciprocal(math_ops.to_float(_FFTSizeForGrad(grad, rank)))
165b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
166b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling
167b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # factor and a mask. The mask scales the gradient for the Hermitian
168b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # symmetric components of the RFFT by a factor of two, since these
169b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    # components are de-duplicated in the RFFT.
170b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    rfft = rfft_fn(grad, fft_length)
171b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan    return rfft * math_ops.cast(rsize * mask, dtypes.complex64), None
172b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
173b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan  return _Grad
174b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
175b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryan
176b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanops.RegisterGradient("RFFT")(_RFFTGradHelper(1, spectral_ops.irfft))
177b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanops.RegisterGradient("IRFFT")(_IRFFTGradHelper(1, spectral_ops.rfft))
178b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanops.RegisterGradient("RFFT2D")(_RFFTGradHelper(2, spectral_ops.irfft2d))
179b03a72c804d2e6ececcbe4fe4cd603edc9f8049dRJ Ryanops.RegisterGradient("IRFFT2D")(_IRFFTGradHelper(2, spectral_ops.rfft2d))
180