1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/shape_inference_testutil.h"
17#include "tensorflow/core/framework/tensor.h"
18#include "tensorflow/core/framework/tensor_testutil.h"
19
20namespace tensorflow {
21
22TEST(MathOpsTest, FFT_ShapeFn) {
23  for (const auto* op_name : {"FFT", "IFFT"}) {
24    ShapeInferenceTestOp op(op_name);
25    INFER_OK(op, "?", "in0");
26    INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
27    INFER_OK(op, "[?]", "in0");
28    INFER_OK(op, "[1]", "in0");
29    INFER_OK(op, "[1,2,3,4,5,6,7]", "in0");
30  }
31
32  for (const auto* op_name : {"FFT2D", "IFFT2D"}) {
33    ShapeInferenceTestOp op(op_name);
34    INFER_OK(op, "?", "in0");
35    INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
36    INFER_OK(op, "[?,1]", "in0");
37    INFER_OK(op, "[1,2]", "in0");
38    INFER_OK(op, "[1,2,3,4,5,6,7]", "in0");
39  }
40
41  for (const auto* op_name : {"FFT3D", "IFFT3D"}) {
42    ShapeInferenceTestOp op(op_name);
43    INFER_OK(op, "?", "in0");
44    INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, "[1,2]");
45    INFER_OK(op, "[?,1,?]", "in0");
46    INFER_OK(op, "[1,2,3]", "in0");
47    INFER_OK(op, "[1,2,3,4,5,6,7]", "in0");
48  }
49}
50
51TEST(MathOpsTest, RFFT_ShapeFn) {
52  // Rank 1
53  for (const bool forward : {true, false}) {
54    ShapeInferenceTestOp op(forward ? "RFFT" : "IRFFT");
55
56    // Unknown rank or shape of inputs.
57    INFER_OK(op, "?;?", "?");
58    INFER_OK(op, "?;[1]", "?");
59
60    // Unknown fft_length (whether or not rank/shape is known) implies unknown
61    // FFT shape.
62    INFER_OK(op, "[1];?", "[?]");
63    INFER_OK(op, "[1];[1]", "[?]");
64    INFER_OK(op, "[?];[1]", "[?]");
65
66    // Batch dimensions preserved.
67    INFER_OK(op, "[1,2,3,4];[1]", "[d0_0,d0_1,d0_2,?]");
68
69    INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];?");
70    INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1];[1,1]");
71    INFER_ERROR("Dimension must be 1 but is 2", op, "[1];[2]");
72
73    // Tests with known values for fft_length input.
74    op.input_tensors.resize(2);
75    Tensor fft_length = test::AsTensor<int32>({10});
76    op.input_tensors[1] = &fft_length;
77
78    // The inner-most dimension of the RFFT is n/2+1 while for IRFFT it's n.
79    if (forward) {
80      INFER_OK(op, "[?];[1]", "[6]");
81      INFER_OK(op, "[1];[1]", "[6]");
82      INFER_OK(op, "[1,1];[1]", "[d0_0,6]");
83    } else {
84      INFER_OK(op, "[?];[1]", "[10]");
85      INFER_OK(op, "[1];[1]", "[10]");
86      INFER_OK(op, "[1,1];[1]", "[d0_0,10]");
87    }
88
89    fft_length = test::AsTensor<int32>({11});
90    if (forward) {
91      INFER_OK(op, "[?];[1]", "[6]");
92      INFER_OK(op, "[1];[1]", "[6]");
93      INFER_OK(op, "[1,1];[1]", "[d0_0,6]");
94    } else {
95      INFER_OK(op, "[?];[1]", "[11]");
96      INFER_OK(op, "[1];[1]", "[11]");
97      INFER_OK(op, "[1,1];[1]", "[d0_0,11]");
98    }
99
100    fft_length = test::AsTensor<int32>({12});
101    if (forward) {
102      INFER_OK(op, "[?];[1]", "[7]");
103      INFER_OK(op, "[1];[1]", "[7]");
104      INFER_OK(op, "[1,1];[1]", "[d0_0,7]");
105    } else {
106      INFER_OK(op, "[?];[1]", "[12]");
107      INFER_OK(op, "[1];[1]", "[12]");
108      INFER_OK(op, "[1,1];[1]", "[d0_0,12]");
109    }
110  }
111
112  // Rank 2
113  for (const bool forward : {true, false}) {
114    ShapeInferenceTestOp op(forward ? "RFFT2D" : "IRFFT2D");
115
116    // Unknown rank or shape of inputs.
117    INFER_OK(op, "?;?", "?");
118    INFER_OK(op, "?;[2]", "?");
119
120    // Unknown fft_length (whether or not rank/shape is known) implies unknown
121    // FFT shape.
122    INFER_OK(op, "[1,1];?", "[?,?]");
123    INFER_OK(op, "[1,1];[2]", "[?,?]");
124    INFER_OK(op, "[?,?];[2]", "[?,?]");
125
126    // Batch dimensions preserved.
127    INFER_OK(op, "[1,2,3,4];[2]", "[d0_0,d0_1,?,?]");
128
129    INFER_ERROR("Shape must be at least rank 2 but is rank 0", op, "[];?");
130    INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,1];[1,1]");
131    INFER_ERROR("Dimension must be 2 but is 3", op, "[1,1];[3]");
132
133    // Tests with known values for fft_length input.
134    op.input_tensors.resize(2);
135    Tensor fft_length = test::AsTensor<int32>({9, 10});
136    op.input_tensors[1] = &fft_length;
137
138    // The inner-most dimension of the RFFT is n/2+1 while for IRFFT it's n.
139    if (forward) {
140      INFER_OK(op, "[?,?];[2]", "[9,6]");
141      INFER_OK(op, "[1,1];[2]", "[9,6]");
142      INFER_OK(op, "[1,1,1];[2]", "[d0_0,9,6]");
143    } else {
144      INFER_OK(op, "[?,?];[2]", "[9,10]");
145      INFER_OK(op, "[1,1];[2]", "[9,10]");
146      INFER_OK(op, "[1,1,1];[2]", "[d0_0,9,10]");
147    }
148
149    fft_length = test::AsTensor<int32>({10, 11});
150    if (forward) {
151      INFER_OK(op, "[?,?];[2]", "[10,6]");
152      INFER_OK(op, "[1,1];[2]", "[10,6]");
153      INFER_OK(op, "[1,1,1];[2]", "[d0_0,10,6]");
154    } else {
155      INFER_OK(op, "[?,?];[2]", "[10,11]");
156      INFER_OK(op, "[1,1];[2]", "[10,11]");
157      INFER_OK(op, "[1,1,1];[2]", "[d0_0,10,11]");
158    }
159
160    fft_length = test::AsTensor<int32>({11, 12});
161    if (forward) {
162      INFER_OK(op, "[?,?];[2]", "[11,7]");
163      INFER_OK(op, "[1,1];[2]", "[11,7]");
164      INFER_OK(op, "[1,1,1];[2]", "[d0_0,11,7]");
165    } else {
166      INFER_OK(op, "[?,?];[2]", "[11,12]");
167      INFER_OK(op, "[1,1];[2]", "[11,12]");
168      INFER_OK(op, "[1,1,1];[2]", "[d0_0,11,12]");
169    }
170  }
171
172  // Rank 3
173  for (const bool forward : {true, false}) {
174    ShapeInferenceTestOp op(forward ? "RFFT3D" : "IRFFT3D");
175
176    // Unknown rank or shape of inputs.
177    INFER_OK(op, "?;?", "?");
178    INFER_OK(op, "?;[3]", "?");
179
180    // Unknown fft_length (whether or not rank/shape is known) implies unknown
181    // FFT shape.
182    INFER_OK(op, "[1,1,1];?", "[?,?,?]");
183    INFER_OK(op, "[1,1,1];[3]", "[?,?,?]");
184    INFER_OK(op, "[?,?,?];[3]", "[?,?,?]");
185
186    // Batch dimensions preserved.
187    INFER_OK(op, "[1,2,3,4];[3]", "[d0_0,?,?,?]");
188
189    INFER_ERROR("Shape must be at least rank 3 but is rank 0", op, "[];?");
190    INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,1,1];[1,1]");
191    INFER_ERROR("Dimension must be 3 but is 4", op, "[1,1,1];[4]");
192
193    // Tests with known values for fft_length input.
194    op.input_tensors.resize(2);
195    Tensor fft_length = test::AsTensor<int32>({10, 11, 12});
196    op.input_tensors[1] = &fft_length;
197
198    // The inner-most dimension of the RFFT is n/2+1 while for IRFFT it's n.
199    if (forward) {
200      INFER_OK(op, "[?,?,?];[3]", "[10,11,7]");
201      INFER_OK(op, "[1,1,1];[3]", "[10,11,7]");
202      INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,10,11,7]");
203    } else {
204      INFER_OK(op, "[?,?,?];[3]", "[10,11,12]");
205      INFER_OK(op, "[1,1,1];[3]", "[10,11,12]");
206      INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,10,11,12]");
207    }
208
209    fft_length = test::AsTensor<int32>({11, 12, 13});
210    if (forward) {
211      INFER_OK(op, "[?,?,?];[3]", "[11,12,7]");
212      INFER_OK(op, "[1,1,1];[3]", "[11,12,7]");
213      INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,11,12,7]");
214    } else {
215      INFER_OK(op, "[?,?,?];[3]", "[11,12,13]");
216      INFER_OK(op, "[1,1,1];[3]", "[11,12,13]");
217      INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,11,12,13]");
218    }
219
220    fft_length = test::AsTensor<int32>({12, 13, 14});
221    if (forward) {
222      INFER_OK(op, "[?,?,?];[3]", "[12,13,8]");
223      INFER_OK(op, "[1,1,1];[3]", "[12,13,8]");
224      INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,12,13,8]");
225    } else {
226      INFER_OK(op, "[?,?,?];[3]", "[12,13,14]");
227      INFER_OK(op, "[1,1,1];[3]", "[12,13,14]");
228      INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,12,13,14]");
229    }
230  }
231}
232
233}  // end namespace tensorflow
234