Lines Matching refs:op

18 #include "tensorflow/core/framework/op.h"
31 ShapeInferenceTestOp op("Pack");
32 auto set_axis = [&op](int axis) {
41 .Finalize(&op.node_def));
45 INFER_OK(op, "?;?;?", "?");
49 INFER_OK(op, "?;?;?", "?");
50 INFER_OK(op, "[1,3];[1,3];?", "[3,d0_0|d1_0,d0_1|d1_1]");
51 INFER_OK(op, "[?,3];[1,3];?", "[3,d1_0,d0_1|d1_1]");
52 INFER_OK(op, "[?,?];[1,3];?", "[3,d1_0,d1_1]");
56 INFER_OK(op, "?;?;?", "?");
57 INFER_OK(op, "[1,3];[1,3];?", "[d0_0|d1_0,3,d0_1|d1_1]");
58 INFER_OK(op, "[?,3];[1,3];?", "[d1_0,3,d0_1|d1_1]");
59 INFER_OK(op, "[?,?];[1,3];?", "[d1_0,3,d1_1]");
63 INFER_OK(op, "?;?;?", "?");
64 INFER_OK(op, "[1,3];[1,3];?", "[d0_0|d1_0,d0_1|d1_1,3]");
65 INFER_OK(op, "[?,3];[1,3];?", "[d1_0,d0_1|d1_1,3]");
66 INFER_OK(op, "[?,?];[1,3];?", "[d1_0,d1_1,3]");
70 INFER_ERROR("Invalid axis: -4; must be in [-3,3)", op, "[1,3];[1,3];?");
72 INFER_ERROR("Invalid axis: 3; must be in [-3,3)", op, "[1,3];[1,3];?");
77 INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op,
79 INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2,3];?;[1,4]");
83 ShapeInferenceTestOp op("Unpack");
84 auto set_axis_and_num = [&op](int axis, int num) {
89 .Finalize(&op.node_def));
93 INFER_OK(op, "?", "?");
97 INFER_OK(op, "?", "?");
98 INFER_OK(op, "[1,2,3]", "[d0_1,d0_2]");
99 INFER_OK(op, "[?,?,?]", "[d0_1,d0_2]");
103 INFER_OK(op, "[1,2,3]", "[d0_0,d0_2];[d0_0,d0_2]");
104 INFER_OK(op, "[?,?,?]", "[d0_0,d0_2];[d0_0,d0_2]");
108 INFER_OK(op, "[1,2,3]", "[d0_0,d0_1];[d0_0,d0_1];[d0_0,d0_1]");
109 INFER_OK(op, "[?,?,?]", "[d0_0,d0_1];[d0_0,d0_1];[d0_0,d0_1]");
113 INFER_ERROR("Dimension must be 2 but is 3", op, "[1,2,3]");
116 INFER_ERROR("Invalid axis: -4; must be in [-3,3)", op, "[1,2,3]");
118 INFER_ERROR("Invalid axis: 3; must be in [-3,3)", op, "[1,2,3]");
122 ShapeInferenceTestOp op("Const");
125 auto rebuild_node_def = [&op, &tensor_proto]() {
128 .Finalize(&op.node_def));
133 INFER_OK(op, "", "[]");
136 INFER_OK(op, "", "[1,2,3,4]");
140 INFER_ERROR("Shape [1,2,3,4,?] is not fully defined", op, "");
153 ShapeInferenceTestOp op(op_name);
154 INFER_OK(op, "?", "in0");
155 INFER_OK(op, "[]", "in0");
156 INFER_OK(op, "[1,2,?,4,5]", "in0");
160 ShapeInferenceTestOp op("MatrixBandPart");
161 INFER_OK(op, "?;?;?", "in0");
162 INFER_OK(op, "[];?;?", "in0");
163 INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
167 ShapeInferenceTestOp op("GuaranteeConst");
168 INFER_OK(op, "?", "in0");
169 INFER_OK(op, "[]", "in0");
170 INFER_OK(op, "[1,2,?,4,5]", "in0");
175 ShapeInferenceTestOp op(op_name);
178 TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
185 shape_inference::InferenceContext c(TF_GRAPH_DEF_VERSION, &op.node_def,
199 ShapeInferenceTestOp op("Diag");
200 INFER_OK(op, "?", "?");
201 INFER_OK(op, "[1,?,3]", "[d0_0,d0_1,d0_2,d0_0,d0_1,d0_2]");
202 INFER_OK(op, "[?,1,2,3]", "[d0_0,d0_1,d0_2,d0_3,d0_0,d0_1,d0_2,d0_3]");
203 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
207 ShapeInferenceTestOp op("DiagPart");
208 INFER_OK(op, "?", "?");
209 INFER_OK(op, "[1,?,?,4]", "[d0_0,d0_3]");
210 INFER_OK(op, "[1,?,3,?,4,3]", "[d0_0,d0_4,d0_2|d0_5]");
211 INFER_OK(op, "[1,2,3,?,?,?,?,4]", "[d0_0,d0_1,d0_2,d0_7]");
212 INFER_ERROR("Input must have even and non-zero rank", op, "[]");
213 INFER_ERROR("Input must have even and non-zero rank", op, "[?]");
214 INFER_ERROR("Input must have even and non-zero rank", op, "[1,2,3]");
215 INFER_ERROR("Dimensions must be equal, but are 2 and 10", op, "[1,2,?,10]");
219 ShapeInferenceTestOp op("MatrixDiag");
220 INFER_OK(op, "?", "?");
221 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
222 INFER_OK(op, "[?]", "[d0_0,d0_0]");
223 INFER_OK(op, "[1,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,d0_3]");
227 ShapeInferenceTestOp op("MatrixDiagPart");
228 INFER_OK(op, "?", "?");
229 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[?]");
230 INFER_OK(op, "[?,1,2,2]", "[d0_0,d0_1,d0_2|d0_3]");
231 INFER_OK(op, "[?,1,2,3]", "[d0_0,d0_1,d0_2]");
232 INFER_OK(op, "[?,1,3,2]", "[d0_0,d0_1,d0_3]");
236 ShapeInferenceTestOp op("Reverse");
237 INFER_OK(op, "?;?", "in0");
238 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
239 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,2]");
240 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4]");
242 op, "[1,2,3,4,5,6,7,8,9];[9]");
243 INFER_OK(op, "[1,2,3,?];[4]", "in0");
244 INFER_OK(op, "[1,2,3,?,5,6,7,8];[8]", "in0");
248 ShapeInferenceTestOp op("ReverseV2");
249 INFER_OK(op, "?;?", "in0");
250 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
251 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,2]");
252 INFER_OK(op, "[1,2,3];[2]", "in0");
254 op, "[1,2,3,4,5,6,7,8,9];[9]");
255 INFER_OK(op, "[1,2,3,?];[4]", "in0");
256 INFER_OK(op, "[1,2,3,?,5,6,7,8];[8]", "in0");
260 ShapeInferenceTestOp op("Fill");
261 AddNodeAttr("index_type", DT_INT32, &op.node_def);
262 op.input_tensors.resize(2);
263 INFER_OK(op, "?;?", "?");
264 INFER_OK(op, "[?];?", "?");
265 INFER_OK(op, "[4];?", "[?,?,?,?]");
268 op.input_tensors[0] = &in_t;
269 INFER_OK(op, "[4];?", "[1,2,3,4]");
273 ShapeInferenceTestOp op("Gather");
274 INFER_OK(op, "?;?", "?");
275 INFER_OK(op, "[1,?,2];[3]", "[d1_0,d0_1,d0_2]");
276 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1,2,3]");
280 ShapeInferenceTestOp op("GatherV2");
283 INFER_OK(op, "?;?;?", "?");
284 INFER_OK(op, "[1,2,3];[3];[]", "[?,?,?]");
285 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op,
289 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];[1,2,3];[1]");
293 op.input_tensors.resize(3);
294 op.input_tensors[2] = &axis_dim_t;
298 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
303 INFER_OK(op, "[1,2,3];[];[]", "[d0_1,d0_2]");
305 INFER_OK(op, "[1,2,3];[];[]", "[d0_0,d0_2]");
307 INFER_OK(op, "[1,2,3];[];[]", "[d0_0,d0_1]");
311 INFER_OK(op, "[1,2,3];[5];[]", "[d1_0,d0_1,d0_2]");
313 INFER_OK(op, "[1,2,3];[5];[]", "[d0_0,d1_0,d0_2]");
315 INFER_OK(op, "[1,2,3];[5];[]", "[d0_0,d0_1,d1_0]");
319 INFER_OK(op, "[1,2,3];[5,6];[]", "[d1_0,d1_1,d0_1,d0_2]");
321 INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d1_0,d1_1,d0_2]");
323 INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d0_1,d1_0,d1_1]");
327 INFER_OK(op, "[1,2,3];[5,6];[]", "[d1_0,d1_1,d0_1,d0_2]");
329 INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d1_0,d1_1,d0_2]");
331 INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d0_1,d1_0,d1_1]");
335 ShapeInferenceTestOp op("GatherNd");
338 INFER_OK(op, "?;?", "?");
339 INFER_OK(op, "[1,?,3,?];[?,0]", "[d1_0,d0_0,d0_1,d0_2,d0_3]");
340 INFER_OK(op, "[1,?,3,?];[?,4]", "[d1_0]");
343 INFER_ERROR("indices.shape[-1] must be <= params.rank", op, "[1,2,3];[4]");
347 ShapeInferenceTestOp op("Shape");
348 INFER_OK(op, "?", "[?]");
349 INFER_OK(op, "[?]", "[1]");
350 INFER_OK(op, "[?,2,3,4,5]", "[5]");
354 ShapeInferenceTestOp op("ShapeN");
362 .Finalize(&op.node_def));
363 INFER_OK(op, "?;?;?", "[?];[?];[?]");
364 INFER_OK(op, "[?];[?];[?]", "[1];[1];[1]");
365 INFER_OK(op, "[?,2,3,4,5];?;[1,?,3]", "[5];[?];[3]");
369 ShapeInferenceTestOp op("Unique");
370 INFER_OK(op, "?", "[?];in0");
371 INFER_OK(op, "[1,2,3,?,5]", "[?];in0");
375 ShapeInferenceTestOp op("UniqueWithCounts");
376 INFER_OK(op, "?", "[?];in0;[?]");
377 INFER_OK(op, "[1,2,3,?,5]", "[?];in0;[?]");
381 ShapeInferenceTestOp op("InvertPermutation");
382 INFER_OK(op, "?", "[?]");
383 INFER_OK(op, "[1]", "in0");
384 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[]");
389 ShapeInferenceTestOp op(op_name);
390 op.input_tensors.resize(2);
394 INFER_OK(op, "?;?", "?");
397 INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;[1,2,3]");
398 INFER_ERROR("Dimension must be 2 but is 4", op, "?;[1,4]");
402 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4,2]");
403 INFER_OK(op, "[1,2,3];?", "[?,?,?]");
404 INFER_OK(op, "?;[3,2]", "[?,?,?]");
411 op.input_tensors[1] = &paddings_t;
412 INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
413 INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
414 INFER_OK(op, "?;[3,2]", "[?,?,?]");
415 INFER_OK(op, "?;?", "[?,?,?]");
420 ShapeInferenceTestOp op("PadV2");
421 op.input_tensors.resize(3);
425 INFER_OK(op, "?;?;?", "?");
428 INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;[1,2,3];?");
429 INFER_ERROR("Dimension must be 2 but is 4", op, "?;[1,4];?");
433 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4,2];[]");
434 INFER_OK(op, "[1,2,3];?;[]", "[?,?,?]");
435 INFER_OK(op, "?;[3,2];[]", "[?,?,?]");
442 op.input_tensors[1] = &paddings_t;
443 INFER_OK(op, "[100,200,300];[3,2];[]", "[111,222,333]");
444 INFER_OK(op, "[100,?,300];[3,2];[]", "[111,?,333]");
445 INFER_OK(op, "?;[3,2];[]", "[?,?,?]");
446 INFER_OK(op, "?;?;[]", "[?,?,?]");
450 ShapeInferenceTestOp op("MirrorPadGrad");
451 op.input_tensors.resize(2);
454 INFER_OK(op, "?;?", "?");
457 INFER_OK(op, "?;[?,4]", "?");
460 INFER_ERROR("must be rank 3 but is rank 2", op, "[?,?];[3,2]");
463 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 3 and 2", op,
468 INFER_OK(op, "[?,?,?];[3,2]", "[?,?,?]");
475 op.input_tensors[1] = &paddings_t;
477 INFER_OK(op, "[111,222,333];[3,2]", "[100,200,300]");
478 INFER_OK(op, "[111,?,333];[3,2]", "[100,?,300]");
482 ShapeInferenceTestOp op("BroadcastArgs");
483 INFER_OK(op, "?;?", "[?]");
484 INFER_OK(op, "[123];[1]", "[123]");
485 INFER_OK(op, "[1];[123]", "[123]");
486 INFER_OK(op, "[123];[121]", "[123]");
489 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
490 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
494 ShapeInferenceTestOp op("BroadcastGradientArgs");
496 INFER_OK(op, "?;?", "[?];[?]");
497 INFER_OK(op, "[123];[456]", "[?];[?]");
500 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
501 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
505 ShapeInferenceTestOp op("BroadcastGradientArgs");
507 INFER_OK(op, "?;?", "[?];[?]");
508 INFER_OK(op, "[123];[456]", "[?];[?]");
511 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
512 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
516 ShapeInferenceTestOp op("MatrixSetDiag");
521 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
522 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "?;[]");
523 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[2,2];[]");
524 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,2];[2,2]");
527 INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[2,3];[3]");
530 INFER_OK(op, "?;?", "in0");
531 INFER_OK(op, "[1,2,2];[1,2]", "in0");
532 INFER_OK(op, "[1,2,3];?", "in0");
533 INFER_OK(op, "[1,3,2];?", "in0");
534 INFER_OK(op, "[1,?,2];[?,?]", "in0");
535 INFER_OK(op, "[1,?,?];[?,2]", "in0");
538 INFER_OK(op, "?;[1,2]", "[d1_0,?,?]");
539 INFER_OK(op, "[?,?,3];[1,2]", "[d1_0,d0_1,d0_2]");
540 INFER_OK(op, "[?,3,?];[1,2]", "[d1_0,d0_1,d0_2]");
541 INFER_OK(op, "[?,3,2];[1,2]", "[d1_0,d0_1,d0_2]");
545 ShapeInferenceTestOp op("ExpandDims");
546 op.input_tensors.resize(2);
549 INFER_OK(op, "?;?", "?");
551 op.input_tensors[1] = &dim_t;
556 INFER_OK(op, "?;?", "?");
557 INFER_OK(op, "[5,?,7];?", "[1,d0_0,d0_1,d0_2]");
563 INFER_OK(op, "?;?", "?");
564 INFER_OK(op, "[5,?,7];?", "[d0_0,1,d0_1,d0_2]");
568 INFER_OK(op, "?;?", "?");
569 INFER_OK(op, "[5,?,7];?", "[d0_0,1,d0_1,d0_2]");
573 INFER_OK(op, "?;?", "?");
574 INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,1,d0_2]");
578 INFER_OK(op, "?;?", "?");
579 INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,1,d0_2]");
585 INFER_OK(op, "?;?", "?");
586 INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,d0_2,1]");
590 INFER_OK(op, "?;?", "?");
591 INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,d0_2,1]");
596 INFER_ERROR("not in the interval [-4, 3]", op, "[5,?,7];?");
598 INFER_ERROR("not in the interval [-4, 3]", op, "[5,?,7];?");
605 INFER_OK(op, "?;?", "?");
606 INFER_OK(op, "[5,?,7];?", "[1,d0_0,d0_1,d0_2]");
611 INFER_ERROR("'dim' input must be a tensor with a single", op, "?;?");
612 INFER_ERROR("'dim' input must be a tensor with a single", op, "[5,6,7];?");
616 INFER_OK(op, "[2];[]", "[1,d0_0]");
618 INFER_OK(op, "[2];[]", "[d0_0,1]");
620 INFER_OK(op, "[2];[]", "[d0_0,1]");
624 ShapeInferenceTestOp op("ImmutableConst");
630 .Finalize(&op.node_def));
631 INFER_OK(op, "", "[1,2,3]");
637 .Finalize(&op.node_def));
638 INFER_OK(op, "", "[]");
644 .Finalize(&op.node_def));
646 op, "");
650 ShapeInferenceTestOp op("Concat");
651 auto set_n = [&op](int n) {
659 .Finalize(&op.node_def));
664 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?;?");
669 INFER_OK(op, "?;?;?;?;[1,2,3];?;[3,2,1];?", "[?,?,?]");
671 INFER_OK(op, "?;?;?;[1,2,3,4];[4,3,2,1]", "[?,?,?,?]");
672 INFER_OK(op, "?;?;?;?;?", "?"); // output rank unknown
673 INFER_ERROR("Can't concatenate scalars (use tf.stack instead)", op,
675 INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;?;?;[1,2];[1,2,3]");
680 op.input_tensors.push_back(&concat_dim_t);
686 INFER_OK(op, "[];[100,2,?];[10,?,3]", "[110,d1_1,d2_2]");
687 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
690 INFER_OK(op, "[];[100,2,?];[?,?,3]", "[?,d1_1,d2_2]");
691 INFER_OK(op, "[];[?,2,?];[10,?,3]", "[?,d1_1,d2_2]");
697 INFER_OK(op, "[];[1,100,?];[?,10,3]", "[d1_0,110,d2_2]");
699 INFER_OK(op, "[];[1,100];[?,10]", "[d1_0,110]");
700 INFER_OK(op, "[];[?,100];[1,10]", "[d2_0,110]");
704 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
706 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
712 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
714 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
720 INFER_OK(op, "[];?;[1,100,?];[?,?,?];[?,10,3];?", "[d2_0,?,d4_2]");
724 ShapeInferenceTestOp op("ConcatV2");
725 auto set_n = [&op](int n) {
733 .Finalize(&op.node_def));
738 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;?;[1]");
743 INFER_OK(op, "?;?;?;?;[1,2,3];?;[3,2,1];?", "[?,?,?]");
745 INFER_OK(op, "?;?;[1,2,3,4];[4,3,2,1];?", "[?,?,?,?]");
746 INFER_OK(op, "?;?;?;?;?", "?"); // output rank unknown
747 INFER_ERROR("Can't concatenate scalars (use tf.stack instead)", op,
749 INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;?;[1,2];[1,2,3];?");
754 op.input_tensors.resize(3);
755 op.input_tensors[2] = &concat_dim_t;
761 // INFER_ERROR("Expected concat_dim >= 0, but got -1", op, "?;?;?");
765 INFER_OK(op, "[100,2,?];[10,?,3];[]", "[110,d0_1,d1_2]");
766 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
769 INFER_OK(op, "[100,2,?];[?,?,3];[]", "[?,d0_1,d1_2]");
770 INFER_OK(op, "[?,2,?];[10,?,3];[]", "[?,d0_1,d1_2]");
774 INFER_OK(op, "[1,100,?];[?,10,3];[]", "[d0_0,110,d1_2]");
775 INFER_OK(op, "[1,100];[?,10];[]", "[d0_0,110]");
776 INFER_OK(op, "[?,100];[1,10];[]", "[d1_0,110]");
778 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
780 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
784 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
786 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
790 op.input_tensors.resize(6);
791 op.input_tensors[3] = nullptr;
792 op.input_tensors[5] = &concat_dim_t;
796 INFER_OK(op, "?;[1,100,?];[?,?,?];[?,10,3];?;[]", "[d1_0,?,d3_2]");
800 ShapeInferenceTestOp op("ConcatOffset");
810 .Finalize(&op.node_def));
811 INFER_OK(op, "?;?;?;?;?", "in1;in2;in3;in4");
815 ShapeInferenceTestOp op("Reshape");
816 op.input_tensors.resize(2);
819 INFER_OK(op, "?;?", "?");
820 INFER_OK(op, "[?];?", "?");
821 INFER_OK(op, "[?];[?]", "?");
822 INFER_OK(op, "[4];[?]", "?");
826 op.input_tensors[1] = &new_shape;
827 INFER_OK(op, "[?];[3]", "[1,2,3]");
828 INFER_OK(op, "[6];[3]", "[1,2,3]");
832 op, "[3,4];[3]");
837 INFER_OK(op, "[?];[1]", "[?]");
838 INFER_OK(op, "[2,2];[1]", "[4]");
841 INFER_OK(op, "[3,4];[2]", "[2,6]");
844 INFER_ERROR("Dimension size must be evenly divisible by 2 but is 7", op,
848 INFER_OK(op, "[8];[3]", "[?,?,2]");
852 INFER_OK(op, "[1];[0]", "[]");
854 "Cannot reshape a tensor with 2 elements to shape [] (1 elements)", op,
859 INFER_OK(op, "[0];[1]", "[0]");
861 INFER_OK(op, "[0,2];[1]", "[0,6]");
863 INFER_OK(op, "[0,2];[1]", "[0,?]");
867 ShapeInferenceTestOp op("QuantizedReshape");
868 op.input_tensors.resize(2);
872 INFER_OK(op, "?;?;?;?", "?;[];[]");
873 INFER_OK(op, "[?];?;?;?", "?;[];[]");
874 INFER_OK(op, "[?];[?];?;?", "?;[];[]");
875 INFER_OK(op, "[4];[?];?;?", "?;[];[]");
877 op.input_tensors[1] = &new_shape;
878 INFER_OK(op, "[?];[3];?;?", "[1,2,3];[];[]");
879 INFER_OK(op, "[6];[3];?;?", "[1,2,3];[];[]");
882 op, "[3,4];[3];?;?");
885 INFER_ERROR("must be rank 0", op, "?;?;[1];?");
886 INFER_ERROR("must be rank 0", op, "?;?;?;[1]");
892 ShapeInferenceTestOp op("Placeholder");
897 .Finalize(&op.node_def));
898 INFER_OK(op, "", "[1,2]");
903 ShapeInferenceTestOp op("Placeholder");
908 .Finalize(&op.node_def));
909 INFER_OK(op, "", "[]");
914 ShapeInferenceTestOp op("Placeholder");
921 .Finalize(&op.node_def));
922 INFER_OK(op, "", "[1,?]");
927 ShapeInferenceTestOp op("Placeholder");
932 .Finalize(&op.node_def));
933 INFER_OK(op, "", "?");
938 ShapeInferenceTestOp op("Transpose");
939 op.input_tensors.resize(2);
942 INFER_OK(op, "?;?", "?");
943 INFER_OK(op, "?;[?]", "?");
944 INFER_OK(op, "?;[2]", "[?,?]");
945 INFER_OK(op, "[?];?", "[?]");
946 INFER_OK(op, "[?,?];[2]", "[?,?]");
947 INFER_ERROR("Dimension must be 3 but is 2", op, "[1,2,3];[2]");
949 op.input_tensors[1] = &perm;
950 INFER_OK(op, "[?];[?]", "[d0_0]");
952 INFER_OK(op, "?;[2]", "[?,?]");
953 INFER_OK(op, "[?,?];[2]", "[d0_1,d0_0]");
954 INFER_OK(op, "[1,?];[2]", "[d0_1,d0_0]");
958 INFER_ERROR("perm dim 2 is out of range of input rank 2", op, "[1,2];[2]");
960 INFER_ERROR("Dimension must be 2 but is 1", op, "[1,2];[1]");
964 INFER_OK(op, "[0,1,2,3,4];[5]", "[d0_1,d0_0,d0_3,d0_4,d0_2]");
965 INFER_OK(op, "[0,?,2,3,4];[5]", "[d0_1,d0_0,d0_3,d0_4,d0_2]");
969 ShapeInferenceTestOp op("Bitcast");
970 auto rebuild_node_def = [&op](DataType input_type, DataType output_type) {
974 .Finalize(&op.node_def));
979 INFER_OK(op, "?", "?");
982 INFER_OK(op, "[1,2]", "in0");
986 INFER_OK(op, "[1,2]", "[d0_0]"); // last dimension matches divisor.
988 INFER_OK(op, "[1,?]", "[d0_0]");
991 INFER_ERROR("does not match", op, "[1,4]");
992 INFER_ERROR("does not match", op, "[1,3]");
996 INFER_OK(op, "[4,5]", "[d0_0,d0_1,2]");
998 INFER_OK(op, "[4,5]", "[d0_0,d0_1,4]");
1000 INFER_OK(op, "[4,5]", "[d0_0,d0_1,8]");
1002 INFER_OK(op, "[4,5]", "[d0_0,d0_1,16]");
1006 INFER_ERROR("one of the type sizes is zero", op, "[1,2,3]");
1008 INFER_ERROR("one of the type sizes is zero", op, "[1,2,3]");
1012 ShapeInferenceTestOp op("Squeeze");
1014 auto rebuild_node_def = [&op](const std::vector<int32>& squeeze_dims) {
1018 .Finalize(&op.node_def));
1025 INFER_OK(op, "?", "?");
1027 INFER_OK(op, "[1,4,1,5,1]", "[d0_1,d0_3]");
1030 INFER_OK(op, "[1,?,1,?,1]", "?");
1034 INFER_OK(op, "[4,1,5]", "[d0_0,d0_2]");
1036 INFER_OK(op, "[4,?,5]", "[d0_0,d0_2]");
1039 INFER_ERROR("Can not squeeze dim[1]", op, "[4,6,5]");
1043 INFER_OK(op, "[4,1,1,5]", "[d0_0,d0_3]");
1045 INFER_OK(op, "[4,1,1,5]", "[d0_0,d0_3]");
1049 INFER_OK(op, "[4,1,5]", "[d0_0,d0_2]");
1053 INFER_ERROR("not in [-3,3)", op, "[1,2,3]");
1055 INFER_ERROR("not in [-3,3)", op, "[1,2,3]");
1059 ShapeInferenceTestOp op("ReverseSequence");
1060 auto rebuild_node_def = [&op](const int32 seq_dim, const int32 batch_dim) {
1066 .Finalize(&op.node_def));
1071 INFER_OK(op, "?;[10]", "?");
1074 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[10,10]");
1078 INFER_ERROR("batch_dim must be < input rank", op, "[1,2,3];[3]");
1080 INFER_ERROR("seq_dim must be < input rank", op, "[1,2,3];[3]");
1083 INFER_OK(op, "[1,2,3];[3]", "[d0_0,d0_1,d0_2]");
1085 INFER_OK(op, "[1,2,?];[3]", "[d0_0,d0_1,d1_0]");
1086 INFER_OK(op, "[1,2,3];[?]", "[d0_0,d0_1,d0_2]");
1090 ShapeInferenceTestOp op("Split");
1091 op.input_tensors.resize(2);
1098 .Finalize(&op.node_def));
1099 INFER_OK(op, "?;?", "?;?");
1101 INFER_OK(op, "?;[?,?]", "[?,?];[?,?]");
1104 INFER_OK(op, "?;[1,4]", "[?,?];[?,?]");
1108 op.input_tensors[0] = &split_dim;
1109 INFER_ERROR("Input must be scalar but has rank 1", op, "[?];[?,?]");
1111 INFER_OK(op, "?;?", "?;?");
1112 INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
1113 INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
1114 INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
1115 INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
1121 "Dimension size, given by scalar input 3 must be in range [-3, 3)", op,
1126 INFER_OK(op, "?;?", "?;?");
1127 INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
1128 INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
1129 INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
1130 INFER_OK(op, "?;[1,4,8]", "[d1_0,d1_1,4];[d1_0,d1_1,4]");
1132 INFER_OK(op, "?;[1,4,8]", "[d1_0,2,d1_2];[d1_0,2,d1_2]");
1135 "Dimension size, given by scalar input -4 must be in range [-3, 3)", op,
1140 ShapeInferenceTestOp op("Tile");
1141 op.input_tensors.resize(2);
1147 .Finalize(&op.node_def));
1150 INFER_OK(op, "?;?", "?");
1153 INFER_OK(op, "[2,3,1,4];?", "[?,?,?,?]");
1156 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,3,1,4];[4,1]");
1159 INFER_OK(op, "?;[4]", "[?,?,?,?]");
1163 op.input_tensors[1] = &multiples;
1164 INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
1167 INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
1171 ShapeInferenceTestOp op("EditDistance");
1172 op.input_tensors.resize(6);
1175 INFER_OK(op, "[?,?];[?];[4];[?,?];[?];[4]", "?");
1178 op.input_tensors[2] = &hypothesis_shape;
1180 op.input_tensors[5] = &truth_shape;
1181 INFER_OK(op, "[?,?];[?];[4];[?,?];[?];[4]", "[20,30,40]");
1185 op.input_tensors[2] = &hypothesis_shape;
1186 INFER_ERROR("Num elements of hypothesis_shape does not match truth_shape", op,
1191 ShapeInferenceTestOp op("OneHot");
1192 op.input_tensors.resize(4);
1193 auto set_axis = [&op](int axis) {
1200 .Finalize(&op.node_def));
1205 INFER_ERROR("axis must be >= -1", op, "?;?;?;?");
1209 INFER_OK(op, "?;[];?;?", "?");
1213 op.input_tensors[1] = &depth;
1214 INFER_ERROR("Input must be scalar but has rank 1", op, "?;[2];?;?");
1218 INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,2,d0_1,d0_2]");
1220 INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,d0_1,d0_2,2]");
1224 ShapeInferenceTestOp op("ExtractImagePatches");
1225 auto set_op = [&op](const std::vector<int32>& ksizes,
1234 .Finalize(&op.node_def));
1245 INFER_OK(op, "[1,7,7,2]", "[d0_0,5,5,8]");
1249 INFER_OK(op, "[1,7,7,2]", "[d0_0,7,7,d0_3]");
1256 op, "[1,7,7,2]");
1260 ShapeInferenceTestOp op("QuantizeAndDequantizeV2");
1261 INFER_OK(op, "?;?;?", "in0");
1262 INFER_OK(op, "[];?;?", "in0");
1263 INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
1265 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1,2,?,4,5];[1];[]");
1266 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1,2,?,4,5];[];[1]");
1267 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1,2,?,4,5];[1];[1]");
1271 ShapeInferenceTestOp op("SpaceToBatch");
1272 op.input_tensors.resize(2);
1277 .Finalize(&op.node_def));
1280 INFER_OK(op, "[1,10,10,3];[2,2]", "[4,?,?,d0_3]");
1283 INFER_OK(op, "[1,10,10,3];?", "[4,?,?,d0_3]");
1286 INFER_ERROR("rank", op, "[1,10,10,3];[4]");
1287 INFER_ERROR("3 and 2", op, "[1,10,10,3];[2,3]");
1290 op.input_tensors[1] = &paddings;
1291 INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
1293 INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
1297 op.input_tensors[1] = &paddings;
1298 INFER_ERROR("Dimension size must be evenly divisible by 2 but is 13", op,
1303 op.input_tensors[1] = &paddings;
1304 INFER_ERROR("cannot be negative", op, "[1,10,10,3];[2,2]");
1308 ShapeInferenceTestOp op("SpaceToBatchND");
1309 op.input_tensors.resize(3);
1314 .Finalize(&op.node_def));
1317 INFER_OK(op, "?;[2];?", "?");
1320 INFER_OK(op, "[?,?,?,?];[2];?", "[?,?,?,d0_3]");
1323 INFER_OK(op, "[?,?,?,2];[2];?", "[?,?,?,d0_3]");
1328 op.input_tensors[1] = &block_shape;
1329 INFER_OK(op, "[3,?,?,2];[2];?", "[18,?,?,d0_3]");
1334 op.input_tensors[2] = &paddings;
1335 INFER_OK(op, "[3,?,2,2];[2];[2,2]", "[18,?,1,d0_3]");
1336 op.input_tensors[2] = nullptr;
1342 op.input_tensors[2] = &paddings;
1343 INFER_OK(op, "[3,2,3,2];[2];[2,2]", "[18,2,1,d0_3]");
1344 op.input_tensors[2] = nullptr;
1347 op.input_tensors[1] = nullptr;
1350 INFER_ERROR("block_shape must have rank 1", op, "?;[1,1];?");
1351 INFER_ERROR("block_shape must have known size", op, "?;[?];?");
1355 op.input_tensors[1] = &block_shape;
1356 INFER_ERROR("block_shape must be positive", op, "[1,2,2];[2];[2,2]");
1357 op.input_tensors[1] = nullptr;
1362 op.input_tensors[1] = &block_shape;
1364 op.input_tensors[2] = &paddings;
1365 INFER_ERROR("paddings cannot be negative", op, "[1,2,2];[2];[2,2]");
1366 op.input_tensors[1] = nullptr;
1367 op.input_tensors[2] = nullptr;
1372 op.input_tensors[1] = &block_shape;
1374 op.input_tensors[2] = &paddings;
1375 INFER_ERROR("divisible", op, "[1,2,3,1];[2];[2,2]");
1376 op.input_tensors[1] = nullptr;
1377 op.input_tensors[2] = nullptr;
1380 INFER_ERROR("rank", op, "[1,3,3,1];[2];[1]");
1381 INFER_ERROR("shape", op, "[1,3,3,1];[2];[1,2]");
1385 ShapeInferenceTestOp op("BatchToSpace");
1386 op.input_tensors.resize(2);
1391 .Finalize(&op.node_def));
1394 INFER_OK(op, "[4,8,8,3];[2,2]", "[1,?,?,d0_3]");
1397 INFER_ERROR("Dimension size must be evenly divisible by", op,
1401 INFER_OK(op, "[4,8,8,3];?", "[1,?,?,d0_3]");
1404 INFER_ERROR("rank", op, "[4,8,8,3];[4]");
1405 INFER_ERROR("3 and 2", op, "[4,8,8,3];[2,3]");
1408 op.input_tensors[1] = &croppings;
1409 INFER_OK(op, "[4,8,8,3];[2,2]", "[1,10,10,d0_3]");
1413 op.input_tensors[1] = &croppings;
1414 INFER_ERROR("Negative dimension size caused by subtracting", op,
1417 op.input_tensors[1] = &croppings;
1418 INFER_ERROR("Negative dimension size caused by subtracting", op,
1423 op.input_tensors[1] = &croppings;
1424 INFER_ERROR("cannot be negative", op, "[4,8,8,3];[2,2]");
1428 ShapeInferenceTestOp op("BatchToSpaceND");
1429 op.input_tensors.resize(3);
1434 .Finalize(&op.node_def));
1437 INFER_OK(op, "?;[2];?", "?");
1440 INFER_OK(op, "[?,?,?,?];[2];?", "[?,?,?,d0_3]");
1445 op.input_tensors[1] = &block_shape;
1446 INFER_OK(op, "[?,?,?,2];[2];?", "[?,?,?,d0_3]");
1448 INFER_OK(op, "[18,?,?,2];[2];?", "[3,?,?,d0_3]");
1453 op.input_tensors[2] = &crops;
1454 INFER_OK(op, "[18,?,2,2];[2];[2,2]", "[3,?,5,d0_3]");
1455 op.input_tensors[2] = nullptr;
1461 op.input_tensors[2] = &crops;
1462 INFER_OK(op, "[18,2,1,2];[2];[2,2]", "[3,2,3,d0_3]");
1463 op.input_tensors[2] = nullptr;
1466 op.input_tensors[1] = nullptr;
1469 INFER_ERROR("block_shape must have rank 1", op, "?;[1,1];?");
1470 INFER_ERROR("block_shape must have known size", op, "?;[?];?");
1471 INFER_ERROR("rank", op, "[2,2];[2];[2,2]");
1472 INFER_ERROR("rank", op, "[2,2,3];[3];[3,2]");
1476 op.input_tensors[1] = &block_shape;
1477 INFER_ERROR("block_shape must be positive", op, "[1,2,2];[2];[2,2]");
1478 op.input_tensors[1] = nullptr;
1483 op.input_tensors[1] = &block_shape;
1485 op.input_tensors[2] = &paddings;
1486 INFER_ERROR("crops cannot be negative", op, "[1,2,2];[2];[2,2]");
1487 op.input_tensors[1] = nullptr;
1488 op.input_tensors[2] = nullptr;
1494 op.input_tensors[1] = &block_shape;
1496 op.input_tensors[2] = &crops;
1497 INFER_ERROR("Negative", op, "[4,2,3,1];[2];[2,2]");
1498 op.input_tensors[1] = nullptr;
1499 op.input_tensors[2] = nullptr;
1505 op.input_tensors[1] = &block_shape;
1506 INFER_ERROR("divisible", op, "[3,1,1,1];[2];[2,2]");
1507 op.input_tensors[1] = nullptr;
1512 ShapeInferenceTestOp op("SpaceToDepth");
1516 .Finalize(&op.node_def));
1518 INFER_OK(op, "[1,2,4,4]", "[d0_0,1,2,16]");
1521 INFER_ERROR("Dimension size must be evenly divisible by 2 but is 3", op,
1523 INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
1527 INFER_OK(op, "[1,2,4,?]", "[d0_0,1,2,?]");
1531 ShapeInferenceTestOp op("DepthToSpace");
1535 .Finalize(&op.node_def));
1537 INFER_OK(op, "[1,1,2,16]", "[d0_0,2,4,4]");
1540 INFER_ERROR("Dimension size must be evenly divisible by 4 but is 15", op,
1544 INFER_OK(op, "[1,2,4,?]", "[d0_0,4,8,?]");
1550 .Finalize(&op.node_def));
1551 INFER_OK(op, "[1,1,2,200]", "[d0_0,10,20,2]");
1555 ShapeInferenceTestOp op("Slice");
1560 .Finalize(&op.node_def));
1564 INFER_OK(op, "[2,3,4,5];[4];[4]", "[?,?,?,?]");
1567 INFER_OK(op, "[2,3,4,5];[?];[?]", "[?,?,?,?]");
1569 INFER_OK(op, "?;[?];[?]", "?");
1571 INFER_OK(op, "?;[4];[?]", "[?,?,?,?]");
1574 INFER_ERROR("must be rank 1", op, "[2,3,4,5];[2,3];[3]");
1575 INFER_ERROR("must be rank 1", op, "[2,3,4,5];[2];[3,4]");
1577 INFER_ERROR("must be rank 2", op, "[2,3,4,5];[2];[2]");
1580 op.input_tensors.resize(3);
1583 op.input_tensors[1] = &begin;
1584 op.input_tensors[2] = &sizes;
1585 INFER_OK(op, "[2,3,4,5];[4];[4]", "[1,2,1,3]");
1589 INFER_OK(op, "[2,3,4,5];[4];[4]", "[d0_0,2,1,4]");
1593 INFER_ERROR("Negative dimension size", op, "[2,3,4,5];[4];[4]");
1597 INFER_ERROR("cannot be < -1", op, "[2,3,4,5];[4];[4]");
1601 ShapeInferenceTestOp op("StridedSliceGrad");
1602 op.input_tensors.resize(5);
1603 INFER_OK(op, "?;?;?;?;?", "?");
1604 INFER_OK(op, "[?];?;?;?;?", "?");
1605 INFER_OK(op, "[4];?;?;?;?", "[?,?,?,?]");
1608 op.input_tensors[0] = &in_t;
1609 INFER_OK(op, "[4];?;?;?;?", "[1,2,3,4]");
1614 ShapeInferenceTestOp op(op_name);
1616 INFER_OK(op, "?;?;?", "in0");
1617 INFER_OK(op, "[1,?,3];[];[]", "in0");
1620 INFER_ERROR("be rank 0", op, "[1,?,3];[1];[]");
1621 INFER_ERROR("be rank 0", op, "[1,?,3];[];[1]");
1626 ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannel");
1628 INFER_OK(op, "?;?;?", "in0");
1629 INFER_OK(op, "[?];?;?", "in0");
1630 INFER_OK(op, "[1,?,3];[3];[3]", "in0");
1631 INFER_OK(op, "[3];[3];[3]", "in0");
1634 INFER_ERROR("be rank 1", op, "[1,?,3];[1];[]");
1635 INFER_ERROR("be rank 1", op, "[1,?,3];[];[1]");
1638 INFER_ERROR("must be equal", op, "[1,?,3];[2];[?]");
1639 INFER_ERROR("must be equal", op, "[1,?,3];[?];[2]");
1640 INFER_ERROR("must be equal", op, "[1,?,?];[1];[2]");
1641 INFER_ERROR("must be equal", op, "[5];[4];[?]");
1645 ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient");
1647 INFER_OK(op, "?;?;?;?", "in0;[?];[?]");
1648 INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3");
1649 INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3");
1650 INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3");
1653 INFER_ERROR("be equal rank", op, "[1,?,3];[1,?,3];[3];[]");
1654 INFER_ERROR("be rank 1", op, "[1,?,3];[1,?,3];[];[3]");
1655 INFER_ERROR("be at least rank 1", op, "[];[];[1];[1]");
1656 INFER_ERROR("be at most rank 4", op, "[1,2,3,4,5];[1,2,3,4,5];[1];[1]");
1659 INFER_ERROR("must be equal", op, "[1,3];[1,3];[2];[3]");
1660 INFER_ERROR("must be equal", op, "[1,3];[1,3];[3];[2]");
1664 ShapeInferenceTestOp op("QuantizedConcat");
1665 auto set_n = [&op](int n) {
1678 .Finalize(&op.node_def));
1683 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?;?;?");
1687 INFER_ERROR("must be rank 0", op, "[];?;?;?;?;?;[1]");
1688 INFER_ERROR("must be rank 0", op, "[];?;?;?;?;[1];?");
1689 INFER_ERROR("must be rank 0", op, "[];?;?;?;[1];?;?");
1690 INFER_ERROR("must be rank 0", op, "[];?;?;[1];?;?;?");
1694 INFER_ERROR("must be rank 2", op, "[];[1,2];[1,2,3];?;?;?;?");
1695 INFER_OK(op, "[];[1,2];[1,3];?;?;?;?", "[?,?];[];[]");
1700 op.input_tensors.push_back(&concat_dim_t);
1703 INFER_OK(op, "[];[100,2,?];[10,?,3];?;?;?;?", "[110,d1_1,d2_2];[];[]");
1704 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
1710 ShapeInferenceTestOp op("_ParallelConcatStart");
1717 .Finalize(&op.node_def));
1718 INFER_OK(op, "", "[1,2,3]");