1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/node_def_builder.h"
17#include "tensorflow/core/framework/node_def_util.h"
18#include "tensorflow/core/framework/op.h"
19#include "tensorflow/core/framework/shape_inference.h"
20#include "tensorflow/core/framework/shape_inference_testutil.h"
21#include "tensorflow/core/framework/tensor.pb.h"
22#include "tensorflow/core/framework/tensor_shape.pb.h"
23#include "tensorflow/core/framework/tensor_testutil.h"
24#include "tensorflow/core/lib/core/status_test_util.h"
25#include "tensorflow/core/platform/test.h"
26#include "tensorflow/core/public/version.h"
27
28namespace tensorflow {
29
30TEST(ArrayOpsTest, Pack_ShapeFn) {
31  ShapeInferenceTestOp op("Pack");
32  auto set_axis = [&op](int axis) {
33    int n = 3;
34    std::vector<NodeDefBuilder::NodeOut> src_list;
35    src_list.reserve(n);
36    for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
37    TF_ASSERT_OK(NodeDefBuilder("test", "Pack")
38                     .Input(src_list)
39                     .Attr("N", n)
40                     .Attr("axis", axis)
41                     .Finalize(&op.node_def));
42  };
43
44  set_axis(0);
45  INFER_OK(op, "?;?;?", "?");
46
47  for (int axis : {0, -3}) {
48    set_axis(axis);
49    INFER_OK(op, "?;?;?", "?");
50    INFER_OK(op, "[1,3];[1,3];?", "[3,d0_0|d1_0,d0_1|d1_1]");
51    INFER_OK(op, "[?,3];[1,3];?", "[3,d1_0,d0_1|d1_1]");
52    INFER_OK(op, "[?,?];[1,3];?", "[3,d1_0,d1_1]");
53  }
54  for (int axis : {1, -2}) {
55    set_axis(axis);
56    INFER_OK(op, "?;?;?", "?");
57    INFER_OK(op, "[1,3];[1,3];?", "[d0_0|d1_0,3,d0_1|d1_1]");
58    INFER_OK(op, "[?,3];[1,3];?", "[d1_0,3,d0_1|d1_1]");
59    INFER_OK(op, "[?,?];[1,3];?", "[d1_0,3,d1_1]");
60  }
61  for (int axis : {2, -1}) {
62    set_axis(axis);
63    INFER_OK(op, "?;?;?", "?");
64    INFER_OK(op, "[1,3];[1,3];?", "[d0_0|d1_0,d0_1|d1_1,3]");
65    INFER_OK(op, "[?,3];[1,3];?", "[d1_0,d0_1|d1_1,3]");
66    INFER_OK(op, "[?,?];[1,3];?", "[d1_0,d1_1,3]");
67  }
68
69  set_axis(-4);
70  INFER_ERROR("Invalid axis: -4; must be in [-3,3)", op, "[1,3];[1,3];?");
71  set_axis(3);
72  INFER_ERROR("Invalid axis: 3; must be in [-3,3)", op, "[1,3];[1,3];?");
73
74  set_axis(0);
75
76  // Check that both components of error message are there.
77  INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op,
78              "[1,2,3];?;[1,4]");
79  INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2,3];?;[1,4]");
80}
81
82TEST(ArrayOpsTest, UnPack_ShapeFn) {
83  ShapeInferenceTestOp op("Unpack");
84  auto set_axis_and_num = [&op](int axis, int num) {
85    TF_ASSERT_OK(NodeDefBuilder("test", "Unpack")
86                     .Input("a", 0, DT_FLOAT)
87                     .Attr("axis", axis)
88                     .Attr("num", num)
89                     .Finalize(&op.node_def));
90  };
91
92  set_axis_and_num(0, 1);
93  INFER_OK(op, "?", "?");
94
95  for (int axis : {0, -3}) {
96    set_axis_and_num(axis, 1);
97    INFER_OK(op, "?", "?");
98    INFER_OK(op, "[1,2,3]", "[d0_1,d0_2]");
99    INFER_OK(op, "[?,?,?]", "[d0_1,d0_2]");
100  }
101  for (int axis : {1, -2}) {
102    set_axis_and_num(axis, 2);
103    INFER_OK(op, "[1,2,3]", "[d0_0,d0_2];[d0_0,d0_2]");
104    INFER_OK(op, "[?,?,?]", "[d0_0,d0_2];[d0_0,d0_2]");
105  }
106  for (int axis : {2, -1}) {
107    set_axis_and_num(axis, 3);
108    INFER_OK(op, "[1,2,3]", "[d0_0,d0_1];[d0_0,d0_1];[d0_0,d0_1]");
109    INFER_OK(op, "[?,?,?]", "[d0_0,d0_1];[d0_0,d0_1];[d0_0,d0_1]");
110  }
111
112  set_axis_and_num(2, 2);
113  INFER_ERROR("Dimension must be 2 but is 3", op, "[1,2,3]");
114
115  set_axis_and_num(-4, 3);
116  INFER_ERROR("Invalid axis: -4; must be in [-3,3)", op, "[1,2,3]");
117  set_axis_and_num(3, 3);
118  INFER_ERROR("Invalid axis: 3; must be in [-3,3)", op, "[1,2,3]");
119}
120
121TEST(ArrayOpsTest, Const_ShapeFn) {
122  ShapeInferenceTestOp op("Const");
123  TensorProto tensor_proto;
124  auto* shape_proto = tensor_proto.mutable_tensor_shape();
125  auto rebuild_node_def = [&op, &tensor_proto]() {
126    TF_ASSERT_OK(NodeDefBuilder("test", "Const")
127                     .Attr("value", tensor_proto)
128                     .Finalize(&op.node_def));
129  };
130
131  TensorShape{}.AsProto(shape_proto);
132  rebuild_node_def();
133  INFER_OK(op, "", "[]");
134  TensorShape{1, 2, 3, 4}.AsProto(shape_proto);
135  rebuild_node_def();
136  INFER_OK(op, "", "[1,2,3,4]");
137
138  shape_proto->add_dim()->set_size(-1);
139  rebuild_node_def();
140  INFER_ERROR("Shape [1,2,3,4,?] is not fully defined", op, "");
141}
142
143TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) {
144  for (const char* op_name : {
145           "CheckNumerics",
146           "Identity",
147           "RefIdentity",
148           "QuantizeAndDequantize",
149           "StopGradient",
150           "ZerosLike",
151           "OnesLike",
152       }) {
153    ShapeInferenceTestOp op(op_name);
154    INFER_OK(op, "?", "in0");
155    INFER_OK(op, "[]", "in0");
156    INFER_OK(op, "[1,2,?,4,5]", "in0");
157  }
158
159  // inputs 1 and 2 are ignored; input 0 is transferred to output 0.
160  ShapeInferenceTestOp op("MatrixBandPart");
161  INFER_OK(op, "?;?;?", "in0");
162  INFER_OK(op, "[];?;?", "in0");
163  INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
164}
165
166TEST(ArrayOpsTest, GuaranteeConst_ShapeFn) {
167  ShapeInferenceTestOp op("GuaranteeConst");
168  INFER_OK(op, "?", "in0");
169  INFER_OK(op, "[]", "in0");
170  INFER_OK(op, "[1,2,?,4,5]", "in0");
171}
172
173TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
174  const char* op_name = "Identity";
175  ShapeInferenceTestOp op(op_name);
176  // Check that handle dtypes are preserved.
177  const OpRegistrationData* op_reg_data;
178  TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
179  std::vector<
180      std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>
181      handle_data;
182  handle_data.emplace_back(
183      new std::vector<std::pair<TensorShapeProto, DataType>>{
184          {TensorShapeProto(), DT_BOOL}});
185  shape_inference::InferenceContext c(TF_GRAPH_DEF_VERSION, &op.node_def,
186                                      op_reg_data->op_def, {TensorShapeProto()},
187                                      {}, {}, handle_data);
188  TF_ASSERT_OK(c.construction_status());
189  ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
190  TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
191
192  const auto* shapes_and_types = c.output_handle_shapes_and_types(0);
193  ASSERT_TRUE(shapes_and_types != nullptr);
194  ASSERT_EQ(1, shapes_and_types->size());
195  EXPECT_EQ((*shapes_and_types)[0].dtype, DT_BOOL);
196}
197
198TEST(ArrayOpsTest, Diag_ShapeFn) {
199  ShapeInferenceTestOp op("Diag");
200  INFER_OK(op, "?", "?");
201  INFER_OK(op, "[1,?,3]", "[d0_0,d0_1,d0_2,d0_0,d0_1,d0_2]");
202  INFER_OK(op, "[?,1,2,3]", "[d0_0,d0_1,d0_2,d0_3,d0_0,d0_1,d0_2,d0_3]");
203  INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
204}
205
206TEST(ArrayOpsTest, DiagPart_ShapeFn) {
207  ShapeInferenceTestOp op("DiagPart");
208  INFER_OK(op, "?", "?");
209  INFER_OK(op, "[1,?,?,4]", "[d0_0,d0_3]");
210  INFER_OK(op, "[1,?,3,?,4,3]", "[d0_0,d0_4,d0_2|d0_5]");
211  INFER_OK(op, "[1,2,3,?,?,?,?,4]", "[d0_0,d0_1,d0_2,d0_7]");
212  INFER_ERROR("Input must have even and non-zero rank", op, "[]");
213  INFER_ERROR("Input must have even and non-zero rank", op, "[?]");
214  INFER_ERROR("Input must have even and non-zero rank", op, "[1,2,3]");
215  INFER_ERROR("Dimensions must be equal, but are 2 and 10", op, "[1,2,?,10]");
216}
217
218TEST(ArrayOpsTest, MatrixDiag_ShapeFn) {
219  ShapeInferenceTestOp op("MatrixDiag");
220  INFER_OK(op, "?", "?");
221  INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
222  INFER_OK(op, "[?]", "[d0_0,d0_0]");
223  INFER_OK(op, "[1,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,d0_3]");
224}
225
226TEST(ArrayOpsTest, MatrixDiagPart_ShapeFn) {
227  ShapeInferenceTestOp op("MatrixDiagPart");
228  INFER_OK(op, "?", "?");
229  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[?]");
230  INFER_OK(op, "[?,1,2,2]", "[d0_0,d0_1,d0_2|d0_3]");
231  INFER_OK(op, "[?,1,2,3]", "[d0_0,d0_1,d0_2]");
232  INFER_OK(op, "[?,1,3,2]", "[d0_0,d0_1,d0_3]");
233}
234
235TEST(ArrayOpsTest, Reverse_ShapeFn) {
236  ShapeInferenceTestOp op("Reverse");
237  INFER_OK(op, "?;?", "in0");
238  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
239  INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,2]");
240  INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4]");
241  INFER_ERROR("reverse does not work on tensors with more than 8 dimensions",
242              op, "[1,2,3,4,5,6,7,8,9];[9]");
243  INFER_OK(op, "[1,2,3,?];[4]", "in0");
244  INFER_OK(op, "[1,2,3,?,5,6,7,8];[8]", "in0");
245}
246
247TEST(ArrayOpsTest, ReverseV2_ShapeFn) {
248  ShapeInferenceTestOp op("ReverseV2");
249  INFER_OK(op, "?;?", "in0");
250  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
251  INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,2]");
252  INFER_OK(op, "[1,2,3];[2]", "in0");
253  INFER_ERROR("reverse does not work on tensors with more than 8 dimensions",
254              op, "[1,2,3,4,5,6,7,8,9];[9]");
255  INFER_OK(op, "[1,2,3,?];[4]", "in0");
256  INFER_OK(op, "[1,2,3,?,5,6,7,8];[8]", "in0");
257}
258
259TEST(ArrayOpsTest, Fill_ShapeFn) {
260  ShapeInferenceTestOp op("Fill");
261  AddNodeAttr("index_type", DT_INT32, &op.node_def);
262  op.input_tensors.resize(2);
263  INFER_OK(op, "?;?", "?");
264  INFER_OK(op, "[?];?", "?");
265  INFER_OK(op, "[4];?", "[?,?,?,?]");
266
267  Tensor in_t = test::AsTensor<int32>({1, 2, 3, 4});
268  op.input_tensors[0] = &in_t;
269  INFER_OK(op, "[4];?", "[1,2,3,4]");
270}
271
272TEST(ArrayOpsTest, Gather_ShapeFn) {
273  ShapeInferenceTestOp op("Gather");
274  INFER_OK(op, "?;?", "?");
275  INFER_OK(op, "[1,?,2];[3]", "[d1_0,d0_1,d0_2]");
276  INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1,2,3]");
277}
278
279TEST(ArrayOpsTest, GatherV2_ShapeFn) {
280  ShapeInferenceTestOp op("GatherV2");
281
282  // Tests when axis is unknown.
283  INFER_OK(op, "?;?;?", "?");
284  INFER_OK(op, "[1,2,3];[3];[]", "[?,?,?]");
285  INFER_ERROR("Shape must be at least rank 1 but is rank 0", op,
286              "[];[1,2,3];[]");
287
288  // Non-scalar axis.
289  INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];[1,2,3];[1]");
290
291  // Test when axis dim is known.
292  Tensor axis_dim_t;
293  op.input_tensors.resize(3);
294  op.input_tensors[2] = &axis_dim_t;
295
296  // Out of range axis.
297  axis_dim_t = test::AsScalar(1);
298  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
299              "[1];[1,2];[]");
300
301  // Rank 0 indices.
302  axis_dim_t = test::AsScalar(0);
303  INFER_OK(op, "[1,2,3];[];[]", "[d0_1,d0_2]");
304  axis_dim_t = test::AsScalar(1);
305  INFER_OK(op, "[1,2,3];[];[]", "[d0_0,d0_2]");
306  axis_dim_t = test::AsScalar(2);
307  INFER_OK(op, "[1,2,3];[];[]", "[d0_0,d0_1]");
308
309  // Rank 1 indices.
310  axis_dim_t = test::AsScalar(0);
311  INFER_OK(op, "[1,2,3];[5];[]", "[d1_0,d0_1,d0_2]");
312  axis_dim_t = test::AsScalar(1);
313  INFER_OK(op, "[1,2,3];[5];[]", "[d0_0,d1_0,d0_2]");
314  axis_dim_t = test::AsScalar(2);
315  INFER_OK(op, "[1,2,3];[5];[]", "[d0_0,d0_1,d1_0]");
316
317  // Rank 2 indices.
318  axis_dim_t = test::AsScalar(0);
319  INFER_OK(op, "[1,2,3];[5,6];[]", "[d1_0,d1_1,d0_1,d0_2]");
320  axis_dim_t = test::AsScalar(1);
321  INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d1_0,d1_1,d0_2]");
322  axis_dim_t = test::AsScalar(2);
323  INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d0_1,d1_0,d1_1]");
324
325  // Negative axis.
326  axis_dim_t = test::AsScalar(-3);
327  INFER_OK(op, "[1,2,3];[5,6];[]", "[d1_0,d1_1,d0_1,d0_2]");
328  axis_dim_t = test::AsScalar(-2);
329  INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d1_0,d1_1,d0_2]");
330  axis_dim_t = test::AsScalar(-1);
331  INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d0_1,d1_0,d1_1]");
332}
333
334TEST(ArrayOpsTest, GatherNd_ShapeFn) {
335  ShapeInferenceTestOp op("GatherNd");
336
337  // Inputs are (params, indices).
338  INFER_OK(op, "?;?", "?");
339  INFER_OK(op, "[1,?,3,?];[?,0]", "[d1_0,d0_0,d0_1,d0_2,d0_3]");
340  INFER_OK(op, "[1,?,3,?];[?,4]", "[d1_0]");
341
342  // params.rank >= indices.dim(-1).
343  INFER_ERROR("indices.shape[-1] must be <= params.rank", op, "[1,2,3];[4]");
344}
345
346TEST(ArrayOpsTest, Shape_ShapeFn) {
347  ShapeInferenceTestOp op("Shape");
348  INFER_OK(op, "?", "[?]");
349  INFER_OK(op, "[?]", "[1]");
350  INFER_OK(op, "[?,2,3,4,5]", "[5]");
351}
352
353TEST(ArrayOpsTest, ShapeN_ShapeFn) {
354  ShapeInferenceTestOp op("ShapeN");
355  int n = 3;
356  std::vector<NodeDefBuilder::NodeOut> src_list;
357  src_list.reserve(n);
358  for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
359  TF_ASSERT_OK(NodeDefBuilder("test", "ShapeN")
360                   .Input(src_list)
361                   .Attr("N", n)
362                   .Finalize(&op.node_def));
363  INFER_OK(op, "?;?;?", "[?];[?];[?]");
364  INFER_OK(op, "[?];[?];[?]", "[1];[1];[1]");
365  INFER_OK(op, "[?,2,3,4,5];?;[1,?,3]", "[5];[?];[3]");
366}
367
368TEST(ArrayOpsTest, Unique_ShapeFn) {
369  ShapeInferenceTestOp op("Unique");
370  INFER_OK(op, "?", "[?];in0");
371  INFER_OK(op, "[1,2,3,?,5]", "[?];in0");
372}
373
374TEST(ArrayOpsTest, UniqueWithCounts_ShapeFn) {
375  ShapeInferenceTestOp op("UniqueWithCounts");
376  INFER_OK(op, "?", "[?];in0;[?]");
377  INFER_OK(op, "[1,2,3,?,5]", "[?];in0;[?]");
378}
379
380TEST(ArrayOpsTest, InvertPermutation_ShapeFn) {
381  ShapeInferenceTestOp op("InvertPermutation");
382  INFER_OK(op, "?", "[?]");
383  INFER_OK(op, "[1]", "in0");
384  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[]");
385}
386
387TEST(ArrayOpsTest, PadD_ShapeFn) {
388  for (const char* op_name : {"Pad", "MirrorPad"}) {
389    ShapeInferenceTestOp op(op_name);
390    op.input_tensors.resize(2);
391
392    // Inputs are input and paddings.
393
394    INFER_OK(op, "?;?", "?");
395
396    // Check shape of paddings.
397    INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;[1,2,3]");
398    INFER_ERROR("Dimension must be 2 but is 4", op, "?;[1,4]");
399
400    // input.rank and paddings.dim(0) are equal. This is the number of dims in
401    // output.
402    INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4,2]");
403    INFER_OK(op, "[1,2,3];?", "[?,?,?]");
404    INFER_OK(op, "?;[3,2]", "[?,?,?]");
405
406    // Make the paddings tensor known and verify padding values get added.
407    // E.g., if padding is ((1,10),(2,20),(3,30)) then values 11,22,23 are added
408    // to input dims to get output.
409    Tensor paddings_t(DT_INT64, TensorShape{3, 2});
410    test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
411    op.input_tensors[1] = &paddings_t;
412    INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
413    INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
414    INFER_OK(op, "?;[3,2]", "[?,?,?]");
415    INFER_OK(op, "?;?", "[?,?,?]");
416  }
417}
418
419TEST(ArrayOpsTest, PadV2_ShapeFn) {
420  ShapeInferenceTestOp op("PadV2");
421  op.input_tensors.resize(3);
422
423  // Inputs are input, paddings and constant_values.
424
425  INFER_OK(op, "?;?;?", "?");
426
427  // Check shape of paddings.
428  INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;[1,2,3];?");
429  INFER_ERROR("Dimension must be 2 but is 4", op, "?;[1,4];?");
430
431  // input.rank and paddings.dim(0) are equal. This is the number of dims in
432  // output.
433  INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4,2];[]");
434  INFER_OK(op, "[1,2,3];?;[]", "[?,?,?]");
435  INFER_OK(op, "?;[3,2];[]", "[?,?,?]");
436
437  // Make the paddings tensor known and verify padding values get added.
438  // E.g., if padding is ((1,10),(2,20),(3,30)) then values 11,22,23 are added
439  // to input dims to get output.
440  Tensor paddings_t(DT_INT64, TensorShape{3, 2});
441  test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
442  op.input_tensors[1] = &paddings_t;
443  INFER_OK(op, "[100,200,300];[3,2];[]", "[111,222,333]");
444  INFER_OK(op, "[100,?,300];[3,2];[]", "[111,?,333]");
445  INFER_OK(op, "?;[3,2];[]", "[?,?,?]");
446  INFER_OK(op, "?;?;[]", "[?,?,?]");
447}
448
449TEST(ArrayOpsTest, MirrorPadGrad_ShapeFn) {
450  ShapeInferenceTestOp op("MirrorPadGrad");
451  op.input_tensors.resize(2);
452
453  // Inputs are input and paddings.
454  INFER_OK(op, "?;?", "?");
455
456  // First padding dimension is unknown, so rank is unknown.
457  INFER_OK(op, "?;[?,4]", "?");
458
459  // Input tensor rank doesn't match paddings dimension.
460  INFER_ERROR("must be rank 3 but is rank 2", op, "[?,?];[3,2]");
461
462  // Paddings tensor is not a [rank x 2] matrix.
463  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 3 and 2", op,
464              "[?,?,?];[3,3]");
465
466  // Paddings tensor is unknown, but rank is known, so the output
467  // shape is a rank 3 unknown shape.
468  INFER_OK(op, "[?,?,?];[3,2]", "[?,?,?]");
469
470  // Make the paddings tensor known and verify padding values get
471  // subtracted.  E.g., if padding is ((1,10),(2,20),(3,30)) then
472  // values 11,22,23 are subtracted to input dims to get output.
473  Tensor paddings_t(DT_INT64, TensorShape{3, 2});
474  test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
475  op.input_tensors[1] = &paddings_t;
476
477  INFER_OK(op, "[111,222,333];[3,2]", "[100,200,300]");
478  INFER_OK(op, "[111,?,333];[3,2]", "[100,?,300]");
479}
480
481TEST(ArrayOpsTest, BroadcastArgs_ShapeFn) {
482  ShapeInferenceTestOp op("BroadcastArgs");
483  INFER_OK(op, "?;?", "[?]");
484  INFER_OK(op, "[123];[1]", "[123]");
485  INFER_OK(op, "[1];[123]", "[123]");
486  INFER_OK(op, "[123];[121]", "[123]");
487
488  // Rank checks
489  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
490  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
491}
492
493TEST(ArrayOpsTest, BroadcastGradientArgs_ShapeFn) {
494  ShapeInferenceTestOp op("BroadcastGradientArgs");
495  // Output is always two unknown vectors.
496  INFER_OK(op, "?;?", "[?];[?]");
497  INFER_OK(op, "[123];[456]", "[?];[?]");
498
499  // Rank checks
500  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
501  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
502}
503
504TEST(ArrayOpsTest, ListDiff_ShapeFn) {
505  ShapeInferenceTestOp op("BroadcastGradientArgs");
506  // Output is always two matching unknown vectors.
507  INFER_OK(op, "?;?", "[?];[?]");
508  INFER_OK(op, "[123];[456]", "[?];[?]");
509
510  // Rank checks
511  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
512  INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
513}
514
515TEST(ArrayOpsTest, MatrixSetDiag_ShapeFn) {
516  ShapeInferenceTestOp op("MatrixSetDiag");
517
518  // Inputs are input and diagonal.
519
520  // Rank checks.
521  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
522  INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "?;[]");
523  INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[2,2];[]");
524  INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,2];[2,2]");
525
526  // diagonal[-1] must match smallest matrix dimension.
527  INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[2,3];[3]");
528
529  // Output matches input.
530  INFER_OK(op, "?;?", "in0");
531  INFER_OK(op, "[1,2,2];[1,2]", "in0");
532  INFER_OK(op, "[1,2,3];?", "in0");
533  INFER_OK(op, "[1,3,2];?", "in0");
534  INFER_OK(op, "[1,?,2];[?,?]", "in0");
535  INFER_OK(op, "[1,?,?];[?,2]", "in0");
536
537  // Infer batch shape from diag when input is not fully specified.
538  INFER_OK(op, "?;[1,2]", "[d1_0,?,?]");
539  INFER_OK(op, "[?,?,3];[1,2]", "[d1_0,d0_1,d0_2]");
540  INFER_OK(op, "[?,3,?];[1,2]", "[d1_0,d0_1,d0_2]");
541  INFER_OK(op, "[?,3,2];[1,2]", "[d1_0,d0_1,d0_2]");
542}
543
544TEST(ArrayOpsTest, ExpandDims_ShapeFn) {
545  ShapeInferenceTestOp op("ExpandDims");
546  op.input_tensors.resize(2);
547
548  // With unknown dim tensor value, output is unknown.
549  INFER_OK(op, "?;?", "?");
550  Tensor dim_t;
551  op.input_tensors[1] = &dim_t;
552
553  // Expand at front of tensor.
554  for (int32 idx : {0, -4}) {
555    dim_t = test::AsScalar<int32>(idx);
556    INFER_OK(op, "?;?", "?");
557    INFER_OK(op, "[5,?,7];?", "[1,d0_0,d0_1,d0_2]");
558  }
559
560  // Expand at middle of tensor.
561  for (int32 idx : {1, -3}) {
562    dim_t = test::AsScalar<int32>(idx);
563    INFER_OK(op, "?;?", "?");
564    INFER_OK(op, "[5,?,7];?", "[d0_0,1,d0_1,d0_2]");
565
566    // Repeat with int64.
567    dim_t = test::AsScalar<int64>(idx);
568    INFER_OK(op, "?;?", "?");
569    INFER_OK(op, "[5,?,7];?", "[d0_0,1,d0_1,d0_2]");
570  }
571  for (int32 idx : {2, -2}) {
572    dim_t = test::AsScalar<int32>(idx);
573    INFER_OK(op, "?;?", "?");
574    INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,1,d0_2]");
575
576    // Repeat with int64.
577    dim_t = test::AsScalar<int64>(idx);
578    INFER_OK(op, "?;?", "?");
579    INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,1,d0_2]");
580  }
581
582  for (int32 idx : {3, -1}) {
583    // Expand at the end.
584    dim_t = test::AsScalar<int32>(idx);
585    INFER_OK(op, "?;?", "?");
586    INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,d0_2,1]");
587
588    // Repeat with int64.
589    dim_t = test::AsScalar<int64>(idx);
590    INFER_OK(op, "?;?", "?");
591    INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,d0_2,1]");
592  }
593  for (int32 idx : {4, -5}) {
594    // Invalid idx.
595    dim_t = test::AsScalar<int32>(idx);
596    INFER_ERROR("not in the interval [-4, 3]", op, "[5,?,7];?");
597    dim_t = test::AsScalar<int64>(idx);
598    INFER_ERROR("not in the interval [-4, 3]", op, "[5,?,7];?");
599  }
600
601  // Expand using an input vector tensor.
602  std::vector<int32> dims;
603  dims.push_back(0);
604  dim_t = test::AsTensor<int32>(dims);
605  INFER_OK(op, "?;?", "?");
606  INFER_OK(op, "[5,?,7];?", "[1,d0_0,d0_1,d0_2]");
607
608  // Expand using too many input elements.
609  dims.push_back(1);
610  dim_t = test::AsTensor<int32>(dims);
611  INFER_ERROR("'dim' input must be a tensor with a single", op, "?;?");
612  INFER_ERROR("'dim' input must be a tensor with a single", op, "[5,6,7];?");
613
614  // Examples from ExpandDims doc.
615  dim_t = test::AsScalar<int32>(0);
616  INFER_OK(op, "[2];[]", "[1,d0_0]");
617  dim_t = test::AsScalar<int32>(1);
618  INFER_OK(op, "[2];[]", "[d0_0,1]");
619  dim_t = test::AsScalar<int32>(-1);
620  INFER_OK(op, "[2];[]", "[d0_0,1]");
621}
622
623TEST(ArrayOpsTest, ImmutableConst_ShapeFn) {
624  ShapeInferenceTestOp op("ImmutableConst");
625
626  TF_ASSERT_OK(NodeDefBuilder("test", "ImmutableConst")
627                   .Attr("dtype", DT_FLOAT)
628                   .Attr("shape", TensorShape({1, 2, 3}))
629                   .Attr("memory_region_name", "test_region")
630                   .Finalize(&op.node_def));
631  INFER_OK(op, "", "[1,2,3]");
632
633  TF_ASSERT_OK(NodeDefBuilder("test", "ImmutableConst")
634                   .Attr("dtype", DT_FLOAT)
635                   .Attr("shape", TensorShape({}))
636                   .Attr("memory_region_name", "test_region")
637                   .Finalize(&op.node_def));
638  INFER_OK(op, "", "[]");
639
640  TF_ASSERT_OK(NodeDefBuilder("test", "ImmutableConst")
641                   .Attr("dtype", DT_FLOAT)
642                   .Attr("shape", "invalid")
643                   .Attr("memory_region_name", "test_region")
644                   .Finalize(&op.node_def));
645  INFER_ERROR("AttrValue had value with type 'string' when 'shape' expected",
646              op, "");
647}
648
649TEST(ArrayOpsTest, Concat_ShapeFn) {
650  ShapeInferenceTestOp op("Concat");
651  auto set_n = [&op](int n) {
652    std::vector<NodeDefBuilder::NodeOut> src_list;
653    src_list.reserve(n);
654    for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
655    TF_ASSERT_OK(NodeDefBuilder("test", "Concat")
656                     .Input({"concat_dim", 0, DT_INT32})
657                     .Input(src_list)
658                     .Attr("n", n)
659                     .Finalize(&op.node_def));
660  };
661
662  // Confirm dimension[0] of the input (the concat_dim) is a scalar.
663  set_n(2);
664  INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?;?");
665
666  // Test with the input concat_dim tensor not known. This takes the known rank
667  // of the inputs and makes a tensor of that many unknown dims.
668  set_n(7);
669  INFER_OK(op, "?;?;?;?;[1,2,3];?;[3,2,1];?", "[?,?,?]");
670  set_n(4);
671  INFER_OK(op, "?;?;?;[1,2,3,4];[4,3,2,1]", "[?,?,?,?]");
672  INFER_OK(op, "?;?;?;?;?", "?");  // output rank unknown
673  INFER_ERROR("Can't concatenate scalars (use tf.stack instead)", op,
674              "?;?;?;[];[]");
675  INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;?;?;[1,2];[1,2,3]");
676
677  // Test when the concat_dim tensor is known. The concatenated dimension is
678  // summed across all input tensors, and other dimensions are merged.
679  Tensor concat_dim_t;
680  op.input_tensors.push_back(&concat_dim_t);
681  set_n(2);
682
683  // Sum dim 0, merge the other two dims.
684  for (int concat_dim : {0, -3}) {
685    concat_dim_t = test::AsScalar(concat_dim);
686    INFER_OK(op, "[];[100,2,?];[10,?,3]", "[110,d1_1,d2_2]");
687    INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
688                "[];[100,2,5];[10,?,3]");
689    // concat_dim can't be summed, as one value is unknown.
690    INFER_OK(op, "[];[100,2,?];[?,?,3]", "[?,d1_1,d2_2]");
691    INFER_OK(op, "[];[?,2,?];[10,?,3]", "[?,d1_1,d2_2]");
692  }
693
694  // Test with a higher concat_dim.
695  for (bool use_negative : {false, true}) {
696    concat_dim_t = test::AsScalar(use_negative ? -2 : 1);
697    INFER_OK(op, "[];[1,100,?];[?,10,3]", "[d1_0,110,d2_2]");
698    concat_dim_t = test::AsScalar(use_negative ? -1 : 1);
699    INFER_OK(op, "[];[1,100];[?,10]", "[d1_0,110]");
700    INFER_OK(op, "[];[?,100];[1,10]", "[d2_0,110]");
701
702    // concat_dim is out of bounds.
703    concat_dim_t = test::AsScalar(use_negative ? -2 : 1);
704    INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
705                "[];[100];[10,?]");
706    INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
707                "[];[100,5];[10]");
708  }
709
710  // concat_dim is too low.
711  concat_dim_t = test::AsScalar(-2);
712  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
713              "[];[100];[10,?]");
714  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
715              "[];[100,5];[10]");
716
717  // Repeat successful case with several unknown inputs.
718  set_n(5);
719  concat_dim_t = test::AsScalar(1);
720  INFER_OK(op, "[];?;[1,100,?];[?,?,?];[?,10,3];?", "[d2_0,?,d4_2]");
721}
722
723TEST(ArrayOpsTest, ConcatV2_ShapeFn) {
724  ShapeInferenceTestOp op("ConcatV2");
725  auto set_n = [&op](int n) {
726    std::vector<NodeDefBuilder::NodeOut> src_list;
727    src_list.reserve(n);
728    for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
729    TF_ASSERT_OK(NodeDefBuilder("test", "ConcatV2")
730                     .Input(src_list)
731                     .Input({"axis", 0, DT_INT32})
732                     .Attr("n", n)
733                     .Finalize(&op.node_def));
734  };
735
736  // Confirm dimension[0] of the input (the concat_dim) is a scalar.
737  set_n(2);
738  INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;?;[1]");
739
740  // Test with the input concat_dim tensor not known. This takes the known rank
741  // of the inputs and makes a tensor of that many unknown dims.
742  set_n(7);
743  INFER_OK(op, "?;?;?;?;[1,2,3];?;[3,2,1];?", "[?,?,?]");
744  set_n(4);
745  INFER_OK(op, "?;?;[1,2,3,4];[4,3,2,1];?", "[?,?,?,?]");
746  INFER_OK(op, "?;?;?;?;?", "?");  // output rank unknown
747  INFER_ERROR("Can't concatenate scalars (use tf.stack instead)", op,
748              "?;?;[];[];?");
749  INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;?;[1,2];[1,2,3];?");
750
751  // Test when the concat_dim tensor is known. The concatenated dimension is
752  // summed across all input tensors, and other dimensions are merged.
753  Tensor concat_dim_t;
754  op.input_tensors.resize(3);
755  op.input_tensors[2] = &concat_dim_t;
756
757  set_n(2);
758
759  // Invalid concat dim value.
760  // concat_dim_t = test::AsScalar(-1);
761  // INFER_ERROR("Expected concat_dim >= 0, but got -1", op, "?;?;?");
762
763  // Sum dim 0, merge the other two dims.
764  concat_dim_t = test::AsScalar(0);
765  INFER_OK(op, "[100,2,?];[10,?,3];[]", "[110,d0_1,d1_2]");
766  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
767              "[100,2,5];[10,?,3];[]");
768  // concat_dim can't be summed, as one value is unknown.
769  INFER_OK(op, "[100,2,?];[?,?,3];[]", "[?,d0_1,d1_2]");
770  INFER_OK(op, "[?,2,?];[10,?,3];[]", "[?,d0_1,d1_2]");
771
772  // Test with a higher concat_dim.
773  concat_dim_t = test::AsScalar(1);
774  INFER_OK(op, "[1,100,?];[?,10,3];[]", "[d0_0,110,d1_2]");
775  INFER_OK(op, "[1,100];[?,10];[]", "[d0_0,110]");
776  INFER_OK(op, "[?,100];[1,10];[]", "[d1_0,110]");
777  // concat_dim is too high.
778  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
779              "[100];[10,?];[]");
780  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
781              "[100,5];[10];[]");
782  // concat_dim is too low.
783  concat_dim_t = test::AsScalar(-2);
784  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
785              "[100];[10,?];[]");
786  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
787              "[100,5];[10];[]");
788
789  // Repeat successful case with several unknown inputs.
790  op.input_tensors.resize(6);
791  op.input_tensors[3] = nullptr;
792  op.input_tensors[5] = &concat_dim_t;
793  concat_dim_t = test::AsScalar(1);
794
795  set_n(5);
796  INFER_OK(op, "?;[1,100,?];[?,?,?];[?,10,3];?;[]", "[d1_0,?,d3_2]");
797}
798
799TEST(ArrayOpsTest, ConcatOffset_ShapeFn) {
800  ShapeInferenceTestOp op("ConcatOffset");
801
802  const int n = 4;
803  std::vector<NodeDefBuilder::NodeOut> src_list;
804  src_list.reserve(n);
805  for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_INT32);
806  TF_ASSERT_OK(NodeDefBuilder("test", "ConcatOffset")
807                   .Input({"concat_dim", 0, DT_INT32})
808                   .Input(src_list)
809                   .Attr("n", n)
810                   .Finalize(&op.node_def));
811  INFER_OK(op, "?;?;?;?;?", "in1;in2;in3;in4");
812}
813
814TEST(ArrayOpsTest, Reshape_ShapeFn) {
815  ShapeInferenceTestOp op("Reshape");
816  op.input_tensors.resize(2);
817
818  // No valid shape provided.
819  INFER_OK(op, "?;?", "?");
820  INFER_OK(op, "[?];?", "?");
821  INFER_OK(op, "[?];[?]", "?");
822  INFER_OK(op, "[4];[?]", "?");
823
824  // All dimensions provided.
825  Tensor new_shape = test::AsTensor<int32>({1, 2, 3});
826  op.input_tensors[1] = &new_shape;
827  INFER_OK(op, "[?];[3]", "[1,2,3]");
828  INFER_OK(op, "[6];[3]", "[1,2,3]");
829  // The number of elements should match for the reshape to succeed.
830  INFER_ERROR(
831      "Cannot reshape a tensor with 12 elements to shape [1,2,3] (6 elements)",
832      op, "[3,4];[3]");
833
834  // Unknown dimensions.
835  // Flatten:
836  new_shape = test::AsTensor<int32>({-1});
837  INFER_OK(op, "[?];[1]", "[?]");
838  INFER_OK(op, "[2,2];[1]", "[4]");
839  // The first dimension is inferred:
840  new_shape = test::AsTensor<int32>({2, -1});
841  INFER_OK(op, "[3,4];[2]", "[2,6]");
842  // The total number of elements must be evenly divisible by the known
843  // dimensions.
844  INFER_ERROR("Dimension size must be evenly divisible by 2 but is 7", op,
845              "[7];[2]");
846  // Multiple missing dimensions cannot be inferred.
847  new_shape = test::AsTensor<int32>({-1, -1, 2});
848  INFER_OK(op, "[8];[3]", "[?,?,2]");
849
850  // Reshaping to a scalar.
851  new_shape = test::AsTensor<int32>({});
852  INFER_OK(op, "[1];[0]", "[]");
853  INFER_ERROR(
854      "Cannot reshape a tensor with 2 elements to shape [] (1 elements)", op,
855      "[1,2];[0]");
856
857  // Reshaping a tensor with no elements.
858  new_shape = test::AsTensor<int32>({-1});
859  INFER_OK(op, "[0];[1]", "[0]");
860  new_shape = test::AsTensor<int32>({-1, 6});
861  INFER_OK(op, "[0,2];[1]", "[0,6]");
862  new_shape = test::AsTensor<int32>({0, -1});
863  INFER_OK(op, "[0,2];[1]", "[0,?]");
864}
865
866TEST(ArrayOpsTest, QuantizedReshape_ShapeFn) {
867  ShapeInferenceTestOp op("QuantizedReshape");
868  op.input_tensors.resize(2);
869
870  // First test a subset of the Reshape_ShapeFn tests. Not all are tested, as
871  // QuantizedReshape uses the same code for the reshape part of the operation.
872  INFER_OK(op, "?;?;?;?", "?;[];[]");
873  INFER_OK(op, "[?];?;?;?", "?;[];[]");
874  INFER_OK(op, "[?];[?];?;?", "?;[];[]");
875  INFER_OK(op, "[4];[?];?;?", "?;[];[]");
876  Tensor new_shape = test::AsTensor<int32>({1, 2, 3});
877  op.input_tensors[1] = &new_shape;
878  INFER_OK(op, "[?];[3];?;?", "[1,2,3];[];[]");
879  INFER_OK(op, "[6];[3];?;?", "[1,2,3];[];[]");
880  INFER_ERROR(
881      "Cannot reshape a tensor with 12 elements to shape [1,2,3] (6 elements)",
882      op, "[3,4];[3];?;?");
883
884  // Test the scalar rank checks on input_min and input_max.
885  INFER_ERROR("must be rank 0", op, "?;?;[1];?");
886  INFER_ERROR("must be rank 0", op, "?;?;?;[1]");
887}
888
889TEST(ArrayOpsTest, Placeholder_ShapeFn) {
890  {
891    // 2D shape
892    ShapeInferenceTestOp op("Placeholder");
893    TensorShape shape({1, 2});
894    TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
895                     .Attr("shape", shape)
896                     .Attr("dtype", DT_FLOAT)
897                     .Finalize(&op.node_def));
898    INFER_OK(op, "", "[1,2]");
899  }
900
901  {
902    // Scalar shapes are supported
903    ShapeInferenceTestOp op("Placeholder");
904    TensorShape shape({});
905    TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
906                     .Attr("shape", shape)
907                     .Attr("dtype", DT_FLOAT)
908                     .Finalize(&op.node_def));
909    INFER_OK(op, "", "[]");
910  }
911
912  {
913    // Partial shape
914    ShapeInferenceTestOp op("Placeholder");
915    const int64 dims[2] = {1, -1};
916    PartialTensorShape shape;
917    TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 2, &shape));
918    TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
919                     .Attr("shape", shape)
920                     .Attr("dtype", DT_FLOAT)
921                     .Finalize(&op.node_def));
922    INFER_OK(op, "", "[1,?]");
923  }
924
925  {
926    // Unknown shape
927    ShapeInferenceTestOp op("Placeholder");
928    PartialTensorShape shape;
929    TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
930                     .Attr("shape", shape)
931                     .Attr("dtype", DT_FLOAT)
932                     .Finalize(&op.node_def));
933    INFER_OK(op, "", "?");
934  }
935}
936
937TEST(ArrayOpsTest, Transpose_ShapeFn) {
938  ShapeInferenceTestOp op("Transpose");
939  op.input_tensors.resize(2);
940
941  // Missing shape information.
942  INFER_OK(op, "?;?", "?");
943  INFER_OK(op, "?;[?]", "?");
944  INFER_OK(op, "?;[2]", "[?,?]");
945  INFER_OK(op, "[?];?", "[?]");
946  INFER_OK(op, "[?,?];[2]", "[?,?]");
947  INFER_ERROR("Dimension must be 3 but is 2", op, "[1,2,3];[2]");
948  Tensor perm = test::AsTensor<int32>({0});
949  op.input_tensors[1] = &perm;
950  INFER_OK(op, "[?];[?]", "[d0_0]");
951  perm = test::AsTensor<int32>({1, 0});
952  INFER_OK(op, "?;[2]", "[?,?]");
953  INFER_OK(op, "[?,?];[2]", "[d0_1,d0_0]");
954  INFER_OK(op, "[1,?];[2]", "[d0_1,d0_0]");
955
956  // Invalid arguments.
957  perm = test::AsTensor<int32>({1, 2});
958  INFER_ERROR("perm dim 2 is out of range of input rank 2", op, "[1,2];[2]");
959  perm = test::AsTensor<int32>({0});
960  INFER_ERROR("Dimension must be 2 but is 1", op, "[1,2];[1]");
961
962  // Larger valid cases.
963  perm = test::AsTensor<int32>({1, 0, 3, 4, 2});
964  INFER_OK(op, "[0,1,2,3,4];[5]", "[d0_1,d0_0,d0_3,d0_4,d0_2]");
965  INFER_OK(op, "[0,?,2,3,4];[5]", "[d0_1,d0_0,d0_3,d0_4,d0_2]");
966}
967
968TEST(ArrayOpsTest, Bitcast_ShapeFn) {
969  ShapeInferenceTestOp op("Bitcast");
970  auto rebuild_node_def = [&op](DataType input_type, DataType output_type) {
971    TF_ASSERT_OK(NodeDefBuilder("test", "Bitcast")
972                     .Input("input", 0, input_type)
973                     .Attr("type", output_type)
974                     .Finalize(&op.node_def));
975  };
976
977  rebuild_node_def(DT_FLOAT, DT_INT32);
978  // No valid shape provided, so output is unknown.
979  INFER_OK(op, "?", "?");
980
981  // Bitcasting from two equal sizes propagates shape.
982  INFER_OK(op, "[1,2]", "in0");
983
984  // Bitcasting from smaller to larger reduces the size of the last dimension.
985  rebuild_node_def(DT_INT32, DT_INT64);
986  INFER_OK(op, "[1,2]", "[d0_0]");  // last dimension matches divisor.
987  // TODO(vrv): Seems like a bug, or at least, too lenient.
988  INFER_OK(op, "[1,?]", "[d0_0]");
989  // 4 is divisible by 2, but the shape function signature requires
990  // that the last dimension matches the last value exactly.
991  INFER_ERROR("does not match", op, "[1,4]");
992  INFER_ERROR("does not match", op, "[1,3]");
993
994  // Bitcasting from a larger type to a smaller type extends the dimension
995  rebuild_node_def(DT_INT64, DT_INT32);
996  INFER_OK(op, "[4,5]", "[d0_0,d0_1,2]");
997  rebuild_node_def(DT_COMPLEX128, DT_INT32);
998  INFER_OK(op, "[4,5]", "[d0_0,d0_1,4]");
999  rebuild_node_def(DT_COMPLEX128, DT_HALF);
1000  INFER_OK(op, "[4,5]", "[d0_0,d0_1,8]");
1001  rebuild_node_def(DT_COMPLEX128, DT_INT8);
1002  INFER_OK(op, "[4,5]", "[d0_0,d0_1,16]");
1003
1004  // Bitcasting from a POD or quantized datatype is not allowed.
1005  rebuild_node_def(DT_STRING, DT_INT32);
1006  INFER_ERROR("one of the type sizes is zero", op, "[1,2,3]");
1007  rebuild_node_def(DT_INT32, DT_STRING);
1008  INFER_ERROR("one of the type sizes is zero", op, "[1,2,3]");
1009}
1010
1011TEST(ArrayOpsTest, Squeeze_ShapeFn) {
1012  ShapeInferenceTestOp op("Squeeze");
1013
1014  auto rebuild_node_def = [&op](const std::vector<int32>& squeeze_dims) {
1015    TF_ASSERT_OK(NodeDefBuilder("test", "Squeeze")
1016                     .Input("input", 0, DT_FLOAT)
1017                     .Attr("squeeze_dims", squeeze_dims)
1018                     .Finalize(&op.node_def));
1019  };
1020
1021  // Default squeeze_dims = []
1022  rebuild_node_def({});
1023
1024  // No valid shape provided, so output is unknown.
1025  INFER_OK(op, "?", "?");
1026
1027  INFER_OK(op, "[1,4,1,5,1]", "[d0_1,d0_3]");
1028
1029  // Squeezing all dimensions, but see some unknown values.
1030  INFER_OK(op, "[1,?,1,?,1]", "?");
1031
1032  // Test simple squeeze of an explicit dimension
1033  rebuild_node_def({1});
1034  INFER_OK(op, "[4,1,5]", "[d0_0,d0_2]");
1035  // Squeezing unknown dim explicitly, assumes it's 1 at runtime.
1036  INFER_OK(op, "[4,?,5]", "[d0_0,d0_2]");
1037
1038  // Attempt to squeeze non-one dimension
1039  INFER_ERROR("Can not squeeze dim[1]", op, "[4,6,5]");
1040
1041  // Squeeze multiple dimensions
1042  rebuild_node_def({1, 2});
1043  INFER_OK(op, "[4,1,1,5]", "[d0_0,d0_3]");
1044  rebuild_node_def({1, -2});
1045  INFER_OK(op, "[4,1,1,5]", "[d0_0,d0_3]");
1046
1047  // Negative squeeze dim
1048  rebuild_node_def({-2});
1049  INFER_OK(op, "[4,1,5]", "[d0_0,d0_2]");
1050
1051  // Test validation of squeeze dimensions
1052  rebuild_node_def({-4});
1053  INFER_ERROR("not in [-3,3)", op, "[1,2,3]");
1054  rebuild_node_def({3});
1055  INFER_ERROR("not in [-3,3)", op, "[1,2,3]");
1056}
1057
1058TEST(ArrayOpsTest, ReverseSequence_ShapeFn) {
1059  ShapeInferenceTestOp op("ReverseSequence");
1060  auto rebuild_node_def = [&op](const int32 seq_dim, const int32 batch_dim) {
1061    TF_ASSERT_OK(NodeDefBuilder("test", "ReverseSequence")
1062                     .Input("input", 0, DT_FLOAT)
1063                     .Input("seq_lengths", 1, DT_INT64)
1064                     .Attr("seq_dim", seq_dim)
1065                     .Attr("batch_dim", batch_dim)
1066                     .Finalize(&op.node_def));
1067  };
1068
1069  rebuild_node_def(1, 2);
1070  // No valid shape provided, so output is unknown.
1071  INFER_OK(op, "?;[10]", "?");
1072
1073  // Bad rank for seq_lengths
1074  INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[10,10]");
1075
1076  // Validate seq_dim and batch_dim
1077  rebuild_node_def(1, 4);
1078  INFER_ERROR("batch_dim must be < input rank", op, "[1,2,3];[3]");
1079  rebuild_node_def(4, 1);
1080  INFER_ERROR("seq_dim must be < input rank", op, "[1,2,3];[3]");
1081
1082  rebuild_node_def(1, 2);
1083  INFER_OK(op, "[1,2,3];[3]", "[d0_0,d0_1,d0_2]");
1084  // Resolves uncertainty on batch dimension by merging.
1085  INFER_OK(op, "[1,2,?];[3]", "[d0_0,d0_1,d1_0]");
1086  INFER_OK(op, "[1,2,3];[?]", "[d0_0,d0_1,d0_2]");
1087}
1088
1089TEST(ArrayOpsTest, Split_ShapeFn) {
1090  ShapeInferenceTestOp op("Split");
1091  op.input_tensors.resize(2);
1092
1093  // No value for split_dim and no input.
1094  TF_ASSERT_OK(NodeDefBuilder("test", "Split")
1095                   .Input("split_dim", 0, DT_INT32)
1096                   .Input("value", 1, DT_FLOAT)
1097                   .Attr("num_split", 2)
1098                   .Finalize(&op.node_def));
1099  INFER_OK(op, "?;?", "?;?");
1100  // If the rank is known, we know the rank of each output.
1101  INFER_OK(op, "?;[?,?]", "[?,?];[?,?]");
1102
1103  // split_dim is unknown but other inputs are known.
1104  INFER_OK(op, "?;[1,4]", "[?,?];[?,?]");
1105
1106  // split_dim is known.
1107  Tensor split_dim = test::AsTensor<int32>({1, 2});
1108  op.input_tensors[0] = &split_dim;
1109  INFER_ERROR("Input must be scalar but has rank 1", op, "[?];[?,?]");
1110  split_dim = test::AsScalar<int32>(1);
1111  INFER_OK(op, "?;?", "?;?");
1112  INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
1113  INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
1114  INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
1115  INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
1116              "?;[1,5]");
1117
1118  // split_dim too large.
1119  split_dim = test::AsScalar<int32>(3);
1120  INFER_ERROR(
1121      "Dimension size, given by scalar input 3 must be in range [-3, 3)", op,
1122      "?;[1,4,8]");
1123
1124  // Negative split_dim.
1125  split_dim = test::AsScalar<int32>(-1);
1126  INFER_OK(op, "?;?", "?;?");
1127  INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
1128  INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
1129  INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
1130  INFER_OK(op, "?;[1,4,8]", "[d1_0,d1_1,4];[d1_0,d1_1,4]");
1131  split_dim = test::AsScalar<int32>(-2);
1132  INFER_OK(op, "?;[1,4,8]", "[d1_0,2,d1_2];[d1_0,2,d1_2]");
1133  split_dim = test::AsScalar<int32>(-4);
1134  INFER_ERROR(
1135      "Dimension size, given by scalar input -4 must be in range [-3, 3)", op,
1136      "?;[1,4,8]");
1137}
1138
1139TEST(ArrayOpsTest, Tile_ShapeFn) {
1140  ShapeInferenceTestOp op("Tile");
1141  op.input_tensors.resize(2);
1142
1143  // No value for split_dim and no input.
1144  TF_ASSERT_OK(NodeDefBuilder("test", "Tile")
1145                   .Input("input", 0, DT_FLOAT)
1146                   .Input("multiples", 1, DT_INT32)
1147                   .Finalize(&op.node_def));
1148
1149  // If both are unknown, output is unknown.
1150  INFER_OK(op, "?;?", "?");
1151
1152  // If multiples rank is unknown but input is, output rank is known.
1153  INFER_OK(op, "[2,3,1,4];?", "[?,?,?,?]");
1154
1155  // Bad rank for 'multiples'
1156  INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,3,1,4];[4,1]");
1157
1158  // No multiples tensor available, but output rank is known from multiples.
1159  INFER_OK(op, "?;[4]", "[?,?,?,?]");
1160
1161  // Test a tile of a 4D input.
1162  Tensor multiples = test::AsTensor<int32>({2, 3, 4, 5});
1163  op.input_tensors[1] = &multiples;
1164  INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
1165  // Test 64-bit tensor type
1166  multiples = test::AsTensor<int64>({2, 3, 4, 5});
1167  INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
1168}
1169
1170TEST(ArrayOpsTest, EditDistance_ShapeFn) {
1171  ShapeInferenceTestOp op("EditDistance");
1172  op.input_tensors.resize(6);
1173
1174  // If the shape tensors are not available, the output shape is unknown.
1175  INFER_OK(op, "[?,?];[?];[4];[?,?];[?];[4]", "?");
1176
1177  Tensor hypothesis_shape = test::AsTensor<int64>({2, 30, 4, 50});
1178  op.input_tensors[2] = &hypothesis_shape;
1179  Tensor truth_shape = test::AsTensor<int64>({20, 3, 40, 5});
1180  op.input_tensors[5] = &truth_shape;
1181  INFER_OK(op, "[?,?];[?];[4];[?,?];[?];[4]", "[20,30,40]");
1182
1183  // Shape elements don't match
1184  hypothesis_shape = test::AsTensor<int64>({2});
1185  op.input_tensors[2] = &hypothesis_shape;
1186  INFER_ERROR("Num elements of hypothesis_shape does not match truth_shape", op,
1187              "[?,?];[?];[1];[?,?];[?];[4]");
1188}
1189
1190TEST(ArrayOpsTest, OneHot_ShapeFn) {
1191  ShapeInferenceTestOp op("OneHot");
1192  op.input_tensors.resize(4);
1193  auto set_axis = [&op](int axis) {
1194    TF_ASSERT_OK(NodeDefBuilder("test", "OneHot")
1195                     .Input("indices", 0, DT_FLOAT)
1196                     .Input("depth", 1, DT_INT32)
1197                     .Input("on_value", 2, DT_FLOAT)
1198                     .Input("off_value", 3, DT_FLOAT)
1199                     .Attr("axis", axis)
1200                     .Finalize(&op.node_def));
1201  };
1202
1203  // Invalid axis value.
1204  set_axis(-2);
1205  INFER_ERROR("axis must be >= -1", op, "?;?;?;?");
1206  set_axis(1);
1207
1208  // If indices shape is unknown, we return an unknown shape.
1209  INFER_OK(op, "?;[];?;?", "?");
1210
1211  // Depth must be scalar.
1212  Tensor depth = test::AsTensor<int32>({1, 2});
1213  op.input_tensors[1] = &depth;
1214  INFER_ERROR("Input must be scalar but has rank 1", op, "?;[2];?;?");
1215
1216  // Full information is available.
1217  depth = test::AsScalar<int32>(2);
1218  INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,2,d0_1,d0_2]");
1219  set_axis(-1);
1220  INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,d0_1,d0_2,2]");
1221}
1222
1223TEST(ArrayOpsTest, ExtractImagePatchesShapeTest) {
1224  ShapeInferenceTestOp op("ExtractImagePatches");
1225  auto set_op = [&op](const std::vector<int32>& ksizes,
1226                      const std::vector<int32>& strides,
1227                      const std::vector<int32>& rates, const string& padding) {
1228    TF_ASSERT_OK(NodeDefBuilder("test", "ExtractImagePatches")
1229                     .Input("input", 0, DT_FLOAT)
1230                     .Attr("ksizes", ksizes)
1231                     .Attr("strides", strides)
1232                     .Attr("rates", rates)
1233                     .Attr("padding", padding)
1234                     .Finalize(&op.node_def));
1235  };
1236
1237  // Just tests that the ksize calculation with rates works.  Most of
1238  // the other code is boilerplate that is tested by a variety of
1239  // other ops.
1240  //
1241  // ksizes is 2x2.  rate rows and cols is 2, so ksize_rows and
1242  // cols are changed to be 2 + (2 - 1) = 3.  7x7 input with 3x3
1243  // filter and 1x1 stride gives a 5x5 output.
1244  set_op({1, 2, 2, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
1245  INFER_OK(op, "[1,7,7,2]", "[d0_0,5,5,8]");
1246  // With ksizes as 1x1, the output depth is now exactly the last value of the
1247  // input and output spatial is reduced as well.
1248  set_op({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
1249  INFER_OK(op, "[1,7,7,2]", "[d0_0,7,7,d0_3]");
1250
1251  // Bad ksize rank
1252  set_op({1, 2, 2, 1, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
1253  INFER_ERROR(
1254      "ExtractImagePatches requires the ksizes attribute to contain 4 values, "
1255      "but got: 5",
1256      op, "[1,7,7,2]");
1257}
1258
1259TEST(ArrayOpsTest, QuantizeAndDequantizeV2_ShapeFn) {
1260  ShapeInferenceTestOp op("QuantizeAndDequantizeV2");
1261  INFER_OK(op, "?;?;?", "in0");
1262  INFER_OK(op, "[];?;?", "in0");
1263  INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
1264
1265  INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1,2,?,4,5];[1];[]");
1266  INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1,2,?,4,5];[];[1]");
1267  INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1,2,?,4,5];[1];[1]");
1268}
1269
1270TEST(ArrayOpsTest, SpaceToBatch_ShapeFn) {
1271  ShapeInferenceTestOp op("SpaceToBatch");
1272  op.input_tensors.resize(2);
1273  TF_ASSERT_OK(NodeDefBuilder("test", "SpaceToBatch")
1274                   .Input("input", 0, DT_FLOAT)
1275                   .Input("paddings", 1, DT_INT32)
1276                   .Attr("block_size", 2)
1277                   .Finalize(&op.node_def));
1278
1279  // Paddings not known, but batch size can be computed.
1280  INFER_OK(op, "[1,10,10,3];[2,2]", "[4,?,?,d0_3]");
1281
1282  // Unknown paddings means width and height.
1283  INFER_OK(op, "[1,10,10,3];?", "[4,?,?,d0_3]");
1284
1285  // Paddings not correct shape
1286  INFER_ERROR("rank", op, "[1,10,10,3];[4]");
1287  INFER_ERROR("3 and 2", op, "[1,10,10,3];[2,3]");
1288
1289  Tensor paddings = test::AsTensor<int32>({4, 2, 2, 4}, {{2, 2}});
1290  op.input_tensors[1] = &paddings;
1291  INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
1292  paddings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}});
1293  INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
1294
1295  // Bad paddings values
1296  paddings = test::AsTensor<int32>({1, 2, 3, 4}, {{2, 2}});
1297  op.input_tensors[1] = &paddings;
1298  INFER_ERROR("Dimension size must be evenly divisible by 2 but is 13", op,
1299              "[1,10,10,3];[2,2]");
1300
1301  // Negative paddsings
1302  paddings = test::AsTensor<int32>({1, -2, 3, 4}, {{2, 2}});
1303  op.input_tensors[1] = &paddings;
1304  INFER_ERROR("cannot be negative", op, "[1,10,10,3];[2,2]");
1305}
1306
1307TEST(ArrayOpsTest, SpaceToBatchND_ShapeFn) {
1308  ShapeInferenceTestOp op("SpaceToBatchND");
1309  op.input_tensors.resize(3);
1310  TF_ASSERT_OK(NodeDefBuilder("test", "SpaceToBatchND")
1311                   .Input("input", 0, DT_FLOAT)
1312                   .Input("block_shape", 1, DT_INT32)
1313                   .Input("paddings", 2, DT_INT32)
1314                   .Finalize(&op.node_def));
1315
1316  // Verify that input shape and paddings shape can be unknown.
1317  INFER_OK(op, "?;[2];?", "?");
1318
1319  // Only number of input dimensions is known.
1320  INFER_OK(op, "[?,?,?,?];[2];?", "[?,?,?,d0_3]");
1321
1322  // Dimensions are partially known.
1323  INFER_OK(op, "[?,?,?,2];[2];?", "[?,?,?,d0_3]");
1324
1325  {
1326    // Dimensions are partially known, block_shape known.
1327    Tensor block_shape = test::AsTensor<int32>({2, 3});
1328    op.input_tensors[1] = &block_shape;
1329    INFER_OK(op, "[3,?,?,2];[2];?", "[18,?,?,d0_3]");
1330
1331    // Dimensions are partially known, block_shape and paddings known.
1332    {
1333      Tensor paddings = test::AsTensor<int32>({1, 1, 0, 1}, {{2, 2}});
1334      op.input_tensors[2] = &paddings;
1335      INFER_OK(op, "[3,?,2,2];[2];[2,2]", "[18,?,1,d0_3]");
1336      op.input_tensors[2] = nullptr;
1337    }
1338
1339    // Dimensions are fully known, block_shape and paddings are known.
1340    {
1341      Tensor paddings = test::AsTensor<int32>({1, 1, 0, 0}, {{2, 2}});
1342      op.input_tensors[2] = &paddings;
1343      INFER_OK(op, "[3,2,3,2];[2];[2,2]", "[18,2,1,d0_3]");
1344      op.input_tensors[2] = nullptr;
1345    }
1346
1347    op.input_tensors[1] = nullptr;
1348  }
1349
1350  INFER_ERROR("block_shape must have rank 1", op, "?;[1,1];?");
1351  INFER_ERROR("block_shape must have known size", op, "?;[?];?");
1352
1353  {
1354    Tensor block_shape = test::AsTensor<int32>({0, 2});
1355    op.input_tensors[1] = &block_shape;
1356    INFER_ERROR("block_shape must be positive", op, "[1,2,2];[2];[2,2]");
1357    op.input_tensors[1] = nullptr;
1358  }
1359
1360  {
1361    Tensor block_shape = test::AsTensor<int32>({1, 1});
1362    op.input_tensors[1] = &block_shape;
1363    Tensor paddings = test::AsTensor<int32>({0, -1, 0, 0}, {{2, 2}});
1364    op.input_tensors[2] = &paddings;
1365    INFER_ERROR("paddings cannot be negative", op, "[1,2,2];[2];[2,2]");
1366    op.input_tensors[1] = nullptr;
1367    op.input_tensors[2] = nullptr;
1368  }
1369
1370  {
1371    Tensor block_shape = test::AsTensor<int32>({3, 3});
1372    op.input_tensors[1] = &block_shape;
1373    Tensor paddings = test::AsTensor<int32>({0, 0, 0, 0}, {{2, 2}});
1374    op.input_tensors[2] = &paddings;
1375    INFER_ERROR("divisible", op, "[1,2,3,1];[2];[2,2]");
1376    op.input_tensors[1] = nullptr;
1377    op.input_tensors[2] = nullptr;
1378  }
1379
1380  INFER_ERROR("rank", op, "[1,3,3,1];[2];[1]");
1381  INFER_ERROR("shape", op, "[1,3,3,1];[2];[1,2]");
1382}
1383
1384TEST(ArrayOpsTest, BatchToSpace_ShapeFn) {
1385  ShapeInferenceTestOp op("BatchToSpace");
1386  op.input_tensors.resize(2);
1387  TF_ASSERT_OK(NodeDefBuilder("test", "BatchToSpace")
1388                   .Input("input", 0, DT_FLOAT)
1389                   .Input("crops", 1, DT_INT32)
1390                   .Attr("block_size", 2)
1391                   .Finalize(&op.node_def));
1392
1393  // croppings not known, but batch size can be computed.
1394  INFER_OK(op, "[4,8,8,3];[2,2]", "[1,?,?,d0_3]");
1395
1396  // block_size not compatible with batch size
1397  INFER_ERROR("Dimension size must be evenly divisible by", op,
1398              "[5,8,8,3];[2,2]");
1399
1400  // Unknown croppings means unknown width and height.
1401  INFER_OK(op, "[4,8,8,3];?", "[1,?,?,d0_3]");
1402
1403  // croppings not correct shape
1404  INFER_ERROR("rank", op, "[4,8,8,3];[4]");
1405  INFER_ERROR("3 and 2", op, "[4,8,8,3];[2,3]");
1406
1407  Tensor croppings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}});
1408  op.input_tensors[1] = &croppings;
1409  INFER_OK(op, "[4,8,8,3];[2,2]", "[1,10,10,d0_3]");
1410
1411  // Bad croppings values
1412  croppings = test::AsTensor<int32>({100, 2, 3, 4}, {{2, 2}});
1413  op.input_tensors[1] = &croppings;
1414  INFER_ERROR("Negative dimension size caused by subtracting", op,
1415              "[4,8,8,3];[2,2]");
1416  croppings = test::AsTensor<int32>({1, 2, 3, 400}, {{2, 2}});
1417  op.input_tensors[1] = &croppings;
1418  INFER_ERROR("Negative dimension size caused by subtracting", op,
1419              "[4,8,8,3];[2,2]");
1420
1421  // Negative paddsings
1422  croppings = test::AsTensor<int32>({1, -2, 3, 4}, {{2, 2}});
1423  op.input_tensors[1] = &croppings;
1424  INFER_ERROR("cannot be negative", op, "[4,8,8,3];[2,2]");
1425}
1426
1427TEST(ArrayOpsTest, BatchToSpaceND_ShapeFn) {
1428  ShapeInferenceTestOp op("BatchToSpaceND");
1429  op.input_tensors.resize(3);
1430  TF_ASSERT_OK(NodeDefBuilder("test", "BatchToSpaceND")
1431                   .Input("input", 0, DT_FLOAT)
1432                   .Input("block_shape", 1, DT_INT32)
1433                   .Input("crops", 2, DT_INT32)
1434                   .Finalize(&op.node_def));
1435
1436  // Verify that input shape and crops shape can be unknown.
1437  INFER_OK(op, "?;[2];?", "?");
1438
1439  // Only number of input dimensions is known.
1440  INFER_OK(op, "[?,?,?,?];[2];?", "[?,?,?,d0_3]");
1441
1442  {
1443    // Dimensions are partially known, block_shape known.
1444    Tensor block_shape = test::AsTensor<int32>({2, 3});
1445    op.input_tensors[1] = &block_shape;
1446    INFER_OK(op, "[?,?,?,2];[2];?", "[?,?,?,d0_3]");
1447
1448    INFER_OK(op, "[18,?,?,2];[2];?", "[3,?,?,d0_3]");
1449
1450    // Dimensions are partially known, block_shape and crops known.
1451    {
1452      Tensor crops = test::AsTensor<int32>({1, 1, 0, 1}, {{2, 2}});
1453      op.input_tensors[2] = &crops;
1454      INFER_OK(op, "[18,?,2,2];[2];[2,2]", "[3,?,5,d0_3]");
1455      op.input_tensors[2] = nullptr;
1456    }
1457
1458    // Dimensions are fully known, block_shape and crops are known.
1459    {
1460      Tensor crops = test::AsTensor<int32>({1, 1, 0, 0}, {{2, 2}});
1461      op.input_tensors[2] = &crops;
1462      INFER_OK(op, "[18,2,1,2];[2];[2,2]", "[3,2,3,d0_3]");
1463      op.input_tensors[2] = nullptr;
1464    }
1465
1466    op.input_tensors[1] = nullptr;
1467  }
1468
1469  INFER_ERROR("block_shape must have rank 1", op, "?;[1,1];?");
1470  INFER_ERROR("block_shape must have known size", op, "?;[?];?");
1471  INFER_ERROR("rank", op, "[2,2];[2];[2,2]");
1472  INFER_ERROR("rank", op, "[2,2,3];[3];[3,2]");
1473
1474  {
1475    Tensor block_shape = test::AsTensor<int32>({0, 2});
1476    op.input_tensors[1] = &block_shape;
1477    INFER_ERROR("block_shape must be positive", op, "[1,2,2];[2];[2,2]");
1478    op.input_tensors[1] = nullptr;
1479  }
1480
1481  {
1482    Tensor block_shape = test::AsTensor<int32>({1, 1});
1483    op.input_tensors[1] = &block_shape;
1484    Tensor paddings = test::AsTensor<int32>({0, -1, 0, 0}, {{2, 2}});
1485    op.input_tensors[2] = &paddings;
1486    INFER_ERROR("crops cannot be negative", op, "[1,2,2];[2];[2,2]");
1487    op.input_tensors[1] = nullptr;
1488    op.input_tensors[2] = nullptr;
1489  }
1490
1491  // The amount to crop exceeds the padded size.
1492  {
1493    Tensor block_shape = test::AsTensor<int32>({2, 2});
1494    op.input_tensors[1] = &block_shape;
1495    Tensor crops = test::AsTensor<int32>({3, 2, 0, 0}, {{2, 2}});
1496    op.input_tensors[2] = &crops;
1497    INFER_ERROR("Negative", op, "[4,2,3,1];[2];[2,2]");
1498    op.input_tensors[1] = nullptr;
1499    op.input_tensors[2] = nullptr;
1500  }
1501
1502  // The batch size is not divisible by the product of the block_shape.
1503  {
1504    Tensor block_shape = test::AsTensor<int32>({2, 3});
1505    op.input_tensors[1] = &block_shape;
1506    INFER_ERROR("divisible", op, "[3,1,1,1];[2];[2,2]");
1507    op.input_tensors[1] = nullptr;
1508  }
1509}
1510
1511TEST(ArrayOpsTest, SpaceToDepth_ShapeFn) {
1512  ShapeInferenceTestOp op("SpaceToDepth");
1513  TF_ASSERT_OK(NodeDefBuilder("test", "SpaceToDepth")
1514                   .Input("input", 0, DT_FLOAT)
1515                   .Attr("block_size", 2)
1516                   .Finalize(&op.node_def));
1517
1518  INFER_OK(op, "[1,2,4,4]", "[d0_0,1,2,16]");
1519
1520  // block_size not compatible with space
1521  INFER_ERROR("Dimension size must be evenly divisible by 2 but is 3", op,
1522              "[1,3,8,4]");
1523  INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
1524              "[1,2,5,4]");
1525
1526  // Unknown depth --> Unknown depth.
1527  INFER_OK(op, "[1,2,4,?]", "[d0_0,1,2,?]");
1528}
1529
1530TEST(ArrayOpsTest, DepthToSpace_ShapeFn) {
1531  ShapeInferenceTestOp op("DepthToSpace");
1532  TF_ASSERT_OK(NodeDefBuilder("test", "DepthToSpace")
1533                   .Input("input", 0, DT_FLOAT)
1534                   .Attr("block_size", 2)
1535                   .Finalize(&op.node_def));
1536
1537  INFER_OK(op, "[1,1,2,16]", "[d0_0,2,4,4]");
1538
1539  // Bad depth
1540  INFER_ERROR("Dimension size must be evenly divisible by 4 but is 15", op,
1541              "[1,1,2,15]");
1542
1543  // Unknown depth --> Unknown depth.
1544  INFER_OK(op, "[1,2,4,?]", "[d0_0,4,8,?]");
1545
1546  // Check another block size.
1547  TF_ASSERT_OK(NodeDefBuilder("test", "DepthToSpace")
1548                   .Input("input", 0, DT_FLOAT)
1549                   .Attr("block_size", 10)
1550                   .Finalize(&op.node_def));
1551  INFER_OK(op, "[1,1,2,200]", "[d0_0,10,20,2]");
1552}
1553
1554TEST(ArrayOpsTest, Slice_ShapeFn) {
1555  ShapeInferenceTestOp op("Slice");
1556  TF_ASSERT_OK(NodeDefBuilder("test", "Slice")
1557                   .Input("input", 0, DT_FLOAT)
1558                   .Input("begin", 1, DT_INT64)
1559                   .Input("sizes", 2, DT_INT64)
1560                   .Finalize(&op.node_def));
1561
1562  // Known rank of input and shape of begin/sizes, but unknown values.
1563  // The best we know is the rank of the output.
1564  INFER_OK(op, "[2,3,4,5];[4];[4]", "[?,?,?,?]");
1565
1566  // Unknown shape of begin/sizes, we still know the rank.
1567  INFER_OK(op, "[2,3,4,5];[?];[?]", "[?,?,?,?]");
1568  // Unknown all around
1569  INFER_OK(op, "?;[?];[?]", "?");
1570  // Can infer based on begin
1571  INFER_OK(op, "?;[4];[?]", "[?,?,?,?]");
1572
1573  // Bad rank of begin, sizes
1574  INFER_ERROR("must be rank 1", op, "[2,3,4,5];[2,3];[3]");
1575  INFER_ERROR("must be rank 1", op, "[2,3,4,5];[2];[3,4]");
1576  // Length of begin doesn't match input rank
1577  INFER_ERROR("must be rank 2", op, "[2,3,4,5];[2];[2]");
1578
1579  // Tests with known values.
1580  op.input_tensors.resize(3);
1581  Tensor begin = test::AsTensor<int32>({0, 1, 2, 1});
1582  Tensor sizes = test::AsTensor<int32>({1, 2, 1, 3});
1583  op.input_tensors[1] = &begin;
1584  op.input_tensors[2] = &sizes;
1585  INFER_OK(op, "[2,3,4,5];[4];[4]", "[1,2,1,3]");
1586
1587  // -1 in sizes means "get the rest"
1588  sizes = test::AsTensor<int32>({-1, -1, 1, -1});
1589  INFER_OK(op, "[2,3,4,5];[4];[4]", "[d0_0,2,1,4]");
1590
1591  begin = test::AsTensor<int32>({0, 1, 2, 6});
1592  sizes = test::AsTensor<int32>({-1, -1, -1, -1});
1593  INFER_ERROR("Negative dimension size", op, "[2,3,4,5];[4];[4]");
1594
1595  begin = test::AsTensor<int32>({0, 1, 2, 5});
1596  sizes = test::AsTensor<int32>({-1, -1, -1, -2});
1597  INFER_ERROR("cannot be < -1", op, "[2,3,4,5];[4];[4]");
1598}
1599
1600TEST(ArrayOpsTest, StridedSliceGrad_ShapeFn) {
1601  ShapeInferenceTestOp op("StridedSliceGrad");
1602  op.input_tensors.resize(5);
1603  INFER_OK(op, "?;?;?;?;?", "?");
1604  INFER_OK(op, "[?];?;?;?;?", "?");
1605  INFER_OK(op, "[4];?;?;?;?", "[?,?,?,?]");
1606
1607  Tensor in_t = test::AsTensor<int32>({1, 2, 3, 4});
1608  op.input_tensors[0] = &in_t;
1609  INFER_OK(op, "[4];?;?;?;?", "[1,2,3,4]");
1610}
1611
1612TEST(ArrayOpsTest, UnchangedWithQuantizationScalars_ShapeFn) {
1613  for (const char* op_name : {"Dequantize", "FakeQuantWithMinMaxVars"}) {
1614    ShapeInferenceTestOp op(op_name);
1615
1616    INFER_OK(op, "?;?;?", "in0");
1617    INFER_OK(op, "[1,?,3];[];[]", "in0");
1618
1619    // Rank check scalars.
1620    INFER_ERROR("be rank 0", op, "[1,?,3];[1];[]");
1621    INFER_ERROR("be rank 0", op, "[1,?,3];[];[1]");
1622  }
1623}
1624
1625TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
1626  ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannel");
1627
1628  INFER_OK(op, "?;?;?", "in0");
1629  INFER_OK(op, "[?];?;?", "in0");
1630  INFER_OK(op, "[1,?,3];[3];[3]", "in0");
1631  INFER_OK(op, "[3];[3];[3]", "in0");
1632
1633  // Rank check vectors.
1634  INFER_ERROR("be rank 1", op, "[1,?,3];[1];[]");
1635  INFER_ERROR("be rank 1", op, "[1,?,3];[];[1]");
1636
1637  // Vectors must match each other, and match last dim of input.
1638  INFER_ERROR("must be equal", op, "[1,?,3];[2];[?]");
1639  INFER_ERROR("must be equal", op, "[1,?,3];[?];[2]");
1640  INFER_ERROR("must be equal", op, "[1,?,?];[1];[2]");
1641  INFER_ERROR("must be equal", op, "[5];[4];[?]");
1642}
1643
1644TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) {
1645  ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient");
1646
1647  INFER_OK(op, "?;?;?;?", "in0;[?];[?]");
1648  INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3");
1649  INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3");
1650  INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3");
1651
1652  // Rank check vectors.
1653  INFER_ERROR("be equal rank", op, "[1,?,3];[1,?,3];[3];[]");
1654  INFER_ERROR("be rank 1", op, "[1,?,3];[1,?,3];[];[3]");
1655  INFER_ERROR("be at least rank 1", op, "[];[];[1];[1]");
1656  INFER_ERROR("be at most rank 4", op, "[1,2,3,4,5];[1,2,3,4,5];[1];[1]");
1657
1658  // Vectors must match each other, and match last dim of input.
1659  INFER_ERROR("must be equal", op, "[1,3];[1,3];[2];[3]");
1660  INFER_ERROR("must be equal", op, "[1,3];[1,3];[3];[2]");
1661}
1662
1663TEST(ArrayOpsTest, QuantizedConcat_ShapeFn) {
1664  ShapeInferenceTestOp op("QuantizedConcat");
1665  auto set_n = [&op](int n) {
1666    std::vector<NodeDefBuilder::NodeOut> src_list;
1667    std::vector<NodeDefBuilder::NodeOut> limit_list;
1668    for (int i = 0; i < n; ++i) {
1669      src_list.emplace_back("a", 0, DT_QUINT8);
1670      limit_list.emplace_back("b", 0, DT_FLOAT);
1671    }
1672    TF_ASSERT_OK(NodeDefBuilder("test", "QuantizedConcat")
1673                     .Input({"concat_dim", 0, DT_INT32})
1674                     .Input(src_list)
1675                     .Input(limit_list)
1676                     .Input(limit_list)
1677                     .Attr("N", n)
1678                     .Finalize(&op.node_def));
1679  };
1680
1681  // Confirm dimension[0] of the input (the concat_dim) is a scalar.
1682  set_n(1);
1683  INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?;?;?");
1684
1685  // Last 2*<N> are all scalars.
1686  set_n(2);
1687  INFER_ERROR("must be rank 0", op, "[];?;?;?;?;?;[1]");
1688  INFER_ERROR("must be rank 0", op, "[];?;?;?;?;[1];?");
1689  INFER_ERROR("must be rank 0", op, "[];?;?;?;[1];?;?");
1690  INFER_ERROR("must be rank 0", op, "[];?;?;[1];?;?;?");
1691
1692  // First is concat dim; next N must be compatible for concat.
1693  set_n(2);
1694  INFER_ERROR("must be rank 2", op, "[];[1,2];[1,2,3];?;?;?;?");
1695  INFER_OK(op, "[];[1,2];[1,3];?;?;?;?", "[?,?];[];[]");
1696
1697  // Test when the concat_dim tensor is known. The concatenated dimension is
1698  // summed across all input tensors, and other dimensions are merged.
1699  Tensor concat_dim_t;
1700  op.input_tensors.push_back(&concat_dim_t);
1701  set_n(2);
1702  concat_dim_t = test::AsScalar(0);  // Sum dim 0, merge the other two dims.
1703  INFER_OK(op, "[];[100,2,?];[10,?,3];?;?;?;?", "[110,d1_1,d2_2];[];[]");
1704  INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
1705              "[];[100,2,5];[10,?,3];?;?;?;?");
1706  // Note that other cases of concat are covered in the Concat tests.
1707}
1708
1709TEST(StateOpsTest, _ParallelConcatStart_ShapeFn) {
1710  ShapeInferenceTestOp op("_ParallelConcatStart");
1711  TensorShape shape({1, 2, 3});
1712  TensorShapeProto shape_proto;
1713  shape.AsProto(&shape_proto);
1714  TF_ASSERT_OK(NodeDefBuilder("test", "_ParallelConcatStart")
1715                   .Attr("shape", shape_proto)
1716                   .Attr("dtype", DT_FLOAT)
1717                   .Finalize(&op.node_def));
1718  INFER_OK(op, "", "[1,2,3]");
1719}
1720
1721}  // end namespace tensorflow
1722