math_ops_test.cc revision 7d9c0c891d82fb5d35dc4669abe832708940a810
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/framework/node_def_builder.h" 17#include "tensorflow/core/framework/op.h" 18#include "tensorflow/core/framework/shape_inference_testutil.h" 19#include "tensorflow/core/framework/tensor.h" 20#include "tensorflow/core/framework/tensor_testutil.h" 21#include "tensorflow/core/lib/core/status_test_util.h" 22#include "tensorflow/core/platform/test.h" 23 24namespace tensorflow { 25 26TEST(MathOpsTest, AddN_ShapeFn) { 27 ShapeInferenceTestOp op("AddN"); 28 auto set_n = [&op](int n) { 29 std::vector<NodeDefBuilder::NodeOut> src_list; 30 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); 31 TF_ASSERT_OK(NodeDefBuilder("test", "AddN") 32 .Input(src_list) 33 .Attr("N", n) 34 .Finalize(&op.node_def)); 35 }; 36 37 set_n(2); 38 // Adding two unknowns returns either input. 39 INFER_OK(op, "?;?", "in0|in1"); 40 41 // known+unknown returns the known input. 42 INFER_OK(op, "[1];[?]", "in0"); 43 INFER_OK(op, "[1];?", "in0"); 44 INFER_OK(op, "[?];[1]", "in1"); 45 INFER_OK(op, "?;[1]", "in1"); 46 47 set_n(2); 48 INFER_OK(op, "[1,2];[?,2]", "in0"); 49 INFER_OK(op, "[1,2];[1,2]", "in0|in1"); 50 INFER_OK(op, "[?,2];[1,2]", "in1"); 51 52 set_n(3); 53 INFER_OK(op, "[1,?];[?,2];[1,2]", "in2"); 54 INFER_OK(op, "[1,2];[?,2];[1,?]", "in0"); 55 INFER_OK(op, "?;?;[1,2]", "in2"); 56 57 set_n(2); 58 INFER_OK(op, "?;[1,2]", "in1"); 59 INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]"); 60 INFER_OK(op, "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]"); 61 INFER_OK(op, "[?,2];[1,?]", "[d1_0,d0_1]"); 62 63 set_n(3); 64 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 2 and 4", op, 65 "[1,2];?;[1,4]"); 66 INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2];?;[1,4]"); 67 set_n(4); 68 INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op, 69 "?;[1,2];?;[1,2,3]"); 70 INFER_ERROR("From merging shape 1 with other shapes.", op, 71 "?;[1,2];?;[1,2,3]"); 72} 73 74TEST(MathOpsTest, UnchangedShape_ShapeFn) { 75 ShapeInferenceTestOp op("Cast"); 76 INFER_OK(op, "?", "in0"); 77 INFER_OK(op, "[?]", "in0"); 78 INFER_OK(op, "[1,?,3,4]", "in0"); 79} 80 81TEST(MathOpsTest, FFT_ShapeFn) { 82 for (const auto* op_name : {"FFT", "IFFT"}) { 83 ShapeInferenceTestOp op(op_name); 84 INFER_OK(op, "?", "?"); 85 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]"); 86 INFER_OK(op, "[?]", "in0"); 87 INFER_OK(op, "[1]", "in0"); 88 INFER_OK(op, "[1,2,3,4,5,6,7]", "in0"); 89 } 90 91 for (const auto* op_name : {"FFT2D", "IFFT2D"}) { 92 ShapeInferenceTestOp op(op_name); 93 INFER_OK(op, "?", "?"); 94 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]"); 95 INFER_OK(op, "[?,1]", "in0"); 96 INFER_OK(op, "[1,2]", "in0"); 97 INFER_OK(op, "[1,2,3,4,5,6,7]", "in0"); 98 } 99 100 for (const auto* op_name : {"FFT3D", "IFFT3D"}) { 101 ShapeInferenceTestOp op(op_name); 102 INFER_OK(op, "?", "?"); 103 INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, "[1,2]"); 104 INFER_OK(op, "[?,1,?]", "in0"); 105 INFER_OK(op, "[1,2,3]", "in0"); 106 INFER_OK(op, "[1,2,3,4,5,6,7]", "in0"); 107 } 108} 109 110TEST(MathOpsTest, Segment_ShapeFn) { 111 // Tests SegmentReductionShapeFn. 112 for (const auto* op_name : {"SegmentMax", "SegmentMean", "SegmentMin", 113 "SegmentProd", "SegmentSum"}) { 114 ShapeInferenceTestOp op(op_name); 115 INFER_OK(op, "?;?", "?"); 116 INFER_OK(op, "?;[100]", "?"); 117 118 // Data shape with single dimension. 119 INFER_OK(op, "[?];?", "[?]"); 120 INFER_OK(op, "[?];[100]", "[?]"); 121 INFER_OK(op, "[1];?", "[?]"); 122 INFER_OK(op, "[1];[100]", "[?]"); 123 124 // Data shape with multiple dimensions. 125 INFER_OK(op, "[?,?];?", "[?,d0_1]"); 126 INFER_OK(op, "[?,2];[100]", "[?,d0_1]"); 127 INFER_OK(op, "[?,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]"); 128 INFER_OK(op, "[1,?];?", "[?,d0_1]"); 129 INFER_OK(op, "[1,2];[100]", "[?,d0_1]"); 130 INFER_OK(op, "[1,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]"); 131 132 // Error cases. 133 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]"); 134 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1]"); 135 } 136} 137 138TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) { 139 for (const auto* op_name : {"Add", "Complex", 140 "Div", "Equal", 141 "Greater", "GreaterEqual", 142 "Igamma", "Igammac", 143 "Zeta", "Polygamma", 144 "Less", "LessEqual", 145 "LogicalAnd", "LogicalOr", 146 "Maximum", "Minimum", 147 "Mod", "Mul", 148 "NotEqual", "Pow", 149 "Sub", "SquaredDifference"}) { 150 ShapeInferenceTestOp op(op_name); 151 INFER_OK(op, "?;?", "?"); 152 INFER_OK(op, "[1,2];?", "?"); 153 INFER_OK(op, "?;[1,2]", "?"); 154 155 INFER_OK(op, "[?];[1]", "[d0_0]"); 156 INFER_OK(op, "[1];[?]", "[d1_0]"); 157 INFER_OK(op, "[?];[2]", "[d1_0]"); 158 INFER_OK(op, "[2];[?]", "[d0_0]"); 159 INFER_OK(op, "[?];[?]", "[?]"); 160 INFER_OK(op, "[];[?]", "[d1_0]"); 161 INFER_OK(op, "[?];[]", "[d0_0]"); 162 163 INFER_OK(op, "[1];[1]", "[d0_0|d1_0]"); 164 INFER_OK(op, "[];[1]", "[d1_0]"); 165 INFER_OK(op, "[1];[]", "[d0_0]"); 166 167 INFER_OK(op, "[2];[2]", "[d0_0|d1_0]"); 168 INFER_OK(op, "[];[2]", "[d1_0]"); 169 INFER_OK(op, "[1];[2]", "[d1_0]"); 170 INFER_OK(op, "[2];[1]", "[d0_0]"); 171 INFER_OK(op, "[2];[]", "[d0_0]"); 172 173 INFER_OK(op, "[0];[0]", "[d0_0|d1_0]"); 174 INFER_OK(op, "[];[0]", "[d1_0]"); 175 INFER_OK(op, "[1];[0]", "[d1_0]"); 176 INFER_OK(op, "[0];[1]", "[d0_0]"); 177 INFER_OK(op, "[0];[]", "[d0_0]"); 178 179 // Multiple dimension cases (same test cases, switching x and y). 180 INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]", 181 "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"); 182 INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]", 183 "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"); 184 } 185} 186 187TEST(MathOpsTest, Select_ShapeFn) { 188 ShapeInferenceTestOp op("Select"); 189 INFER_OK(op, "?;?;?", "in1|in2"); 190 191 INFER_OK(op, "[];?;?", "in1|in2"); 192 INFER_OK(op, "[1];?;?", 193 "in1|in2"); // When cond is vector, t/e may not match it. 194 INFER_OK(op, "[1,2];?;?", "in1|in2?"); 195 196 INFER_OK(op, "?;[];?", "in1"); 197 INFER_OK(op, "?;?;[]", "in2"); 198 INFER_OK(op, "?;[1];?", "in1"); 199 INFER_OK(op, "?;?;[1]", "in2"); 200 INFER_OK(op, "?;[1,2];?", "in1"); 201 INFER_OK(op, "?;?;[1,2]", "in2"); 202 203 INFER_OK(op, "[1];[];?", "in1"); 204 INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op, "[];[1];?"); 205 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?"); 206 INFER_OK(op, "[2];[?];[?]", "in1|in2"); 207 208 INFER_OK(op, "[?];[?,?,3];[1,2,?]", "[d2_0,d2_1,d1_2]"); 209 INFER_OK(op, "[2];[?,?,3];[?,2,?]", "[d1_0|d2_0,d2_1,d1_2]"); 210 INFER_ERROR("must be equal", op, "[1];[2,?,3];[?,2,?]"); 211 INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op, 212 "[2,?];[?,?,3];[?,2,?]"); 213 INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]"); 214 INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op, 215 "[2,?,5];[?,?,3];[?,2,?]"); 216} 217 218TEST(MathOpsTest, Range_ShapeFn) { 219 ShapeInferenceTestOp op("Range"); 220 op.input_tensors.resize(3); 221 INFER_OK(op, "?;?;?", "[?]"); 222 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?"); 223 INFER_ERROR("for 'start'", op, "[1,2];?;?"); 224 225 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?"); 226 INFER_ERROR("for 'limit'", op, "?;[1,2];?"); 227 228 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 229 INFER_ERROR("for 'delta'", op, "?;?;[1,2]"); 230 231 Tensor start_t = test::AsScalar(1); 232 op.input_tensors[0] = &start_t; 233 INFER_OK(op, "?;?;?", "[?]"); 234 Tensor limit_t = test::AsScalar(1); 235 op.input_tensors[1] = &limit_t; 236 INFER_OK(op, "?;?;?", "[?]"); 237 238 Tensor delta_t = test::AsScalar(1); 239 op.input_tensors[2] = &delta_t; 240 INFER_OK(op, "?;?;?", "[0]"); 241 242 delta_t = test::AsScalar(0); 243 INFER_ERROR("Requires delta > 0: 0", op, "?;?;?"); 244 delta_t = test::AsScalar(3); 245 246 limit_t = test::AsScalar(-1); 247 INFER_ERROR("Requires start <= limit: 1/-1", op, "?;?;?"); 248 249 limit_t = test::AsScalar(100); 250 start_t = test::AsScalar(2); 251 delta_t = test::AsScalar(3); 252 INFER_OK(op, "?;?;?", "[33]"); 253} 254 255TEST(MathOpsTest, LinSpace_ShapeFn) { 256 ShapeInferenceTestOp op("LinSpace"); 257 op.input_tensors.resize(3); 258 INFER_OK(op, "?;?;?", "[?]"); 259 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?"); 260 INFER_ERROR("for 'start'", op, "[1,2];?;?"); 261 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?"); 262 INFER_ERROR("for 'stop'", op, "?;[1,2];?"); 263 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 264 INFER_ERROR("for 'num'", op, "?;?;[1,2]"); 265 266 Tensor num_t = test::AsScalar(1); 267 op.input_tensors[2] = &num_t; 268 INFER_OK(op, "?;?;?", "[1]"); 269 num_t = test::AsScalar(2); 270 INFER_OK(op, "?;?;?", "[2]"); 271 num_t = test::AsScalar(-1); 272 INFER_ERROR("Requires num > 0: -1", op, "?;?;?"); 273} 274 275TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) { 276 ShapeInferenceTestOp op("UnsortedSegmentSum"); 277 op.input_tensors.resize(3); 278 INFER_OK(op, "?;?;?", "?"); 279 INFER_OK(op, "?;[?];?", "?"); 280 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]"); 281 INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, 282 "[1,?,2];[1,?,3];?"); 283 INFER_OK(op, "?;[3];?", "?"); 284 INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, 285 "[1,2];[1,2,3];?"); 286 287 Tensor num_segments_t = test::AsScalar(100); 288 op.input_tensors[2] = &num_segments_t; 289 INFER_OK(op, "[?,2,3,?,5];[1,2,?];[]", "[100,d0_3,d0_4]"); 290 291 num_segments_t = test::AsScalar(-1); 292 INFER_ERROR(("Dimension size, given by scalar input 2, must be " 293 "non-negative but is -1"), 294 op, "[3];[3];?"); 295} 296 297TEST(MathOpsTest, SparseSegment_ShapeFn) { 298 ShapeInferenceTestOp op("SparseSegmentSum"); 299 op.input_tensors.resize(3); 300 INFER_OK(op, "?;?;?", "?"); 301 INFER_OK(op, "[2,4,3];[3];[3]", "[?,d0_1,d0_2]"); 302 303 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2,4,3];[];[3]"); 304 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,4,3];[3];[3,4]"); 305 306 INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 4", op, 307 "[2,4,3];[3];[4]"); 308} 309 310TEST(MathOpsTest, SparseSegmentGrad_ShapeFn) { 311 ShapeInferenceTestOp op("SparseSegmentMeanGrad"); 312 op.input_tensors.resize(4); 313 INFER_OK(op, "?;?;?;?", "?"); 314 INFER_OK(op, "[2,4,3];[3];[3];[]", "[?,d0_1,d0_2]"); 315 316 Tensor num_segments_t = test::AsScalar(100); 317 op.input_tensors[3] = &num_segments_t; 318 INFER_OK(op, "[2,4,3];[3];[3];[]", "[100,d0_1,d0_2]"); 319 320 INFER_ERROR("Shape must be rank 0 but is rank 2", op, 321 "[2,4,3];[3];[3];[1,1]"); 322 323 // Negative value is not allowed 324 num_segments_t = test::AsScalar(-100); 325 op.input_tensors[3] = &num_segments_t; 326 INFER_ERROR("Cannot specify a negative value", op, "[2,4,3];[3];[3];[]"); 327} 328 329TEST(MathOpsTest, BatchMatMul_ShapeFn) { 330 ShapeInferenceTestOp op("BatchMatMul"); 331 auto set_adj = [&op](bool adj_x, bool adj_y) { 332 TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMul") 333 .Input({"a", 0, DT_FLOAT}) 334 .Input({"b", 0, DT_FLOAT}) 335 .Attr("adj_x", adj_x) 336 .Attr("adj_y", adj_y) 337 .Finalize(&op.node_def)); 338 }; 339 340 set_adj(false, false); 341 342 // Rank checks. 343 INFER_ERROR("at least rank 2", op, "[1];?"); 344 INFER_ERROR("at least rank 2", op, "?;[2]"); 345 346 INFER_OK(op, "?;?", "?"); 347 348 // 0 batch dims. 349 INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]"); 350 351 // 2 batch dims. 352 INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]"); 353 354 // Test adj_a, testing output and that inner dims are compared. 355 set_adj(false, false); 356 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); 357 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]"); // inner dim mismatch 358 set_adj(true, false); 359 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]"); 360 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]"); // inner dim mismatch 361 362 // Test adj_b=true. 363 set_adj(false, true); 364 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]"); 365 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]"); // inner dim mismatch 366 set_adj(true, true); 367 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]"); 368 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]"); // inner dim mismatch 369} 370 371TEST(MathOpsTest, ArgOps_ShapeFn) { 372 ShapeInferenceTestOp op("ArgMax"); 373 op.input_tensors.resize(2); 374 375 INFER_OK(op, "?;?", "?"); 376 377 // input rank <= 1 produces scalar 378 INFER_OK(op, "[2];?", "[]"); 379 INFER_OK(op, "[];?", "[]"); 380 381 // Incorrect rank for dimension 382 INFER_ERROR("must be rank 0", op, "[2];[1]"); 383 384 // dimension not available, but input rank is. Output is unknown 385 // shape with rank one less than input rank. 386 INFER_OK(op, "[2,3,4];?", "[?,?]"); 387 INFER_OK(op, "[2,3,4,5,6];?", "[?,?,?,?]"); 388 389 // Dimension values known 390 Tensor dimension = test::AsScalar(0); 391 op.input_tensors[1] = &dimension; 392 INFER_OK(op, "[2,3,4];[]", "[d0_1,d0_2]"); 393 394 dimension = test::AsScalar(1); 395 op.input_tensors[1] = &dimension; 396 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_2]"); 397 398 dimension = test::AsScalar(2); 399 op.input_tensors[1] = &dimension; 400 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]"); 401 402 // Dimension value out of bounds 403 dimension = test::AsScalar(10); 404 op.input_tensors[1] = &dimension; 405 INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]"); 406 407 dimension = test::AsScalar(-10); 408 op.input_tensors[1] = &dimension; 409 INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]"); 410} 411 412TEST(MathOpsTest, Betainc_ShapeFn) { 413 ShapeInferenceTestOp op("Betainc"); 414 415 INFER_OK(op, "?;?;?", "?"); 416 INFER_OK(op, "[?,?];?;?", "in0"); 417 INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]"); 418 INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]"); 419 420 INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]"); 421 INFER_OK(op, "[];[];[?,?,3]", "in2"); 422 423 // All but one is a scalar, so use it. 424 INFER_OK(op, "[];[];?", "in2"); 425 INFER_OK(op, "[];[];[1,2,3,4]", "in2"); 426 427 // All scalar input; implementation picks in0. 428 INFER_OK(op, "[];[];[]", "in0"); 429 430 // Non-scalars must match shape. 431 INFER_ERROR("must be equal", op, "[1,2];[];[1,4]"); 432 INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]"); 433} 434 435} // end namespace tensorflow 436