math_ops_test.cc revision 7d9c0c891d82fb5d35dc4669abe832708940a810
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/node_def_builder.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference_testutil.h"
19#include "tensorflow/core/framework/tensor.h"
20#include "tensorflow/core/framework/tensor_testutil.h"
21#include "tensorflow/core/lib/core/status_test_util.h"
22#include "tensorflow/core/platform/test.h"
23
24namespace tensorflow {
25
26TEST(MathOpsTest, AddN_ShapeFn) {
27  ShapeInferenceTestOp op("AddN");
28  auto set_n = [&op](int n) {
29    std::vector<NodeDefBuilder::NodeOut> src_list;
30    for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
31    TF_ASSERT_OK(NodeDefBuilder("test", "AddN")
32                     .Input(src_list)
33                     .Attr("N", n)
34                     .Finalize(&op.node_def));
35  };
36
37  set_n(2);
38  // Adding two unknowns returns either input.
39  INFER_OK(op, "?;?", "in0|in1");
40
41  // known+unknown returns the known input.
42  INFER_OK(op, "[1];[?]", "in0");
43  INFER_OK(op, "[1];?", "in0");
44  INFER_OK(op, "[?];[1]", "in1");
45  INFER_OK(op, "?;[1]", "in1");
46
47  set_n(2);
48  INFER_OK(op, "[1,2];[?,2]", "in0");
49  INFER_OK(op, "[1,2];[1,2]", "in0|in1");
50  INFER_OK(op, "[?,2];[1,2]", "in1");
51
52  set_n(3);
53  INFER_OK(op, "[1,?];[?,2];[1,2]", "in2");
54  INFER_OK(op, "[1,2];[?,2];[1,?]", "in0");
55  INFER_OK(op, "?;?;[1,2]", "in2");
56
57  set_n(2);
58  INFER_OK(op, "?;[1,2]", "in1");
59  INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]");
60  INFER_OK(op, "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]");
61  INFER_OK(op, "[?,2];[1,?]", "[d1_0,d0_1]");
62
63  set_n(3);
64  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 2 and 4", op,
65              "[1,2];?;[1,4]");
66  INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2];?;[1,4]");
67  set_n(4);
68  INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op,
69              "?;[1,2];?;[1,2,3]");
70  INFER_ERROR("From merging shape 1 with other shapes.", op,
71              "?;[1,2];?;[1,2,3]");
72}
73
74TEST(MathOpsTest, UnchangedShape_ShapeFn) {
75  ShapeInferenceTestOp op("Cast");
76  INFER_OK(op, "?", "in0");
77  INFER_OK(op, "[?]", "in0");
78  INFER_OK(op, "[1,?,3,4]", "in0");
79}
80
81TEST(MathOpsTest, FFT_ShapeFn) {
82  for (const auto* op_name : {"FFT", "IFFT"}) {
83    ShapeInferenceTestOp op(op_name);
84    INFER_OK(op, "?", "?");
85    INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
86    INFER_OK(op, "[?]", "in0");
87    INFER_OK(op, "[1]", "in0");
88    INFER_OK(op, "[1,2,3,4,5,6,7]", "in0");
89  }
90
91  for (const auto* op_name : {"FFT2D", "IFFT2D"}) {
92    ShapeInferenceTestOp op(op_name);
93    INFER_OK(op, "?", "?");
94    INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
95    INFER_OK(op, "[?,1]", "in0");
96    INFER_OK(op, "[1,2]", "in0");
97    INFER_OK(op, "[1,2,3,4,5,6,7]", "in0");
98  }
99
100  for (const auto* op_name : {"FFT3D", "IFFT3D"}) {
101    ShapeInferenceTestOp op(op_name);
102    INFER_OK(op, "?", "?");
103    INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, "[1,2]");
104    INFER_OK(op, "[?,1,?]", "in0");
105    INFER_OK(op, "[1,2,3]", "in0");
106    INFER_OK(op, "[1,2,3,4,5,6,7]", "in0");
107  }
108}
109
110TEST(MathOpsTest, Segment_ShapeFn) {
111  // Tests SegmentReductionShapeFn.
112  for (const auto* op_name : {"SegmentMax", "SegmentMean", "SegmentMin",
113                              "SegmentProd", "SegmentSum"}) {
114    ShapeInferenceTestOp op(op_name);
115    INFER_OK(op, "?;?", "?");
116    INFER_OK(op, "?;[100]", "?");
117
118    // Data shape with single dimension.
119    INFER_OK(op, "[?];?", "[?]");
120    INFER_OK(op, "[?];[100]", "[?]");
121    INFER_OK(op, "[1];?", "[?]");
122    INFER_OK(op, "[1];[100]", "[?]");
123
124    // Data shape with multiple dimensions.
125    INFER_OK(op, "[?,?];?", "[?,d0_1]");
126    INFER_OK(op, "[?,2];[100]", "[?,d0_1]");
127    INFER_OK(op, "[?,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
128    INFER_OK(op, "[1,?];?", "[?,d0_1]");
129    INFER_OK(op, "[1,2];[100]", "[?,d0_1]");
130    INFER_OK(op, "[1,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
131
132    // Error cases.
133    INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
134    INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1]");
135  }
136}
137
138TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
139  for (const auto* op_name : {"Add",        "Complex",
140                              "Div",        "Equal",
141                              "Greater",    "GreaterEqual",
142                              "Igamma",     "Igammac",
143                              "Zeta",       "Polygamma",
144                              "Less",       "LessEqual",
145                              "LogicalAnd", "LogicalOr",
146                              "Maximum",    "Minimum",
147                              "Mod",        "Mul",
148                              "NotEqual",   "Pow",
149                              "Sub",        "SquaredDifference"}) {
150    ShapeInferenceTestOp op(op_name);
151    INFER_OK(op, "?;?", "?");
152    INFER_OK(op, "[1,2];?", "?");
153    INFER_OK(op, "?;[1,2]", "?");
154
155    INFER_OK(op, "[?];[1]", "[d0_0]");
156    INFER_OK(op, "[1];[?]", "[d1_0]");
157    INFER_OK(op, "[?];[2]", "[d1_0]");
158    INFER_OK(op, "[2];[?]", "[d0_0]");
159    INFER_OK(op, "[?];[?]", "[?]");
160    INFER_OK(op, "[];[?]", "[d1_0]");
161    INFER_OK(op, "[?];[]", "[d0_0]");
162
163    INFER_OK(op, "[1];[1]", "[d0_0|d1_0]");
164    INFER_OK(op, "[];[1]", "[d1_0]");
165    INFER_OK(op, "[1];[]", "[d0_0]");
166
167    INFER_OK(op, "[2];[2]", "[d0_0|d1_0]");
168    INFER_OK(op, "[];[2]", "[d1_0]");
169    INFER_OK(op, "[1];[2]", "[d1_0]");
170    INFER_OK(op, "[2];[1]", "[d0_0]");
171    INFER_OK(op, "[2];[]", "[d0_0]");
172
173    INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
174    INFER_OK(op, "[];[0]", "[d1_0]");
175    INFER_OK(op, "[1];[0]", "[d1_0]");
176    INFER_OK(op, "[0];[1]", "[d0_0]");
177    INFER_OK(op, "[0];[]", "[d0_0]");
178
179    // Multiple dimension cases (same test cases, switching x and y).
180    INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
181             "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]");
182    INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
183             "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]");
184  }
185}
186
187TEST(MathOpsTest, Select_ShapeFn) {
188  ShapeInferenceTestOp op("Select");
189  INFER_OK(op, "?;?;?", "in1|in2");
190
191  INFER_OK(op, "[];?;?", "in1|in2");
192  INFER_OK(op, "[1];?;?",
193           "in1|in2");  // When cond is vector, t/e may not match it.
194  INFER_OK(op, "[1,2];?;?", "in1|in2?");
195
196  INFER_OK(op, "?;[];?", "in1");
197  INFER_OK(op, "?;?;[]", "in2");
198  INFER_OK(op, "?;[1];?", "in1");
199  INFER_OK(op, "?;?;[1]", "in2");
200  INFER_OK(op, "?;[1,2];?", "in1");
201  INFER_OK(op, "?;?;[1,2]", "in2");
202
203  INFER_OK(op, "[1];[];?", "in1");
204  INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[];[1];?");
205  INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?");
206  INFER_OK(op, "[2];[?];[?]", "in1|in2");
207
208  INFER_OK(op, "[?];[?,?,3];[1,2,?]", "[d2_0,d2_1,d1_2]");
209  INFER_OK(op, "[2];[?,?,3];[?,2,?]", "[d1_0|d2_0,d2_1,d1_2]");
210  INFER_ERROR("must be equal", op, "[1];[2,?,3];[?,2,?]");
211  INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op,
212              "[2,?];[?,?,3];[?,2,?]");
213  INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]");
214  INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
215              "[2,?,5];[?,?,3];[?,2,?]");
216}
217
218TEST(MathOpsTest, Range_ShapeFn) {
219  ShapeInferenceTestOp op("Range");
220  op.input_tensors.resize(3);
221  INFER_OK(op, "?;?;?", "[?]");
222  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
223  INFER_ERROR("for 'start'", op, "[1,2];?;?");
224
225  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
226  INFER_ERROR("for 'limit'", op, "?;[1,2];?");
227
228  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
229  INFER_ERROR("for 'delta'", op, "?;?;[1,2]");
230
231  Tensor start_t = test::AsScalar(1);
232  op.input_tensors[0] = &start_t;
233  INFER_OK(op, "?;?;?", "[?]");
234  Tensor limit_t = test::AsScalar(1);
235  op.input_tensors[1] = &limit_t;
236  INFER_OK(op, "?;?;?", "[?]");
237
238  Tensor delta_t = test::AsScalar(1);
239  op.input_tensors[2] = &delta_t;
240  INFER_OK(op, "?;?;?", "[0]");
241
242  delta_t = test::AsScalar(0);
243  INFER_ERROR("Requires delta > 0: 0", op, "?;?;?");
244  delta_t = test::AsScalar(3);
245
246  limit_t = test::AsScalar(-1);
247  INFER_ERROR("Requires start <= limit: 1/-1", op, "?;?;?");
248
249  limit_t = test::AsScalar(100);
250  start_t = test::AsScalar(2);
251  delta_t = test::AsScalar(3);
252  INFER_OK(op, "?;?;?", "[33]");
253}
254
255TEST(MathOpsTest, LinSpace_ShapeFn) {
256  ShapeInferenceTestOp op("LinSpace");
257  op.input_tensors.resize(3);
258  INFER_OK(op, "?;?;?", "[?]");
259  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
260  INFER_ERROR("for 'start'", op, "[1,2];?;?");
261  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
262  INFER_ERROR("for 'stop'", op, "?;[1,2];?");
263  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
264  INFER_ERROR("for 'num'", op, "?;?;[1,2]");
265
266  Tensor num_t = test::AsScalar(1);
267  op.input_tensors[2] = &num_t;
268  INFER_OK(op, "?;?;?", "[1]");
269  num_t = test::AsScalar(2);
270  INFER_OK(op, "?;?;?", "[2]");
271  num_t = test::AsScalar(-1);
272  INFER_ERROR("Requires num > 0: -1", op, "?;?;?");
273}
274
275TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) {
276  ShapeInferenceTestOp op("UnsortedSegmentSum");
277  op.input_tensors.resize(3);
278  INFER_OK(op, "?;?;?", "?");
279  INFER_OK(op, "?;[?];?", "?");
280  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
281  INFER_ERROR("Dimensions must be equal, but are 2 and 3", op,
282              "[1,?,2];[1,?,3];?");
283  INFER_OK(op, "?;[3];?", "?");
284  INFER_ERROR("Shape must be at least rank 3 but is rank 2", op,
285              "[1,2];[1,2,3];?");
286
287  Tensor num_segments_t = test::AsScalar(100);
288  op.input_tensors[2] = &num_segments_t;
289  INFER_OK(op, "[?,2,3,?,5];[1,2,?];[]", "[100,d0_3,d0_4]");
290
291  num_segments_t = test::AsScalar(-1);
292  INFER_ERROR(("Dimension size, given by scalar input 2, must be "
293               "non-negative but is -1"),
294              op, "[3];[3];?");
295}
296
297TEST(MathOpsTest, SparseSegment_ShapeFn) {
298  ShapeInferenceTestOp op("SparseSegmentSum");
299  op.input_tensors.resize(3);
300  INFER_OK(op, "?;?;?", "?");
301  INFER_OK(op, "[2,4,3];[3];[3]", "[?,d0_1,d0_2]");
302
303  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2,4,3];[];[3]");
304  INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,4,3];[3];[3,4]");
305
306  INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 4", op,
307              "[2,4,3];[3];[4]");
308}
309
310TEST(MathOpsTest, SparseSegmentGrad_ShapeFn) {
311  ShapeInferenceTestOp op("SparseSegmentMeanGrad");
312  op.input_tensors.resize(4);
313  INFER_OK(op, "?;?;?;?", "?");
314  INFER_OK(op, "[2,4,3];[3];[3];[]", "[?,d0_1,d0_2]");
315
316  Tensor num_segments_t = test::AsScalar(100);
317  op.input_tensors[3] = &num_segments_t;
318  INFER_OK(op, "[2,4,3];[3];[3];[]", "[100,d0_1,d0_2]");
319
320  INFER_ERROR("Shape must be rank 0 but is rank 2", op,
321              "[2,4,3];[3];[3];[1,1]");
322
323  // Negative value is not allowed
324  num_segments_t = test::AsScalar(-100);
325  op.input_tensors[3] = &num_segments_t;
326  INFER_ERROR("Cannot specify a negative value", op, "[2,4,3];[3];[3];[]");
327}
328
329TEST(MathOpsTest, BatchMatMul_ShapeFn) {
330  ShapeInferenceTestOp op("BatchMatMul");
331  auto set_adj = [&op](bool adj_x, bool adj_y) {
332    TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMul")
333                     .Input({"a", 0, DT_FLOAT})
334                     .Input({"b", 0, DT_FLOAT})
335                     .Attr("adj_x", adj_x)
336                     .Attr("adj_y", adj_y)
337                     .Finalize(&op.node_def));
338  };
339
340  set_adj(false, false);
341
342  // Rank checks.
343  INFER_ERROR("at least rank 2", op, "[1];?");
344  INFER_ERROR("at least rank 2", op, "?;[2]");
345
346  INFER_OK(op, "?;?", "?");
347
348  // 0 batch dims.
349  INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
350
351  // 2 batch dims.
352  INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]");
353
354  // Test adj_a, testing output and that inner dims are compared.
355  set_adj(false, false);
356  INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
357  INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]");  // inner dim mismatch
358  set_adj(true, false);
359  INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]");
360  INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]");  // inner dim mismatch
361
362  // Test adj_b=true.
363  set_adj(false, true);
364  INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]");
365  INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]");  // inner dim mismatch
366  set_adj(true, true);
367  INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]");
368  INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]");  // inner dim mismatch
369}
370
371TEST(MathOpsTest, ArgOps_ShapeFn) {
372  ShapeInferenceTestOp op("ArgMax");
373  op.input_tensors.resize(2);
374
375  INFER_OK(op, "?;?", "?");
376
377  // input rank <= 1 produces scalar
378  INFER_OK(op, "[2];?", "[]");
379  INFER_OK(op, "[];?", "[]");
380
381  // Incorrect rank for dimension
382  INFER_ERROR("must be rank 0", op, "[2];[1]");
383
384  // dimension not available, but input rank is.  Output is unknown
385  // shape with rank one less than input rank.
386  INFER_OK(op, "[2,3,4];?", "[?,?]");
387  INFER_OK(op, "[2,3,4,5,6];?", "[?,?,?,?]");
388
389  // Dimension values known
390  Tensor dimension = test::AsScalar(0);
391  op.input_tensors[1] = &dimension;
392  INFER_OK(op, "[2,3,4];[]", "[d0_1,d0_2]");
393
394  dimension = test::AsScalar(1);
395  op.input_tensors[1] = &dimension;
396  INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_2]");
397
398  dimension = test::AsScalar(2);
399  op.input_tensors[1] = &dimension;
400  INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
401
402  // Dimension value out of bounds
403  dimension = test::AsScalar(10);
404  op.input_tensors[1] = &dimension;
405  INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]");
406
407  dimension = test::AsScalar(-10);
408  op.input_tensors[1] = &dimension;
409  INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]");
410}
411
412TEST(MathOpsTest, Betainc_ShapeFn) {
413  ShapeInferenceTestOp op("Betainc");
414
415  INFER_OK(op, "?;?;?", "?");
416  INFER_OK(op, "[?,?];?;?", "in0");
417  INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]");
418  INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]");
419
420  INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]");
421  INFER_OK(op, "[];[];[?,?,3]", "in2");
422
423  // All but one is a scalar, so use it.
424  INFER_OK(op, "[];[];?", "in2");
425  INFER_OK(op, "[];[];[1,2,3,4]", "in2");
426
427  // All scalar input; implementation picks in0.
428  INFER_OK(op, "[];[];[]", "in0");
429
430  // Non-scalars must match shape.
431  INFER_ERROR("must be equal", op, "[1,2];[];[1,4]");
432  INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]");
433}
434
435}  // end namespace tensorflow
436