1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/node_def_builder.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference_testutil.h"
19#include "tensorflow/core/framework/tensor_testutil.h"
20#include "tensorflow/core/platform/test.h"
21
22namespace tensorflow {
23
24TEST(SetOpsTest, DenseToDenseShape_InvalidNumberOfInputs) {
25  ShapeInferenceTestOp op("DenseToDenseSetOperation");
26  op.input_tensors.resize(3);
27  INFER_ERROR("Wrong number of inputs passed", op, "?;?;?");
28}
29
30TEST(SetOpsTest, DenseToDenseShape) {
31  ShapeInferenceTestOp op("DenseToDenseSetOperation");
32
33  // Unknown shapes.
34  INFER_OK(op, "?;?", "[?,?];[?];[?]");
35
36  // Invalid rank.
37  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[?];?");
38  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "?;[?]");
39  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[2];?");
40  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "?;[2]");
41
42  // Mismatched ranks.
43  INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[?,?];[?,?,?]");
44  INFER_ERROR("Shape must be rank 3 but is rank 2", op, "[?,?,?];[?,?]");
45  INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[2,1];[2,1,2]");
46  INFER_ERROR("Shape must be rank 3 but is rank 2", op, "[2,1,2];[2,1]");
47
48  // Rank 2, unknown dims.
49  INFER_OK(op, "[?,?];?", "[?,2];[?];[2]");
50  INFER_OK(op, "?;[?,?]", "[?,2];[?];[2]");
51  INFER_OK(op, "[?,?];[?,?]", "[?,2];[?];[2]");
52
53  // Rank 4, unknown dims.
54  INFER_OK(op, "[?,?,?,?];?", "[?,4];[?];[4]");
55  INFER_OK(op, "?;[?,?,?,?]", "[?,4];[?];[4]");
56  INFER_OK(op, "[?,?,?,?];[?,?,?,?]", "[?,4];[?];[4]");
57
58  // Known rank for 1 input.
59  INFER_OK(op, "[5,3,2,1];?", "[?,4];[?];[4]");
60  INFER_OK(op, "?;[5,3,2,1]", "[?,4];[?];[4]");
61  INFER_OK(op, "[5,3,2,1];[?,?,?,?]", "[?,4];[?];[4]");
62  INFER_OK(op, "[?,?,?,?];[5,3,2,1]", "[?,4];[?];[4]");
63  INFER_OK(op, "[5,3,2,1];[?,?,?,?]", "[?,4];[?];[4]");
64
65  // Mismatched n-1 dims.
66  INFER_ERROR("Dimension 0 in both shapes must be equal", op,
67              "[4,?,2,?];[3,1,?,5]");
68  INFER_ERROR("Dimension 2 in both shapes must be equal", op,
69              "[4,3,2,1];[4,3,3,1]");
70
71  // Matched n-1 dims.
72  INFER_OK(op, "[4,5,6,7];[?,?,?,?]", "[?,4];[?];[4]");
73  INFER_OK(op, "[4,5,6,7];[?,?,?,4]", "[?,4];[?];[4]");
74  INFER_OK(op, "[?,?,?,?];[4,5,6,7]", "[?,4];[?];[4]");
75  INFER_OK(op, "[4,?,2,?];[?,1,?,5]", "[?,4];[?];[4]");
76  INFER_OK(op, "[4,5,6,7];[4,?,6,?]", "[?,4];[?];[4]");
77  INFER_OK(op, "[4,5,6,7];[4,5,6,4]", "[?,4];[?];[4]");
78}
79
80TEST(SetOpsTest, DenseToSparseShape_InvalidNumberOfInputs) {
81  ShapeInferenceTestOp op("DenseToSparseSetOperation");
82  op.input_tensors.resize(5);
83  INFER_ERROR("Wrong number of inputs passed", op, "?;?;?;?;?");
84}
85
86TEST(SetOpsTest, DenseToSparseShape) {
87  ShapeInferenceTestOp op("DenseToSparseSetOperation");
88  INFER_OK(op, "?;?;?;?", "[?,?];[?];[?]");
89
90  // Unknown shapes.
91  INFER_OK(op, "?;?;?;?", "[?,?];[?];[?]");
92  INFER_OK(op, "?;[?,?];[?];[?]", "[?,?];[?];[?]");
93
94  // Invalid rank.
95  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[?];?;?;?");
96  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
97              "[?];[?,?];[?];[?]");
98  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
99              "[?];[5,3];[5];[3]");
100  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[2];?;?;?");
101  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
102              "[2];[?,?];[?];[?]");
103  INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
104              "[2];[5,3];[5];[3]");
105
106  // Unknown sparse rank.
107  INFER_OK(op, "[?,?];?;?;?", "[?,2];[?];[2]");
108  INFER_OK(op, "[?,?];[?,?];[?];[?]", "[?,2];[?];[2]");
109
110  // Unknown dense rank.
111  INFER_OK(op, "?;[?,2];[?];[2]", "[?,d3_0];[?];[d3_0]");
112  INFER_OK(op, "?;[5,2];[5];[2]", "[?,d3_0];[?];[d3_0]");
113
114  // Known both ranks.
115  INFER_OK(op, "[?,?];[5,2];[5];[2]", "[?,2];[?];[2]");
116  INFER_OK(op, "[4,3];[5,2];[5];[2]", "[?,2];[?];[2]");
117
118  // Invalid input sparse tensor.
119  INFER_ERROR("elements in index (5) and values (6) do not match", op,
120              "?;[5,3];[6];[3]");
121  INFER_ERROR("rank (3) and shape rank (4) do not match", op,
122              "?;[5,3];[5];[4]");
123}
124
125TEST(SetOpsTest, SparseToSparseShape_InvalidNumberOfInputs) {
126  ShapeInferenceTestOp op("SparseToSparseSetOperation");
127  op.input_tensors.resize(7);
128  INFER_ERROR("Wrong number of inputs passed", op, "?;?;?;?;?;?;?");
129}
130
131TEST(SetOpsTest, SparseToSparseShape) {
132  ShapeInferenceTestOp op("SparseToSparseSetOperation");
133
134  // Unknown.
135  INFER_OK(op, "?;?;?;?;?;?", "[?,?];[?];[?]");
136  INFER_OK(op, "[?,?];[?];[?];[?,?];[?];[?]", "[?,?];[?];[?]");
137  INFER_OK(op, "?;?;?;[?,?];[?];[?]", "[?,?];[?];[?]");
138  INFER_OK(op, "[?,?];[?];[?];?;?;?", "[?,?];[?];[?]");
139
140  // Known rank for 1 input.
141  INFER_OK(op, "[?,2];[?];[2];?;?;?", "[?,d2_0];[?];[d2_0]");
142  INFER_OK(op, "?;?;?;[?,2];[?];[2]", "[?,d5_0];[?];[d5_0]");
143  INFER_OK(op, "[?,2];[?];[2];[?,?];[?];[?]", "[?,d2_0];[?];[d2_0]");
144  INFER_OK(op, "[?,?];[?];[?];[?,2];[?];[2]", "[?,d5_0];[?];[d5_0]");
145
146  // Known rank for both inputs.
147  INFER_OK(op, "[?,2];[?];[2];[?,2];[?];[2]", "[?,d2_0];[?];[d2_0]");
148}
149
150}  // end namespace tensorflow
151