1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/framework/node_def_builder.h" 17#include "tensorflow/core/framework/op.h" 18#include "tensorflow/core/framework/shape_inference_testutil.h" 19#include "tensorflow/core/framework/tensor.h" 20#include "tensorflow/core/framework/tensor_shape.pb.h" 21#include "tensorflow/core/framework/tensor_testutil.h" 22#include "tensorflow/core/lib/core/status_test_util.h" 23#include "tensorflow/core/platform/test.h" 24 25namespace tensorflow { 26 27TEST(MathOpsTest, AddN_ShapeFn) { 28 ShapeInferenceTestOp op("AddN"); 29 auto set_n = [&op](int n) { 30 std::vector<NodeDefBuilder::NodeOut> src_list; 31 src_list.reserve(n); 32 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); 33 TF_ASSERT_OK(NodeDefBuilder("test", "AddN") 34 .Input(src_list) 35 .Attr("N", n) 36 .Finalize(&op.node_def)); 37 }; 38 39 set_n(2); 40 // Adding two unknowns returns either input. 41 INFER_OK(op, "?;?", "in0|in1"); 42 43 // known+unknown returns the known input. 44 INFER_OK(op, "[1];[?]", "in0"); 45 INFER_OK(op, "[1];?", "in0"); 46 INFER_OK(op, "[?];[1]", "in1"); 47 INFER_OK(op, "?;[1]", "in1"); 48 49 set_n(2); 50 INFER_OK(op, "[1,2];[?,2]", "in0"); 51 INFER_OK(op, "[1,2];[1,2]", "in0|in1"); 52 INFER_OK(op, "[?,2];[1,2]", "in1"); 53 54 set_n(3); 55 INFER_OK(op, "[1,?];[?,2];[1,2]", "in2"); 56 INFER_OK(op, "[1,2];[?,2];[1,?]", "in0"); 57 INFER_OK(op, "?;?;[1,2]", "in2"); 58 59 set_n(2); 60 INFER_OK(op, "?;[1,2]", "in1"); 61 INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]"); 62 INFER_OK(op, "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]"); 63 INFER_OK(op, "[?,2];[1,?]", "[d1_0,d0_1]"); 64 65 set_n(3); 66 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 2 and 4", op, 67 "[1,2];?;[1,4]"); 68 INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2];?;[1,4]"); 69 set_n(4); 70 INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op, 71 "?;[1,2];?;[1,2,3]"); 72 INFER_ERROR("From merging shape 1 with other shapes.", op, 73 "?;[1,2];?;[1,2,3]"); 74} 75 76TEST(MathOpsTest, UnchangedShape_ShapeFn) { 77 ShapeInferenceTestOp op("Cast"); 78 INFER_OK(op, "?", "in0"); 79 INFER_OK(op, "[?]", "in0"); 80 INFER_OK(op, "[1,?,3,4]", "in0"); 81} 82 83TEST(MathOpsTest, Segment_ShapeFn) { 84 // Tests SegmentReductionShapeFn. 85 for (const auto* op_name : {"SegmentMax", "SegmentMean", "SegmentMin", 86 "SegmentProd", "SegmentSum"}) { 87 ShapeInferenceTestOp op(op_name); 88 INFER_OK(op, "?;?", "?"); 89 INFER_OK(op, "?;[100]", "?"); 90 91 // Data shape with single dimension. 92 INFER_OK(op, "[?];?", "[?]"); 93 INFER_OK(op, "[?];[100]", "[?]"); 94 INFER_OK(op, "[1];?", "[?]"); 95 INFER_OK(op, "[1];[100]", "[?]"); 96 97 // Data shape with multiple dimensions. 98 INFER_OK(op, "[?,?];?", "[?,d0_1]"); 99 INFER_OK(op, "[?,2];[100]", "[?,d0_1]"); 100 INFER_OK(op, "[?,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]"); 101 INFER_OK(op, "[1,?];?", "[?,d0_1]"); 102 INFER_OK(op, "[1,2];[100]", "[?,d0_1]"); 103 INFER_OK(op, "[1,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]"); 104 105 // Error cases. 106 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]"); 107 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1]"); 108 } 109} 110 111TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) { 112 for (const auto* op_name : {"Add", "Complex", 113 "Div", "Equal", 114 "Greater", "GreaterEqual", 115 "Igamma", "Igammac", 116 "Zeta", "Polygamma", 117 "Less", "LessEqual", 118 "LogicalAnd", "LogicalOr", 119 "Maximum", "Minimum", 120 "Mod", "Mul", 121 "NotEqual", "Pow", 122 "Sub", "SquaredDifference"}) { 123 ShapeInferenceTestOp op(op_name); 124 INFER_OK(op, "?;?", "?"); 125 INFER_OK(op, "[1,2];?", "?"); 126 INFER_OK(op, "?;[1,2]", "?"); 127 128 INFER_OK(op, "[?];[1]", "[d0_0]"); 129 INFER_OK(op, "[1];[?]", "[d1_0]"); 130 INFER_OK(op, "[?];[2]", "[d1_0]"); 131 INFER_OK(op, "[2];[?]", "[d0_0]"); 132 INFER_OK(op, "[?];[?]", "[?]"); 133 INFER_OK(op, "[];[?]", "[d1_0]"); 134 INFER_OK(op, "[?];[]", "[d0_0]"); 135 136 INFER_OK(op, "[1];[1]", "[d0_0|d1_0]"); 137 INFER_OK(op, "[];[1]", "[d1_0]"); 138 INFER_OK(op, "[1];[]", "[d0_0]"); 139 140 INFER_OK(op, "[2];[2]", "[d0_0|d1_0]"); 141 INFER_OK(op, "[];[2]", "[d1_0]"); 142 INFER_OK(op, "[1];[2]", "[d1_0]"); 143 INFER_OK(op, "[2];[1]", "[d0_0]"); 144 INFER_OK(op, "[2];[]", "[d0_0]"); 145 146 INFER_OK(op, "[0];[0]", "[d0_0|d1_0]"); 147 INFER_OK(op, "[];[0]", "[d1_0]"); 148 INFER_OK(op, "[1];[0]", "[d1_0]"); 149 INFER_OK(op, "[0];[1]", "[d0_0]"); 150 INFER_OK(op, "[0];[]", "[d0_0]"); 151 152 // Multiple dimension cases (same test cases, switching x and y). 153 INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]", 154 "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"); 155 INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]", 156 "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"); 157 } 158} 159 160TEST(MathOpsTest, Select_ShapeFn) { 161 ShapeInferenceTestOp op("Select"); 162 INFER_OK(op, "?;?;?", "in1|in2"); 163 164 // scalar case 165 INFER_OK(op, "[];[1];?", "in1"); 166 INFER_OK(op, "[];?;?", "in1|in2"); 167 168 INFER_OK(op, "[1];?;?", 169 "in1|in2"); // When cond is vector, t/e may not match it. 170 INFER_OK(op, "[1,2];?;?", "in1|in2?"); 171 172 INFER_OK(op, "?;[];?", "in1"); 173 INFER_OK(op, "?;?;[]", "in2"); 174 INFER_OK(op, "?;[1];?", "in1"); 175 INFER_OK(op, "?;?;[1]", "in2"); 176 INFER_OK(op, "?;[1,2];?", "in1"); 177 INFER_OK(op, "?;?;[1,2]", "in2"); 178 179 INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, "[1];[];?"); 180 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[];[1];[1,2]"); 181 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?"); 182 INFER_OK(op, "[2];[?];[?]", "in1|in2"); 183 184 INFER_OK(op, "[?];[?,?,3];[1,2,?]", "[d2_0,d2_1,d1_2]"); 185 INFER_OK(op, "[2];[?,?,3];[?,2,?]", "[d1_0|d2_0,d2_1,d1_2]"); 186 INFER_ERROR("must be equal", op, "[1];[2,?,3];[?,2,?]"); 187 INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op, 188 "[2,?];[?,?,3];[?,2,?]"); 189 INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]"); 190 INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op, 191 "[2,?,5];[?,?,3];[?,2,?]"); 192 193 // Test that handles were merged. 194 // 195 // Tests below will modify handle_data and call run_inference_for_handles to 196 // rerun shape inference, updating the context <c>. 197 const OpRegistrationData* op_reg_data; 198 TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data)); 199 typedef std::vector<std::pair<TensorShapeProto, DataType>> ShapeDtypeV; 200 std::vector<std::unique_ptr<ShapeDtypeV>> handle_data; 201 std::unique_ptr<shape_inference::InferenceContext> c; 202 Status run_status; 203 auto run_inference_for_handles = [&]() -> Status { 204 CHECK(op_reg_data->shape_inference_fn != nullptr); 205 c.reset(new shape_inference::InferenceContext( 206 TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def, 207 {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {}, 208 handle_data)); 209 TF_CHECK_OK(c->construction_status()); 210 Status s = c->Run(op_reg_data->shape_inference_fn); 211 LOG(INFO) << "Inference got " << s; 212 return s; 213 }; 214 auto shape_proto = [](std::initializer_list<int64> dim_sizes) { 215 TensorShapeProto p; 216 for (auto i : dim_sizes) p.add_dim()->set_size(i); 217 return p; 218 }; 219 220 TensorShapeProto i0 = shape_proto({1, -1}); 221 TensorShapeProto i1 = shape_proto({-1, 2}); 222 TensorShapeProto unknown_shape; 223 unknown_shape.set_unknown_rank(true); 224 TensorShapeProto scalar; 225 226 handle_data.emplace_back( 227 new ShapeDtypeV{{scalar, DT_FLOAT}, {unknown_shape, DT_INT32}}); 228 handle_data.emplace_back(new ShapeDtypeV{{i0, DT_FLOAT}, {i1, DT_INT32}}); 229 handle_data.emplace_back( 230 new ShapeDtypeV{{i1, DT_FLOAT}, {unknown_shape, DT_INT32}}); 231 232 TF_ASSERT_OK(run_inference_for_handles()); 233 auto* out = c->output_handle_shapes_and_types(0); 234 ASSERT_EQ(2, out->size()); 235 EXPECT_EQ("[1,2]", c->DebugString(out->at(0).shape)); 236 EXPECT_EQ(DT_FLOAT, out->at(0).dtype); 237 EXPECT_EQ("[?,2]", c->DebugString(out->at(1).shape)); 238 EXPECT_EQ(DT_INT32, out->at(1).dtype); 239 240 // Expect an error when the shapes can't be merged. 241 handle_data[2]->at(0).first = shape_proto({2, 2}); 242 EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message()) 243 .contains("must be equal, but are 1 and 2")); 244 handle_data[2]->at(0).first = i1; // restore to valid 245 246 // Expect an error when the types can't be merged. 247 handle_data[2]->at(1).second = DT_INT64; 248 EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message()) 249 .contains("pointing to different dtypes")); 250 handle_data[2]->at(1).second = DT_INT32; // restore to valid 251 252 // Expect an error when different numbers of tensors are merged. 253 handle_data[2]->push_back({i1, DT_FLOAT}); 254 EXPECT_TRUE(StringPiece(run_inference_for_handles().error_message()) 255 .contains("pointing to different numbers of tensors")); 256 handle_data[2]->pop_back(); // restore to valid. 257} 258 259TEST(MathOpsTest, Range_ShapeFn) { 260 ShapeInferenceTestOp op("Range"); 261 262 TF_ASSERT_OK(NodeDefBuilder("test", "Range") 263 .Input({"start", {}, DT_INT32}) 264 .Input({"limit", {}, DT_INT32}) 265 .Input({"delta", {}, DT_INT32}) 266 .Attr("Tidx", DT_INT32) 267 .Finalize(&op.node_def)); 268 269 op.input_tensors.resize(3); 270 INFER_OK(op, "?;?;?", "[?]"); 271 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?"); 272 INFER_ERROR("for 'start'", op, "[1,2];?;?"); 273 274 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?"); 275 INFER_ERROR("for 'limit'", op, "?;[1,2];?"); 276 277 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 278 INFER_ERROR("for 'delta'", op, "?;?;[1,2]"); 279 280 Tensor start_t = test::AsScalar(1); 281 op.input_tensors[0] = &start_t; 282 INFER_OK(op, "?;?;?", "[?]"); 283 Tensor limit_t = test::AsScalar(1); 284 op.input_tensors[1] = &limit_t; 285 INFER_OK(op, "?;?;?", "[?]"); 286 287 Tensor delta_t = test::AsScalar(1); 288 op.input_tensors[2] = &delta_t; 289 INFER_OK(op, "?;?;?", "[0]"); 290 291 delta_t = test::AsScalar(0); 292 INFER_ERROR("Requires delta != 0", op, "?;?;?"); 293 delta_t = test::AsScalar(3); 294 295 limit_t = test::AsScalar(-1); 296 INFER_ERROR("Requires start <= limit when delta > 0: 1/-1", op, "?;?;?"); 297 298 delta_t = test::AsScalar(-1); 299 INFER_OK(op, "?;?;?", "[2]"); 300 301 limit_t = test::AsScalar(4); 302 INFER_ERROR("Requires start >= limit when delta < 0: 1/4", op, "?;?;?"); 303 304 limit_t = test::AsScalar(100); 305 start_t = test::AsScalar(2); 306 delta_t = test::AsScalar(3); 307 INFER_OK(op, "?;?;?", "[33]"); 308} 309 310TEST(MathOpsTest, LinSpace_ShapeFn) { 311 ShapeInferenceTestOp op("LinSpace"); 312 op.input_tensors.resize(3); 313 INFER_OK(op, "?;?;?", "[?]"); 314 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?"); 315 INFER_ERROR("for 'start'", op, "[1,2];?;?"); 316 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?"); 317 INFER_ERROR("for 'stop'", op, "?;[1,2];?"); 318 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 319 INFER_ERROR("for 'num'", op, "?;?;[1,2]"); 320 321 Tensor num_t = test::AsScalar(1); 322 op.input_tensors[2] = &num_t; 323 INFER_OK(op, "?;?;?", "[1]"); 324 num_t = test::AsScalar(2); 325 INFER_OK(op, "?;?;?", "[2]"); 326 num_t = test::AsScalar(-1); 327 INFER_ERROR("Requires num > 0: -1", op, "?;?;?"); 328} 329 330TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) { 331 ShapeInferenceTestOp op("UnsortedSegmentSum"); 332 op.input_tensors.resize(3); 333 INFER_OK(op, "?;?;?", "?"); 334 INFER_OK(op, "?;[?];?", "?"); 335 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 336 INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, 337 "[1,?,2];[1,?,3];?"); 338 INFER_OK(op, "?;[3];?", "?"); 339 INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, 340 "[1,2];[1,2,3];?"); 341 342 Tensor num_segments_t = test::AsScalar(100); 343 op.input_tensors[2] = &num_segments_t; 344 INFER_OK(op, "[?,2,3,?,5];[1,2,?];[]", "[100,d0_3,d0_4]"); 345 346 num_segments_t = test::AsScalar(-1); 347 INFER_ERROR(("Dimension size, given by scalar input 2, must be " 348 "non-negative but is -1"), 349 op, "[3];[3];?"); 350} 351 352TEST(MathOpsTest, SparseSegment_ShapeFn) { 353 ShapeInferenceTestOp op("SparseSegmentSum"); 354 op.input_tensors.resize(3); 355 INFER_OK(op, "?;?;?", "?"); 356 INFER_OK(op, "[2,4,3];[3];[3]", "[?,d0_1,d0_2]"); 357 358 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2,4,3];[];[3]"); 359 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,4,3];[3];[3,4]"); 360 361 INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 4", op, 362 "[2,4,3];[3];[4]"); 363} 364 365TEST(MathOpsTest, SparseSegmentGrad_ShapeFn) { 366 ShapeInferenceTestOp op("SparseSegmentMeanGrad"); 367 op.input_tensors.resize(4); 368 INFER_OK(op, "?;?;?;?", "?"); 369 INFER_OK(op, "[2,4,3];[3];[3];[]", "[?,d0_1,d0_2]"); 370 371 Tensor num_segments_t = test::AsScalar(100); 372 op.input_tensors[3] = &num_segments_t; 373 INFER_OK(op, "[2,4,3];[3];[3];[]", "[100,d0_1,d0_2]"); 374 375 INFER_ERROR("Shape must be rank 0 but is rank 2", op, 376 "[2,4,3];[3];[3];[1,1]"); 377 378 // Negative value is not allowed 379 num_segments_t = test::AsScalar(-100); 380 op.input_tensors[3] = &num_segments_t; 381 INFER_ERROR("Cannot specify a negative value", op, "[2,4,3];[3];[3];[]"); 382} 383 384TEST(MathOpsTest, BatchMatMul_ShapeFn) { 385 ShapeInferenceTestOp op("BatchMatMul"); 386 auto set_adj = [&op](bool adj_x, bool adj_y) { 387 TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMul") 388 .Input({"a", 0, DT_FLOAT}) 389 .Input({"b", 0, DT_FLOAT}) 390 .Attr("adj_x", adj_x) 391 .Attr("adj_y", adj_y) 392 .Finalize(&op.node_def)); 393 }; 394 395 set_adj(false, false); 396 397 // Rank checks. 398 INFER_ERROR("at least rank 2", op, "[1];?"); 399 INFER_ERROR("at least rank 2", op, "?;[2]"); 400 401 INFER_OK(op, "?;?", "?"); 402 403 // 0 batch dims. 404 INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); 405 406 // 2 batch dims. 407 INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]"); 408 409 // Test adj_a, testing output and that inner dims are compared. 410 set_adj(false, false); 411 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); 412 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]"); // inner dim mismatch 413 set_adj(true, false); 414 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]"); 415 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]"); // inner dim mismatch 416 417 // Test adj_b=true. 418 set_adj(false, true); 419 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]"); 420 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]"); // inner dim mismatch 421 set_adj(true, true); 422 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]"); 423 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]"); // inner dim mismatch 424} 425 426TEST(MathOpsTest, ArgOps_ShapeFn) { 427 ShapeInferenceTestOp op("ArgMax"); 428 op.input_tensors.resize(2); 429 430 INFER_OK(op, "?;?", "?"); 431 432 // input rank <= 1 produces scalar 433 INFER_OK(op, "[2];?", "[]"); 434 INFER_OK(op, "[];?", "[]"); 435 436 // Incorrect rank for dimension 437 INFER_ERROR("must be rank 0", op, "[2];[1]"); 438 439 // dimension not available, but input rank is. Output is unknown 440 // shape with rank one less than input rank. 441 INFER_OK(op, "[2,3,4];?", "[?,?]"); 442 INFER_OK(op, "[2,3,4,5,6];?", "[?,?,?,?]"); 443 444 // Dimension values known 445 Tensor dimension = test::AsScalar(0); 446 op.input_tensors[1] = &dimension; 447 INFER_OK(op, "[2,3,4];[]", "[d0_1,d0_2]"); 448 449 dimension = test::AsScalar(1); 450 op.input_tensors[1] = &dimension; 451 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_2]"); 452 453 dimension = test::AsScalar(2); 454 op.input_tensors[1] = &dimension; 455 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]"); 456 457 // Dimension value out of bounds 458 dimension = test::AsScalar(10); 459 op.input_tensors[1] = &dimension; 460 INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]"); 461 462 dimension = test::AsScalar(-10); 463 op.input_tensors[1] = &dimension; 464 INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]"); 465 466 dimension = test::AsScalar(-1); 467 op.input_tensors[1] = &dimension; 468 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]"); 469} 470 471TEST(MathOpsTest, Betainc_ShapeFn) { 472 ShapeInferenceTestOp op("Betainc"); 473 474 INFER_OK(op, "?;?;?", "?"); 475 INFER_OK(op, "[?,?];?;?", "in0"); 476 INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]"); 477 INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]"); 478 479 INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]"); 480 INFER_OK(op, "[];[];[?,?,3]", "in2"); 481 482 // All but one is a scalar, so use it. 483 INFER_OK(op, "[];[];?", "in2"); 484 INFER_OK(op, "[];[];[1,2,3,4]", "in2"); 485 486 // All scalar input; implementation picks in0. 487 INFER_OK(op, "[];[];[]", "in0"); 488 489 // Non-scalars must match shape. 490 INFER_ERROR("must be equal", op, "[1,2];[];[1,4]"); 491 INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]"); 492} 493 494TEST(MathOpsTest, Requantize_ShapeFn) { 495 ShapeInferenceTestOp op("Requantize"); 496 497 INFER_OK(op, "?;?;?;?;?", "in0;[];[]"); 498 INFER_OK(op, "?;[];[];[];[]", "in0;[];[]"); 499 500 // Rank checks on input scalars. 501 INFER_ERROR("must be rank 0", op, "?;[1];?;?;?"); 502 INFER_ERROR("must be rank 0", op, "?;?;[2];?;?"); 503 INFER_ERROR("must be rank 0", op, "?;?;?;[3];?"); 504 INFER_ERROR("must be rank 0", op, "?;?;?;?;[4]"); 505} 506 507TEST(MathOpstest, RequantizationRange_ShapeFn) { 508 ShapeInferenceTestOp op("RequantizationRange"); 509 510 INFER_OK(op, "?;?;?", "[];[]"); 511 INFER_OK(op, "?;[];[]", "[];[]"); 512 513 // Rank checks on input scalars. 514 INFER_ERROR("must be rank 0", op, "?;[1];?"); 515 INFER_ERROR("must be rank 0", op, "?;?;[2]"); 516} 517 518TEST(MathOpsTest, Cross_ShapeFn) { 519 ShapeInferenceTestOp op("Cross"); 520 521 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]"); 522 INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]"); 523 INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]"); 524 525 INFER_OK(op, "?;?", "in0"); 526 INFER_OK(op, "[?];[?]", "in0"); 527 INFER_OK(op, "[1,?,3];[?,?,?]", "in0"); 528} 529} // end namespace tensorflow 530