1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/framework/shape_inference_testutil.h" 17#include "tensorflow/core/framework/tensor.h" 18#include "tensorflow/core/framework/tensor_testutil.h" 19 20namespace tensorflow { 21 22TEST(MathOpsTest, FFT_ShapeFn) { 23 for (const auto* op_name : {"FFT", "IFFT"}) { 24 ShapeInferenceTestOp op(op_name); 25 INFER_OK(op, "?", "in0"); 26 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]"); 27 INFER_OK(op, "[?]", "in0"); 28 INFER_OK(op, "[1]", "in0"); 29 INFER_OK(op, "[1,2,3,4,5,6,7]", "in0"); 30 } 31 32 for (const auto* op_name : {"FFT2D", "IFFT2D"}) { 33 ShapeInferenceTestOp op(op_name); 34 INFER_OK(op, "?", "in0"); 35 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]"); 36 INFER_OK(op, "[?,1]", "in0"); 37 INFER_OK(op, "[1,2]", "in0"); 38 INFER_OK(op, "[1,2,3,4,5,6,7]", "in0"); 39 } 40 41 for (const auto* op_name : {"FFT3D", "IFFT3D"}) { 42 ShapeInferenceTestOp op(op_name); 43 INFER_OK(op, "?", "in0"); 44 INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, "[1,2]"); 45 INFER_OK(op, "[?,1,?]", "in0"); 46 INFER_OK(op, "[1,2,3]", "in0"); 47 INFER_OK(op, "[1,2,3,4,5,6,7]", "in0"); 48 } 49} 50 51TEST(MathOpsTest, RFFT_ShapeFn) { 52 // Rank 1 53 for (const bool forward : {true, false}) { 54 ShapeInferenceTestOp op(forward ? "RFFT" : "IRFFT"); 55 56 // Unknown rank or shape of inputs. 57 INFER_OK(op, "?;?", "?"); 58 INFER_OK(op, "?;[1]", "?"); 59 60 // Unknown fft_length (whether or not rank/shape is known) implies unknown 61 // FFT shape. 62 INFER_OK(op, "[1];?", "[?]"); 63 INFER_OK(op, "[1];[1]", "[?]"); 64 INFER_OK(op, "[?];[1]", "[?]"); 65 66 // Batch dimensions preserved. 67 INFER_OK(op, "[1,2,3,4];[1]", "[d0_0,d0_1,d0_2,?]"); 68 69 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];?"); 70 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1];[1,1]"); 71 INFER_ERROR("Dimension must be 1 but is 2", op, "[1];[2]"); 72 73 // Tests with known values for fft_length input. 74 op.input_tensors.resize(2); 75 Tensor fft_length = test::AsTensor<int32>({10}); 76 op.input_tensors[1] = &fft_length; 77 78 // The inner-most dimension of the RFFT is n/2+1 while for IRFFT it's n. 79 if (forward) { 80 INFER_OK(op, "[?];[1]", "[6]"); 81 INFER_OK(op, "[1];[1]", "[6]"); 82 INFER_OK(op, "[1,1];[1]", "[d0_0,6]"); 83 } else { 84 INFER_OK(op, "[?];[1]", "[10]"); 85 INFER_OK(op, "[1];[1]", "[10]"); 86 INFER_OK(op, "[1,1];[1]", "[d0_0,10]"); 87 } 88 89 fft_length = test::AsTensor<int32>({11}); 90 if (forward) { 91 INFER_OK(op, "[?];[1]", "[6]"); 92 INFER_OK(op, "[1];[1]", "[6]"); 93 INFER_OK(op, "[1,1];[1]", "[d0_0,6]"); 94 } else { 95 INFER_OK(op, "[?];[1]", "[11]"); 96 INFER_OK(op, "[1];[1]", "[11]"); 97 INFER_OK(op, "[1,1];[1]", "[d0_0,11]"); 98 } 99 100 fft_length = test::AsTensor<int32>({12}); 101 if (forward) { 102 INFER_OK(op, "[?];[1]", "[7]"); 103 INFER_OK(op, "[1];[1]", "[7]"); 104 INFER_OK(op, "[1,1];[1]", "[d0_0,7]"); 105 } else { 106 INFER_OK(op, "[?];[1]", "[12]"); 107 INFER_OK(op, "[1];[1]", "[12]"); 108 INFER_OK(op, "[1,1];[1]", "[d0_0,12]"); 109 } 110 } 111 112 // Rank 2 113 for (const bool forward : {true, false}) { 114 ShapeInferenceTestOp op(forward ? "RFFT2D" : "IRFFT2D"); 115 116 // Unknown rank or shape of inputs. 117 INFER_OK(op, "?;?", "?"); 118 INFER_OK(op, "?;[2]", "?"); 119 120 // Unknown fft_length (whether or not rank/shape is known) implies unknown 121 // FFT shape. 122 INFER_OK(op, "[1,1];?", "[?,?]"); 123 INFER_OK(op, "[1,1];[2]", "[?,?]"); 124 INFER_OK(op, "[?,?];[2]", "[?,?]"); 125 126 // Batch dimensions preserved. 127 INFER_OK(op, "[1,2,3,4];[2]", "[d0_0,d0_1,?,?]"); 128 129 INFER_ERROR("Shape must be at least rank 2 but is rank 0", op, "[];?"); 130 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,1];[1,1]"); 131 INFER_ERROR("Dimension must be 2 but is 3", op, "[1,1];[3]"); 132 133 // Tests with known values for fft_length input. 134 op.input_tensors.resize(2); 135 Tensor fft_length = test::AsTensor<int32>({9, 10}); 136 op.input_tensors[1] = &fft_length; 137 138 // The inner-most dimension of the RFFT is n/2+1 while for IRFFT it's n. 139 if (forward) { 140 INFER_OK(op, "[?,?];[2]", "[9,6]"); 141 INFER_OK(op, "[1,1];[2]", "[9,6]"); 142 INFER_OK(op, "[1,1,1];[2]", "[d0_0,9,6]"); 143 } else { 144 INFER_OK(op, "[?,?];[2]", "[9,10]"); 145 INFER_OK(op, "[1,1];[2]", "[9,10]"); 146 INFER_OK(op, "[1,1,1];[2]", "[d0_0,9,10]"); 147 } 148 149 fft_length = test::AsTensor<int32>({10, 11}); 150 if (forward) { 151 INFER_OK(op, "[?,?];[2]", "[10,6]"); 152 INFER_OK(op, "[1,1];[2]", "[10,6]"); 153 INFER_OK(op, "[1,1,1];[2]", "[d0_0,10,6]"); 154 } else { 155 INFER_OK(op, "[?,?];[2]", "[10,11]"); 156 INFER_OK(op, "[1,1];[2]", "[10,11]"); 157 INFER_OK(op, "[1,1,1];[2]", "[d0_0,10,11]"); 158 } 159 160 fft_length = test::AsTensor<int32>({11, 12}); 161 if (forward) { 162 INFER_OK(op, "[?,?];[2]", "[11,7]"); 163 INFER_OK(op, "[1,1];[2]", "[11,7]"); 164 INFER_OK(op, "[1,1,1];[2]", "[d0_0,11,7]"); 165 } else { 166 INFER_OK(op, "[?,?];[2]", "[11,12]"); 167 INFER_OK(op, "[1,1];[2]", "[11,12]"); 168 INFER_OK(op, "[1,1,1];[2]", "[d0_0,11,12]"); 169 } 170 } 171 172 // Rank 3 173 for (const bool forward : {true, false}) { 174 ShapeInferenceTestOp op(forward ? "RFFT3D" : "IRFFT3D"); 175 176 // Unknown rank or shape of inputs. 177 INFER_OK(op, "?;?", "?"); 178 INFER_OK(op, "?;[3]", "?"); 179 180 // Unknown fft_length (whether or not rank/shape is known) implies unknown 181 // FFT shape. 182 INFER_OK(op, "[1,1,1];?", "[?,?,?]"); 183 INFER_OK(op, "[1,1,1];[3]", "[?,?,?]"); 184 INFER_OK(op, "[?,?,?];[3]", "[?,?,?]"); 185 186 // Batch dimensions preserved. 187 INFER_OK(op, "[1,2,3,4];[3]", "[d0_0,?,?,?]"); 188 189 INFER_ERROR("Shape must be at least rank 3 but is rank 0", op, "[];?"); 190 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,1,1];[1,1]"); 191 INFER_ERROR("Dimension must be 3 but is 4", op, "[1,1,1];[4]"); 192 193 // Tests with known values for fft_length input. 194 op.input_tensors.resize(2); 195 Tensor fft_length = test::AsTensor<int32>({10, 11, 12}); 196 op.input_tensors[1] = &fft_length; 197 198 // The inner-most dimension of the RFFT is n/2+1 while for IRFFT it's n. 199 if (forward) { 200 INFER_OK(op, "[?,?,?];[3]", "[10,11,7]"); 201 INFER_OK(op, "[1,1,1];[3]", "[10,11,7]"); 202 INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,10,11,7]"); 203 } else { 204 INFER_OK(op, "[?,?,?];[3]", "[10,11,12]"); 205 INFER_OK(op, "[1,1,1];[3]", "[10,11,12]"); 206 INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,10,11,12]"); 207 } 208 209 fft_length = test::AsTensor<int32>({11, 12, 13}); 210 if (forward) { 211 INFER_OK(op, "[?,?,?];[3]", "[11,12,7]"); 212 INFER_OK(op, "[1,1,1];[3]", "[11,12,7]"); 213 INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,11,12,7]"); 214 } else { 215 INFER_OK(op, "[?,?,?];[3]", "[11,12,13]"); 216 INFER_OK(op, "[1,1,1];[3]", "[11,12,13]"); 217 INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,11,12,13]"); 218 } 219 220 fft_length = test::AsTensor<int32>({12, 13, 14}); 221 if (forward) { 222 INFER_OK(op, "[?,?,?];[3]", "[12,13,8]"); 223 INFER_OK(op, "[1,1,1];[3]", "[12,13,8]"); 224 INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,12,13,8]"); 225 } else { 226 INFER_OK(op, "[?,?,?];[3]", "[12,13,14]"); 227 INFER_OK(op, "[1,1,1];[3]", "[12,13,14]"); 228 INFER_OK(op, "[1,1,1,1];[3]", "[d0_0,12,13,14]"); 229 } 230 } 231} 232 233} // end namespace tensorflow 234