math_ops_test.cc revision 41734d78d3facf652c25b2a2761aadd978b3f2ef
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/framework/node_def_builder.h" 17#include "tensorflow/core/framework/op.h" 18#include "tensorflow/core/framework/shape_inference_testutil.h" 19#include "tensorflow/core/framework/tensor.h" 20#include "tensorflow/core/framework/tensor_testutil.h" 21#include "tensorflow/core/lib/core/status_test_util.h" 22#include "tensorflow/core/platform/test.h" 23 24namespace tensorflow { 25 26TEST(MathOpsTest, AddN_ShapeFn) { 27 ShapeInferenceTestOp op("AddN"); 28 auto set_n = [&op](int n) { 29 std::vector<NodeDefBuilder::NodeOut> src_list; 30 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); 31 TF_ASSERT_OK(NodeDefBuilder("test", "AddN") 32 .Input(src_list) 33 .Attr("N", n) 34 .Finalize(&op.node_def)); 35 }; 36 37 set_n(2); 38 // Adding two unknowns returns either input. 39 INFER_OK(op, "?;?", "in0|in1"); 40 41 // known+unknown returns the known input. 42 INFER_OK(op, "[1];[?]", "in0"); 43 INFER_OK(op, "[1];?", "in0"); 44 INFER_OK(op, "[?];[1]", "in1"); 45 INFER_OK(op, "?;[1]", "in1"); 46 47 set_n(2); 48 INFER_OK(op, "[1,2];[?,2]", "in0"); 49 INFER_OK(op, "[1,2];[1,2]", "in0|in1"); 50 INFER_OK(op, "[?,2];[1,2]", "in1"); 51 52 set_n(3); 53 INFER_OK(op, "[1,?];[?,2];[1,2]", "in2"); 54 INFER_OK(op, "[1,2];[?,2];[1,?]", "in0"); 55 INFER_OK(op, "?;?;[1,2]", "in2"); 56 57 set_n(2); 58 INFER_OK(op, "?;[1,2]", "in1"); 59 INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]"); 60 INFER_OK(op, "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]"); 61 INFER_OK(op, "[?,2];[1,?]", "[d1_0,d0_1]"); 62 63 set_n(3); 64 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 2 and 4", op, 65 "[1,2];?;[1,4]"); 66 INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2];?;[1,4]"); 67 set_n(4); 68 INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op, 69 "?;[1,2];?;[1,2,3]"); 70 INFER_ERROR("From merging shape 1 with other shapes.", op, 71 "?;[1,2];?;[1,2,3]"); 72} 73 74TEST(MathOpsTest, UnchangedShape_ShapeFn) { 75 ShapeInferenceTestOp op("Cast"); 76 INFER_OK(op, "?", "in0"); 77 INFER_OK(op, "[?]", "in0"); 78 INFER_OK(op, "[1,?,3,4]", "in0"); 79} 80 81TEST(MathOpsTest, FFT_ShapeFn) { 82 for (const auto* op_name : {"FFT", "IFFT"}) { 83 ShapeInferenceTestOp op(op_name); 84 INFER_OK(op, "?", "?"); 85 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]"); 86 INFER_OK(op, "[?]", "in0"); 87 INFER_OK(op, "[1]", "in0"); 88 INFER_OK(op, "[1,2,3,4,5,6,7]", "in0"); 89 } 90 91 for (const auto* op_name : {"FFT2D", "IFFT2D"}) { 92 ShapeInferenceTestOp op(op_name); 93 INFER_OK(op, "?", "?"); 94 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]"); 95 INFER_OK(op, "[?,1]", "in0"); 96 INFER_OK(op, "[1,2]", "in0"); 97 INFER_OK(op, "[1,2,3,4,5,6,7]", "in0"); 98 } 99 100 for (const auto* op_name : {"FFT3D", "IFFT3D"}) { 101 ShapeInferenceTestOp op(op_name); 102 INFER_OK(op, "?", "?"); 103 INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, "[1,2]"); 104 INFER_OK(op, "[?,1,?]", "in0"); 105 INFER_OK(op, "[1,2,3]", "in0"); 106 INFER_OK(op, "[1,2,3,4,5,6,7]", "in0"); 107 } 108} 109 110TEST(MathOpsTest, Segment_ShapeFn) { 111 // Tests SegmentReductionShapeFn. 112 for (const auto* op_name : {"SegmentMax", "SegmentMean", "SegmentMin", 113 "SegmentProd", "SegmentSum"}) { 114 ShapeInferenceTestOp op(op_name); 115 INFER_OK(op, "?;?", "?"); 116 INFER_OK(op, "?;[100]", "?"); 117 118 // Data shape with single dimension. 119 INFER_OK(op, "[?];?", "[?]"); 120 INFER_OK(op, "[?];[100]", "[?]"); 121 INFER_OK(op, "[1];?", "[?]"); 122 INFER_OK(op, "[1];[100]", "[?]"); 123 124 // Data shape with multiple dimensions. 125 INFER_OK(op, "[?,?];?", "[?,d0_1]"); 126 INFER_OK(op, "[?,2];[100]", "[?,d0_1]"); 127 INFER_OK(op, "[?,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]"); 128 INFER_OK(op, "[1,?];?", "[?,d0_1]"); 129 INFER_OK(op, "[1,2];[100]", "[?,d0_1]"); 130 INFER_OK(op, "[1,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]"); 131 132 // Error cases. 133 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]"); 134 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1]"); 135 } 136} 137 138TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) { 139 for (const auto* op_name : {"Add", "Complex", 140 "Div", "Equal", 141 "Greater", "GreaterEqual", 142 "Igamma", "Igammac", 143 "Zeta", "Polygamma", 144 "Less", "LessEqual", 145 "LogicalAnd", "LogicalOr", 146 "Maximum", "Minimum", 147 "Mod", "Mul", 148 "NotEqual", "Pow", 149 "Sub", "SquaredDifference"}) { 150 ShapeInferenceTestOp op(op_name); 151 INFER_OK(op, "?;?", "?"); 152 INFER_OK(op, "[1,2];?", "?"); 153 INFER_OK(op, "?;[1,2]", "?"); 154 155 INFER_OK(op, "[?];[1]", "[d0_0]"); 156 INFER_OK(op, "[1];[?]", "[d1_0]"); 157 INFER_OK(op, "[?];[2]", "[d1_0]"); 158 INFER_OK(op, "[2];[?]", "[d0_0]"); 159 INFER_OK(op, "[?];[?]", "[?]"); 160 INFER_OK(op, "[];[?]", "[d1_0]"); 161 INFER_OK(op, "[?];[]", "[d0_0]"); 162 163 INFER_OK(op, "[1];[1]", "[d0_0|d1_0]"); 164 INFER_OK(op, "[];[1]", "[d1_0]"); 165 INFER_OK(op, "[1];[]", "[d0_0]"); 166 167 INFER_OK(op, "[2];[2]", "[d0_0|d1_0]"); 168 INFER_OK(op, "[];[2]", "[d1_0]"); 169 INFER_OK(op, "[1];[2]", "[d1_0]"); 170 INFER_OK(op, "[2];[1]", "[d0_0]"); 171 INFER_OK(op, "[2];[]", "[d0_0]"); 172 173 INFER_OK(op, "[0];[0]", "[d0_0|d1_0]"); 174 INFER_OK(op, "[];[0]", "[d1_0]"); 175 INFER_OK(op, "[1];[0]", "[d1_0]"); 176 INFER_OK(op, "[0];[1]", "[d0_0]"); 177 INFER_OK(op, "[0];[]", "[d0_0]"); 178 179 // Multiple dimension cases (same test cases, switching x and y). 180 INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]", 181 "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"); 182 INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]", 183 "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"); 184 } 185} 186 187TEST(MathOpsTest, Select_ShapeFn) { 188 ShapeInferenceTestOp op("Select"); 189 INFER_OK(op, "?;?;?", "in1|in2"); 190 191 // scalar case 192 INFER_OK(op, "[];[1];?", "in1"); 193 INFER_OK(op, "[];?;?", "in1|in2"); 194 195 INFER_OK(op, "[1];?;?", 196 "in1|in2"); // When cond is vector, t/e may not match it. 197 INFER_OK(op, "[1,2];?;?", "in1|in2?"); 198 199 INFER_OK(op, "?;[];?", "in1"); 200 INFER_OK(op, "?;?;[]", "in2"); 201 INFER_OK(op, "?;[1];?", "in1"); 202 INFER_OK(op, "?;?;[1]", "in2"); 203 INFER_OK(op, "?;[1,2];?", "in1"); 204 INFER_OK(op, "?;?;[1,2]", "in2"); 205 206 INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, "[1];[];?"); 207 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[];[1];[1,2]"); 208 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?"); 209 INFER_OK(op, "[2];[?];[?]", "in1|in2"); 210 211 INFER_OK(op, "[?];[?,?,3];[1,2,?]", "[d2_0,d2_1,d1_2]"); 212 INFER_OK(op, "[2];[?,?,3];[?,2,?]", "[d1_0|d2_0,d2_1,d1_2]"); 213 INFER_ERROR("must be equal", op, "[1];[2,?,3];[?,2,?]"); 214 INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op, 215 "[2,?];[?,?,3];[?,2,?]"); 216 INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]"); 217 INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op, 218 "[2,?,5];[?,?,3];[?,2,?]"); 219 220 // Test that handle shapes were merged. 221 const OpRegistrationData* op_reg_data; 222 TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data)); 223 TensorShapeProto i0; 224 i0.add_dim()->set_size(1); 225 i0.add_dim()->set_size(-1); 226 TensorShapeProto i1; 227 i1.add_dim()->set_size(-1); 228 i1.add_dim()->set_size(2); 229 230 ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr); 231 shape_inference::InferenceContext c( 232 &op.node_def, op_reg_data->op_def, 233 {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {}, 234 {TensorShapeProto(), i0, i1}, {}); 235 TF_ASSERT_OK(c.construction_status()); 236 TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn)); 237 EXPECT_TRUE(c.FullyDefined(c.output_handle_shape(0))); 238 EXPECT_EQ("[1,2]", c.DebugString(c.output_handle_shape(0))); 239 240 // Expect an error when the shapes can't be merged. 241 TensorShapeProto i2; 242 i1.add_dim()->set_size(2); 243 i1.add_dim()->set_size(2); 244 shape_inference::InferenceContext c2( 245 &op.node_def, op_reg_data->op_def, 246 {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {}, 247 {TensorShapeProto(), i0, i2}, {}); 248 TF_ASSERT_OK(c.construction_status()); 249 EXPECT_FALSE(c2.Run(op_reg_data->shape_inference_fn).ok()); 250} 251 252TEST(MathOpsTest, Range_ShapeFn) { 253 ShapeInferenceTestOp op("Range"); 254 255 TF_ASSERT_OK(NodeDefBuilder("test", "Range") 256 .Input({"start", {}, DT_INT32}) 257 .Input({"limit", {}, DT_INT32}) 258 .Input({"delta", {}, DT_INT32}) 259 .Attr("Tidx", DT_INT32) 260 .Finalize(&op.node_def)); 261 262 op.input_tensors.resize(3); 263 INFER_OK(op, "?;?;?", "[?]"); 264 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?"); 265 INFER_ERROR("for 'start'", op, "[1,2];?;?"); 266 267 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?"); 268 INFER_ERROR("for 'limit'", op, "?;[1,2];?"); 269 270 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 271 INFER_ERROR("for 'delta'", op, "?;?;[1,2]"); 272 273 Tensor start_t = test::AsScalar(1); 274 op.input_tensors[0] = &start_t; 275 INFER_OK(op, "?;?;?", "[?]"); 276 Tensor limit_t = test::AsScalar(1); 277 op.input_tensors[1] = &limit_t; 278 INFER_OK(op, "?;?;?", "[?]"); 279 280 Tensor delta_t = test::AsScalar(1); 281 op.input_tensors[2] = &delta_t; 282 INFER_OK(op, "?;?;?", "[0]"); 283 284 delta_t = test::AsScalar(0); 285 INFER_ERROR("Requires delta != 0", op, "?;?;?"); 286 delta_t = test::AsScalar(3); 287 288 limit_t = test::AsScalar(-1); 289 INFER_ERROR("Requires start <= limit when delta > 0: 1/-1", op, "?;?;?"); 290 291 delta_t = test::AsScalar(-1); 292 INFER_OK(op, "?;?;?", "[2]"); 293 294 limit_t = test::AsScalar(4); 295 INFER_ERROR("Requires start >= limit when delta < 0: 1/4", op, "?;?;?"); 296 297 limit_t = test::AsScalar(100); 298 start_t = test::AsScalar(2); 299 delta_t = test::AsScalar(3); 300 INFER_OK(op, "?;?;?", "[33]"); 301} 302 303TEST(MathOpsTest, LinSpace_ShapeFn) { 304 ShapeInferenceTestOp op("LinSpace"); 305 op.input_tensors.resize(3); 306 INFER_OK(op, "?;?;?", "[?]"); 307 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?"); 308 INFER_ERROR("for 'start'", op, "[1,2];?;?"); 309 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?"); 310 INFER_ERROR("for 'stop'", op, "?;[1,2];?"); 311 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 312 INFER_ERROR("for 'num'", op, "?;?;[1,2]"); 313 314 Tensor num_t = test::AsScalar(1); 315 op.input_tensors[2] = &num_t; 316 INFER_OK(op, "?;?;?", "[1]"); 317 num_t = test::AsScalar(2); 318 INFER_OK(op, "?;?;?", "[2]"); 319 num_t = test::AsScalar(-1); 320 INFER_ERROR("Requires num > 0: -1", op, "?;?;?"); 321} 322 323TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) { 324 ShapeInferenceTestOp op("UnsortedSegmentSum"); 325 op.input_tensors.resize(3); 326 INFER_OK(op, "?;?;?", "?"); 327 INFER_OK(op, "?;[?];?", "?"); 328 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 329 INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, 330 "[1,?,2];[1,?,3];?"); 331 INFER_OK(op, "?;[3];?", "?"); 332 INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, 333 "[1,2];[1,2,3];?"); 334 335 Tensor num_segments_t = test::AsScalar(100); 336 op.input_tensors[2] = &num_segments_t; 337 INFER_OK(op, "[?,2,3,?,5];[1,2,?];[]", "[100,d0_3,d0_4]"); 338 339 num_segments_t = test::AsScalar(-1); 340 INFER_ERROR(("Dimension size, given by scalar input 2, must be " 341 "non-negative but is -1"), 342 op, "[3];[3];?"); 343} 344 345TEST(MathOpsTest, SparseSegment_ShapeFn) { 346 ShapeInferenceTestOp op("SparseSegmentSum"); 347 op.input_tensors.resize(3); 348 INFER_OK(op, "?;?;?", "?"); 349 INFER_OK(op, "[2,4,3];[3];[3]", "[?,d0_1,d0_2]"); 350 351 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2,4,3];[];[3]"); 352 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,4,3];[3];[3,4]"); 353 354 INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 4", op, 355 "[2,4,3];[3];[4]"); 356} 357 358TEST(MathOpsTest, SparseSegmentGrad_ShapeFn) { 359 ShapeInferenceTestOp op("SparseSegmentMeanGrad"); 360 op.input_tensors.resize(4); 361 INFER_OK(op, "?;?;?;?", "?"); 362 INFER_OK(op, "[2,4,3];[3];[3];[]", "[?,d0_1,d0_2]"); 363 364 Tensor num_segments_t = test::AsScalar(100); 365 op.input_tensors[3] = &num_segments_t; 366 INFER_OK(op, "[2,4,3];[3];[3];[]", "[100,d0_1,d0_2]"); 367 368 INFER_ERROR("Shape must be rank 0 but is rank 2", op, 369 "[2,4,3];[3];[3];[1,1]"); 370 371 // Negative value is not allowed 372 num_segments_t = test::AsScalar(-100); 373 op.input_tensors[3] = &num_segments_t; 374 INFER_ERROR("Cannot specify a negative value", op, "[2,4,3];[3];[3];[]"); 375} 376 377TEST(MathOpsTest, BatchMatMul_ShapeFn) { 378 ShapeInferenceTestOp op("BatchMatMul"); 379 auto set_adj = [&op](bool adj_x, bool adj_y) { 380 TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMul") 381 .Input({"a", 0, DT_FLOAT}) 382 .Input({"b", 0, DT_FLOAT}) 383 .Attr("adj_x", adj_x) 384 .Attr("adj_y", adj_y) 385 .Finalize(&op.node_def)); 386 }; 387 388 set_adj(false, false); 389 390 // Rank checks. 391 INFER_ERROR("at least rank 2", op, "[1];?"); 392 INFER_ERROR("at least rank 2", op, "?;[2]"); 393 394 INFER_OK(op, "?;?", "?"); 395 396 // 0 batch dims. 397 INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); 398 399 // 2 batch dims. 400 INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]"); 401 402 // Test adj_a, testing output and that inner dims are compared. 403 set_adj(false, false); 404 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); 405 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]"); // inner dim mismatch 406 set_adj(true, false); 407 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]"); 408 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]"); // inner dim mismatch 409 410 // Test adj_b=true. 411 set_adj(false, true); 412 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]"); 413 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]"); // inner dim mismatch 414 set_adj(true, true); 415 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]"); 416 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]"); // inner dim mismatch 417} 418 419TEST(MathOpsTest, ArgOps_ShapeFn) { 420 ShapeInferenceTestOp op("ArgMax"); 421 op.input_tensors.resize(2); 422 423 INFER_OK(op, "?;?", "?"); 424 425 // input rank <= 1 produces scalar 426 INFER_OK(op, "[2];?", "[]"); 427 INFER_OK(op, "[];?", "[]"); 428 429 // Incorrect rank for dimension 430 INFER_ERROR("must be rank 0", op, "[2];[1]"); 431 432 // dimension not available, but input rank is. Output is unknown 433 // shape with rank one less than input rank. 434 INFER_OK(op, "[2,3,4];?", "[?,?]"); 435 INFER_OK(op, "[2,3,4,5,6];?", "[?,?,?,?]"); 436 437 // Dimension values known 438 Tensor dimension = test::AsScalar(0); 439 op.input_tensors[1] = &dimension; 440 INFER_OK(op, "[2,3,4];[]", "[d0_1,d0_2]"); 441 442 dimension = test::AsScalar(1); 443 op.input_tensors[1] = &dimension; 444 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_2]"); 445 446 dimension = test::AsScalar(2); 447 op.input_tensors[1] = &dimension; 448 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]"); 449 450 // Dimension value out of bounds 451 dimension = test::AsScalar(10); 452 op.input_tensors[1] = &dimension; 453 INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]"); 454 455 dimension = test::AsScalar(-10); 456 op.input_tensors[1] = &dimension; 457 INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]"); 458} 459 460TEST(MathOpsTest, Betainc_ShapeFn) { 461 ShapeInferenceTestOp op("Betainc"); 462 463 INFER_OK(op, "?;?;?", "?"); 464 INFER_OK(op, "[?,?];?;?", "in0"); 465 INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]"); 466 INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]"); 467 468 INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]"); 469 INFER_OK(op, "[];[];[?,?,3]", "in2"); 470 471 // All but one is a scalar, so use it. 472 INFER_OK(op, "[];[];?", "in2"); 473 INFER_OK(op, "[];[];[1,2,3,4]", "in2"); 474 475 // All scalar input; implementation picks in0. 476 INFER_OK(op, "[];[];[]", "in0"); 477 478 // Non-scalars must match shape. 479 INFER_ERROR("must be equal", op, "[1,2];[];[1,4]"); 480 INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]"); 481} 482 483TEST(MathOpsTest, Requantize_ShapeFn) { 484 ShapeInferenceTestOp op("Requantize"); 485 486 INFER_OK(op, "?;?;?;?;?", "in0;[];[]"); 487 INFER_OK(op, "?;[];[];[];[]", "in0;[];[]"); 488 489 // Rank checks on input scalars. 490 INFER_ERROR("must be rank 0", op, "?;[1];?;?;?"); 491 INFER_ERROR("must be rank 0", op, "?;?;[2];?;?"); 492 INFER_ERROR("must be rank 0", op, "?;?;?;[3];?"); 493 INFER_ERROR("must be rank 0", op, "?;?;?;?;[4]"); 494} 495 496TEST(MathOpstest, RequantizationRange_ShapeFn) { 497 ShapeInferenceTestOp op("RequantizationRange"); 498 499 INFER_OK(op, "?;?;?", "[];[]"); 500 INFER_OK(op, "?;[];[]", "[];[]"); 501 502 // Rank checks on input scalars. 503 INFER_ERROR("must be rank 0", op, "?;[1];?"); 504 INFER_ERROR("must be rank 0", op, "?;?;[2]"); 505} 506 507} // end namespace tensorflow 508