1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/node_def_builder.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference_testutil.h"
19#include "tensorflow/core/framework/tensor.h"
20#include "tensorflow/core/framework/tensor_shape.pb.h"
21#include "tensorflow/core/framework/tensor_testutil.h"
22#include "tensorflow/core/lib/core/status_test_util.h"
23#include "tensorflow/core/platform/test.h"
24
25namespace tensorflow {
26
27TEST(MathOpsTest, AddN_ShapeFn) {
28  ShapeInferenceTestOp op("AddN");
29  auto set_n = [&op](int n) {
30    std::vector<NodeDefBuilder::NodeOut> src_list;
31    src_list.reserve(n);
32    for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
33    TF_ASSERT_OK(NodeDefBuilder("test", "AddN")
34                     .Input(src_list)
35                     .Attr("N", n)
36                     .Finalize(&op.node_def));
37  };
38
39  set_n(2);
40  // Adding two unknowns returns either input.
41  INFER_OK(op, "?;?", "in0|in1");
42
43  // known+unknown returns the known input.
44  INFER_OK(op, "[1];[?]", "in0");
45  INFER_OK(op, "[1];?", "in0");
46  INFER_OK(op, "[?];[1]", "in1");
47  INFER_OK(op, "?;[1]", "in1");
48
49  set_n(2);
50  INFER_OK(op, "[1,2];[?,2]", "in0");
51  INFER_OK(op, "[1,2];[1,2]", "in0|in1");
52  INFER_OK(op, "[?,2];[1,2]", "in1");
53
54  set_n(3);
55  INFER_OK(op, "[1,?];[?,2];[1,2]", "in2");
56  INFER_OK(op, "[1,2];[?,2];[1,?]", "in0");
57  INFER_OK(op, "?;?;[1,2]", "in2");
58
59  set_n(2);
60  INFER_OK(op, "?;[1,2]", "in1");
61  INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]");
62  INFER_OK(op, "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]");
63  INFER_OK(op, "[?,2];[1,?]", "[d1_0,d0_1]");
64
65  set_n(3);
66  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 2 and 4", op,
67              "[1,2];?;[1,4]");
68  INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2];?;[1,4]");
69  set_n(4);
70  INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op,
71              "?;[1,2];?;[1,2,3]");
72  INFER_ERROR("From merging shape 1 with other shapes.", op,
73              "?;[1,2];?;[1,2,3]");
74}
75
76TEST(MathOpsTest, UnchangedShape_ShapeFn) {
77  ShapeInferenceTestOp op("Cast");
78  INFER_OK(op, "?", "in0");
79  INFER_OK(op, "[?]", "in0");
80  INFER_OK(op, "[1,?,3,4]", "in0");
81}
82
83TEST(MathOpsTest, Segment_ShapeFn) {
84  // Tests SegmentReductionShapeFn.
85  for (const auto* op_name : {"SegmentMax", "SegmentMean", "SegmentMin",
86                              "SegmentProd", "SegmentSum"}) {
87    ShapeInferenceTestOp op(op_name);
88    INFER_OK(op, "?;?", "?");
89    INFER_OK(op, "?;[100]", "?");
90
91    // Data shape with single dimension.
92    INFER_OK(op, "[?];?", "[?]");
93    INFER_OK(op, "[?];[100]", "[?]");
94    INFER_OK(op, "[1];?", "[?]");
95    INFER_OK(op, "[1];[100]", "[?]");
96
97    // Data shape with multiple dimensions.
98    INFER_OK(op, "[?,?];?", "[?,d0_1]");
99    INFER_OK(op, "[?,2];[100]", "[?,d0_1]");
100    INFER_OK(op, "[?,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
101    INFER_OK(op, "[1,?];?", "[?,d0_1]");
102    INFER_OK(op, "[1,2];[100]", "[?,d0_1]");
103    INFER_OK(op, "[1,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
104
105    // Error cases.
106    INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
107    INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1]");
108  }
109}
110
111TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
112  for (const auto* op_name : {"Add",        "Complex",
113                              "Div",        "Equal",
114                              "Greater",    "GreaterEqual",
115                              "Igamma",     "Igammac",
116                              "Zeta",       "Polygamma",
117                              "Less",       "LessEqual",
118                              "LogicalAnd", "LogicalOr",
119                              "Maximum",    "Minimum",
120                              "Mod",        "Mul",
121                              "NotEqual",   "Pow",
122                              "Sub",        "SquaredDifference"}) {
123    ShapeInferenceTestOp op(op_name);
124    INFER_OK(op, "?;?", "?");
125    INFER_OK(op, "[1,2];?", "?");
126    INFER_OK(op, "?;[1,2]", "?");
127
128    INFER_OK(op, "[?];[1]", "[d0_0]");
129    INFER_OK(op, "[1];[?]", "[d1_0]");
130    INFER_OK(op, "[?];[2]", "[d1_0]");
131    INFER_OK(op, "[2];[?]", "[d0_0]");
132    INFER_OK(op, "[?];[?]", "[?]");
133    INFER_OK(op, "[];[?]", "[d1_0]");
134    INFER_OK(op, "[?];[]", "[d0_0]");
135
136    INFER_OK(op, "[1];[1]", "[d0_0|d1_0]");
137    INFER_OK(op, "[];[1]", "[d1_0]");
138    INFER_OK(op, "[1];[]", "[d0_0]");
139
140    INFER_OK(op, "[2];[2]", "[d0_0|d1_0]");
141    INFER_OK(op, "[];[2]", "[d1_0]");
142    INFER_OK(op, "[1];[2]", "[d1_0]");
143    INFER_OK(op, "[2];[1]", "[d0_0]");
144    INFER_OK(op, "[2];[]", "[d0_0]");
145
146    INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
147    INFER_OK(op, "[];[0]", "[d1_0]");
148    INFER_OK(op, "[1];[0]", "[d1_0]");
149    INFER_OK(op, "[0];[1]", "[d0_0]");
150    INFER_OK(op, "[0];[]", "[d0_0]");
151
152    // Multiple dimension cases (same test cases, switching x and y).
153    INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
154             "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]");
155    INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
156             "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]");
157  }
158}
159
160TEST(MathOpsTest, Select_ShapeFn) {
161  ShapeInferenceTestOp op("Select");
162  INFER_OK(op, "?;?;?", "in1|in2");
163
164  // scalar case
165  INFER_OK(op, "[];[1];?", "in1");
166  INFER_OK(op, "[];?;?", "in1|in2");
167
168  INFER_OK(op, "[1];?;?",
169           "in1|in2");  // When cond is vector, t/e may not match it.
170  INFER_OK(op, "[1,2];?;?", "in1|in2?");
171
172  INFER_OK(op, "?;[];?", "in1");
173  INFER_OK(op, "?;?;[]", "in2");
174  INFER_OK(op, "?;[1];?", "in1");
175  INFER_OK(op, "?;?;[1]", "in2");
176  INFER_OK(op, "?;[1,2];?", "in1");
177  INFER_OK(op, "?;?;[1,2]", "in2");
178
179  INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, "[1];[];?");
180  INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[];[1];[1,2]");
181  INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?");
182  INFER_OK(op, "[2];[?];[?]", "in1|in2");
183
184  INFER_OK(op, "[?];[?,?,3];[1,2,?]", "[d2_0,d2_1,d1_2]");
185  INFER_OK(op, "[2];[?,?,3];[?,2,?]", "[d1_0|d2_0,d2_1,d1_2]");
186  INFER_ERROR("must be equal", op, "[1];[2,?,3];[?,2,?]");
187  INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op,
188              "[2,?];[?,?,3];[?,2,?]");
189  INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]");
190  INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
191              "[2,?,5];[?,?,3];[?,2,?]");
192
193  // Test that handles were merged.
194  //
195  // Tests below will modify handle_data and call run_inference_for_handles to
196  // rerun shape inference, updating the context <c>.
197  const OpRegistrationData* op_reg_data;
198  TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
199  typedef std::vector<std::pair<TensorShapeProto, DataType>> ShapeDtypeV;
200  std::vector<std::unique_ptr<ShapeDtypeV>> handle_data;
201  std::unique_ptr<shape_inference::InferenceContext> c;
202  Status run_status;
203  auto run_inference_for_handles = [&]() -> Status {
204    CHECK(op_reg_data->shape_inference_fn != nullptr);
205    c.reset(new shape_inference::InferenceContext(
206        TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
207        {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
208        handle_data));
209    TF_CHECK_OK(c->construction_status());
210    Status s = c->Run(op_reg_data->shape_inference_fn);
211    LOG(INFO) << "Inference got " << s;
212    return s;
213  };
214  auto shape_proto = [](std::initializer_list<int64> dim_sizes) {
215    TensorShapeProto p;
216    for (auto i : dim_sizes) p.add_dim()->set_size(i);
217    return p;
218  };
219
220  TensorShapeProto i0 = shape_proto({1, -1});
221  TensorShapeProto i1 = shape_proto({-1, 2});
222  TensorShapeProto unknown_shape;
223  unknown_shape.set_unknown_rank(true);
224  TensorShapeProto scalar;
225
226  handle_data.emplace_back(
227      new ShapeDtypeV{{scalar, DT_FLOAT}, {unknown_shape, DT_INT32}});
228  handle_data.emplace_back(new ShapeDtypeV{{i0, DT_FLOAT}, {i1, DT_INT32}});
229  handle_data.emplace_back(
230      new ShapeDtypeV{{i1, DT_FLOAT}, {unknown_shape, DT_INT32}});
231
232  TF_ASSERT_OK(run_inference_for_handles());
233  auto* out = c->output_handle_shapes_and_types(0);
234  ASSERT_EQ(2, out->size());
235  EXPECT_EQ("[1,2]", c->DebugString(out->at(0).shape));
236  EXPECT_EQ(DT_FLOAT, out->at(0).dtype);
237  EXPECT_EQ("[?,2]", c->DebugString(out->at(1).shape));
238  EXPECT_EQ(DT_INT32, out->at(1).dtype);
239
240  // Expect an error when the shapes can't be merged.
241  handle_data[2]->at(0).first = shape_proto({2, 2});
242  EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
243                  .contains("must be equal, but are 1 and 2"));
244  handle_data[2]->at(0).first = i1;  // restore to valid
245
246  // Expect an error when the types can't be merged.
247  handle_data[2]->at(1).second = DT_INT64;
248  EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
249                  .contains("pointing to different dtypes"));
250  handle_data[2]->at(1).second = DT_INT32;  // restore to valid
251
252  // Expect an error when different numbers of tensors are merged.
253  handle_data[2]->push_back({i1, DT_FLOAT});
254  EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message())
255                  .contains("pointing to different numbers of tensors"));
256  handle_data[2]->pop_back();  // restore to valid.
257}
258
259TEST(MathOpsTest, Range_ShapeFn) {
260  ShapeInferenceTestOp op("Range");
261
262  TF_ASSERT_OK(NodeDefBuilder("test", "Range")
263                   .Input({"start", {}, DT_INT32})
264                   .Input({"limit", {}, DT_INT32})
265                   .Input({"delta", {}, DT_INT32})
266                   .Attr("Tidx", DT_INT32)
267                   .Finalize(&op.node_def));
268
269  op.input_tensors.resize(3);
270  INFER_OK(op, "?;?;?", "[?]");
271  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
272  INFER_ERROR("for 'start'", op, "[1,2];?;?");
273
274  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
275  INFER_ERROR("for 'limit'", op, "?;[1,2];?");
276
277  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
278  INFER_ERROR("for 'delta'", op, "?;?;[1,2]");
279
280  Tensor start_t = test::AsScalar(1);
281  op.input_tensors[0] = &start_t;
282  INFER_OK(op, "?;?;?", "[?]");
283  Tensor limit_t = test::AsScalar(1);
284  op.input_tensors[1] = &limit_t;
285  INFER_OK(op, "?;?;?", "[?]");
286
287  Tensor delta_t = test::AsScalar(1);
288  op.input_tensors[2] = &delta_t;
289  INFER_OK(op, "?;?;?", "[0]");
290
291  delta_t = test::AsScalar(0);
292  INFER_ERROR("Requires delta != 0", op, "?;?;?");
293  delta_t = test::AsScalar(3);
294
295  limit_t = test::AsScalar(-1);
296  INFER_ERROR("Requires start <= limit when delta > 0: 1/-1", op, "?;?;?");
297
298  delta_t = test::AsScalar(-1);
299  INFER_OK(op, "?;?;?", "[2]");
300
301  limit_t = test::AsScalar(4);
302  INFER_ERROR("Requires start >= limit when delta < 0: 1/4", op, "?;?;?");
303
304  limit_t = test::AsScalar(100);
305  start_t = test::AsScalar(2);
306  delta_t = test::AsScalar(3);
307  INFER_OK(op, "?;?;?", "[33]");
308}
309
310TEST(MathOpsTest, LinSpace_ShapeFn) {
311  ShapeInferenceTestOp op("LinSpace");
312  op.input_tensors.resize(3);
313  INFER_OK(op, "?;?;?", "[?]");
314  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
315  INFER_ERROR("for 'start'", op, "[1,2];?;?");
316  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
317  INFER_ERROR("for 'stop'", op, "?;[1,2];?");
318  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
319  INFER_ERROR("for 'num'", op, "?;?;[1,2]");
320
321  Tensor num_t = test::AsScalar(1);
322  op.input_tensors[2] = &num_t;
323  INFER_OK(op, "?;?;?", "[1]");
324  num_t = test::AsScalar(2);
325  INFER_OK(op, "?;?;?", "[2]");
326  num_t = test::AsScalar(-1);
327  INFER_ERROR("Requires num > 0: -1", op, "?;?;?");
328}
329
330TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) {
331  ShapeInferenceTestOp op("UnsortedSegmentSum");
332  op.input_tensors.resize(3);
333  INFER_OK(op, "?;?;?", "?");
334  INFER_OK(op, "?;[?];?", "?");
335  INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
336  INFER_ERROR("Dimensions must be equal, but are 2 and 3", op,
337              "[1,?,2];[1,?,3];?");
338  INFER_OK(op, "?;[3];?", "?");
339  INFER_ERROR("Shape must be at least rank 3 but is rank 2", op,
340              "[1,2];[1,2,3];?");
341
342  Tensor num_segments_t = test::AsScalar(100);
343  op.input_tensors[2] = &num_segments_t;
344  INFER_OK(op, "[?,2,3,?,5];[1,2,?];[]", "[100,d0_3,d0_4]");
345
346  num_segments_t = test::AsScalar(-1);
347  INFER_ERROR(("Dimension size, given by scalar input 2, must be "
348               "non-negative but is -1"),
349              op, "[3];[3];?");
350}
351
352TEST(MathOpsTest, SparseSegment_ShapeFn) {
353  ShapeInferenceTestOp op("SparseSegmentSum");
354  op.input_tensors.resize(3);
355  INFER_OK(op, "?;?;?", "?");
356  INFER_OK(op, "[2,4,3];[3];[3]", "[?,d0_1,d0_2]");
357
358  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2,4,3];[];[3]");
359  INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,4,3];[3];[3,4]");
360
361  INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 4", op,
362              "[2,4,3];[3];[4]");
363}
364
365TEST(MathOpsTest, SparseSegmentGrad_ShapeFn) {
366  ShapeInferenceTestOp op("SparseSegmentMeanGrad");
367  op.input_tensors.resize(4);
368  INFER_OK(op, "?;?;?;?", "?");
369  INFER_OK(op, "[2,4,3];[3];[3];[]", "[?,d0_1,d0_2]");
370
371  Tensor num_segments_t = test::AsScalar(100);
372  op.input_tensors[3] = &num_segments_t;
373  INFER_OK(op, "[2,4,3];[3];[3];[]", "[100,d0_1,d0_2]");
374
375  INFER_ERROR("Shape must be rank 0 but is rank 2", op,
376              "[2,4,3];[3];[3];[1,1]");
377
378  // Negative value is not allowed
379  num_segments_t = test::AsScalar(-100);
380  op.input_tensors[3] = &num_segments_t;
381  INFER_ERROR("Cannot specify a negative value", op, "[2,4,3];[3];[3];[]");
382}
383
384TEST(MathOpsTest, BatchMatMul_ShapeFn) {
385  ShapeInferenceTestOp op("BatchMatMul");
386  auto set_adj = [&op](bool adj_x, bool adj_y) {
387    TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMul")
388                     .Input({"a", 0, DT_FLOAT})
389                     .Input({"b", 0, DT_FLOAT})
390                     .Attr("adj_x", adj_x)
391                     .Attr("adj_y", adj_y)
392                     .Finalize(&op.node_def));
393  };
394
395  set_adj(false, false);
396
397  // Rank checks.
398  INFER_ERROR("at least rank 2", op, "[1];?");
399  INFER_ERROR("at least rank 2", op, "?;[2]");
400
401  INFER_OK(op, "?;?", "?");
402
403  // 0 batch dims.
404  INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
405
406  // 2 batch dims.
407  INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]");
408
409  // Test adj_a, testing output and that inner dims are compared.
410  set_adj(false, false);
411  INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
412  INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]");  // inner dim mismatch
413  set_adj(true, false);
414  INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]");
415  INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]");  // inner dim mismatch
416
417  // Test adj_b=true.
418  set_adj(false, true);
419  INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]");
420  INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]");  // inner dim mismatch
421  set_adj(true, true);
422  INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]");
423  INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]");  // inner dim mismatch
424}
425
426TEST(MathOpsTest, ArgOps_ShapeFn) {
427  ShapeInferenceTestOp op("ArgMax");
428  op.input_tensors.resize(2);
429
430  INFER_OK(op, "?;?", "?");
431
432  // input rank <= 1 produces scalar
433  INFER_OK(op, "[2];?", "[]");
434  INFER_OK(op, "[];?", "[]");
435
436  // Incorrect rank for dimension
437  INFER_ERROR("must be rank 0", op, "[2];[1]");
438
439  // dimension not available, but input rank is.  Output is unknown
440  // shape with rank one less than input rank.
441  INFER_OK(op, "[2,3,4];?", "[?,?]");
442  INFER_OK(op, "[2,3,4,5,6];?", "[?,?,?,?]");
443
444  // Dimension values known
445  Tensor dimension = test::AsScalar(0);
446  op.input_tensors[1] = &dimension;
447  INFER_OK(op, "[2,3,4];[]", "[d0_1,d0_2]");
448
449  dimension = test::AsScalar(1);
450  op.input_tensors[1] = &dimension;
451  INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_2]");
452
453  dimension = test::AsScalar(2);
454  op.input_tensors[1] = &dimension;
455  INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
456
457  // Dimension value out of bounds
458  dimension = test::AsScalar(10);
459  op.input_tensors[1] = &dimension;
460  INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
461
462  dimension = test::AsScalar(-10);
463  op.input_tensors[1] = &dimension;
464  INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
465
466  dimension = test::AsScalar(-1);
467  op.input_tensors[1] = &dimension;
468  INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
469}
470
471TEST(MathOpsTest, Betainc_ShapeFn) {
472  ShapeInferenceTestOp op("Betainc");
473
474  INFER_OK(op, "?;?;?", "?");
475  INFER_OK(op, "[?,?];?;?", "in0");
476  INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]");
477  INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]");
478
479  INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]");
480  INFER_OK(op, "[];[];[?,?,3]", "in2");
481
482  // All but one is a scalar, so use it.
483  INFER_OK(op, "[];[];?", "in2");
484  INFER_OK(op, "[];[];[1,2,3,4]", "in2");
485
486  // All scalar input; implementation picks in0.
487  INFER_OK(op, "[];[];[]", "in0");
488
489  // Non-scalars must match shape.
490  INFER_ERROR("must be equal", op, "[1,2];[];[1,4]");
491  INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]");
492}
493
494TEST(MathOpsTest, Requantize_ShapeFn) {
495  ShapeInferenceTestOp op("Requantize");
496
497  INFER_OK(op, "?;?;?;?;?", "in0;[];[]");
498  INFER_OK(op, "?;[];[];[];[]", "in0;[];[]");
499
500  // Rank checks on input scalars.
501  INFER_ERROR("must be rank 0", op, "?;[1];?;?;?");
502  INFER_ERROR("must be rank 0", op, "?;?;[2];?;?");
503  INFER_ERROR("must be rank 0", op, "?;?;?;[3];?");
504  INFER_ERROR("must be rank 0", op, "?;?;?;?;[4]");
505}
506
507TEST(MathOpstest, RequantizationRange_ShapeFn) {
508  ShapeInferenceTestOp op("RequantizationRange");
509
510  INFER_OK(op, "?;?;?", "[];[]");
511  INFER_OK(op, "?;[];[]", "[];[]");
512
513  // Rank checks on input scalars.
514  INFER_ERROR("must be rank 0", op, "?;[1];?");
515  INFER_ERROR("must be rank 0", op, "?;?;[2]");
516}
517
518TEST(MathOpsTest, Cross_ShapeFn) {
519  ShapeInferenceTestOp op("Cross");
520
521  INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
522  INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]");
523  INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]");
524
525  INFER_OK(op, "?;?", "in0");
526  INFER_OK(op, "[?];[?]", "in0");
527  INFER_OK(op, "[1,?,3];[?,?,?]", "in0");
528}
529}  // end namespace tensorflow
530