1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/node_def_builder.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference_testutil.h"
19#include "tensorflow/core/framework/tensor_testutil.h"
20#include "tensorflow/core/lib/core/status_test_util.h"
21#include "tensorflow/core/platform/test.h"
22
23namespace tensorflow {
24
25TEST(SparseOpsTest, SparseTensorDenseAdd_ShapeFn) {
26  ShapeInferenceTestOp op("SparseTensorDenseAdd");
27
28  // Copies input 3 to output 0.
29  INFER_OK(op, "?;?;?;?", "in3");
30}
31
32TEST(SparseOpsTest, SparseAdd_ShapeFn) {
33  ShapeInferenceTestOp op("SparseAdd");
34
35  INFER_OK(op, "?;?;?;?;?;?;?", "[?,?];[?];[?]");
36
37  // input(2) determines the output[0].
38  INFER_OK(op, "?;?;[?];?;?;?;?", "[?,d2_0];[?];in2");
39  INFER_OK(op, "?;?;[1];?;?;?;?", "[?,d2_0];[?];in2");
40}
41
42TEST(SparseOpsTest, SparseAddGrad_ShapeFn) {
43  ShapeInferenceTestOp op("SparseAddGrad");
44
45  // Rank checks.
46  INFER_ERROR("must be rank 2", op, "?;?;[1];?");
47  INFER_ERROR("must be rank 2", op, "?;[1];?;?");
48
49  INFER_OK(op, "?;?;?;?", "[?];[?]");
50
51  // input[1].dim(0) and input[2].dim(0) determine output.
52  INFER_OK(op, "?;[?,?];[?,?];?", "[d1_0];[d2_0]");
53}
54
55TEST(SparseOpsTest, SparseReorder_ShapeFn) {
56  ShapeInferenceTestOp op("SparseReorder");
57
58  // Inputs are input_indices, input_values, and input_shape.
59
60  // Rank checks.
61  INFER_ERROR("must be rank 2", op, "[1];?;?");
62  INFER_ERROR("must be rank 1", op, "?;[];?");
63  INFER_ERROR("must be rank 1", op, "?;?;[]");
64
65  // output is always matrix and vector.
66  INFER_OK(op, "?;?;?", "[?,?];[?]");
67
68  // input_indices and input_values and transferred to outputs 0 and 1.
69  INFER_OK(op, "[?,?];[?];?", "in0;in1");
70}
71
72TEST(SparseOpsTest, SparseReshape_ShapeFn) {
73  ShapeInferenceTestOp op("SparseReshape");
74
75  // Inputs are input_indices, input_shape, and new_shape.
76
77  // Rank checks.
78  INFER_ERROR("must be rank 2", op, "[1];?;?");
79  INFER_ERROR("must be rank 1", op, "?;[];?");
80  INFER_ERROR("must be rank 1", op, "?;?;[]");
81
82  // output is always matrix and vector.
83  INFER_OK(op, "?;?;?", "[?,?];[?]");
84
85  // first output is matrix [input_indices.dim(0), new_shape.dim(0)].
86  // new_shape is transferred to second output.
87  INFER_OK(op, "[?,?];?;[?]", "[d0_0,d2_0];in2");
88}
89
90TEST(SparseOpsTest, SparseSplit_ShapeFn) {
91  ShapeInferenceTestOp op("SparseSplit");
92  TF_ASSERT_OK(NodeDefBuilder("test", "SparseSplit")
93                   .Input({"split_dim", 0, DT_INT64})
94                   .Input({"indices", 1, DT_INT64})
95                   .Input({"values", 2, DT_INT64})
96                   .Input({"shape", 3, DT_INT64})
97                   .Attr("num_split", 2)  // each output is copied twice.
98                   .Finalize(&op.node_def));
99
100  // output has three shape types, derived from input_shape (which is input(3)).
101  // each type is copied #splits times.
102  // First output is [?, NumElements(input_shape)].
103  // Second output is [?]
104  // Third output is input_shape.
105  INFER_OK(op, "?;?;?;?", "[?,?];[?,?];[?];[?];in3;in3");
106  INFER_OK(op, "?;?;?;[5,4,3,2,1]", "[?,120];[?,120];[?];[?];in3;in3");
107}
108
109TEST(SparseOpsTest, SparseToDense_ShapeFn) {
110  ShapeInferenceTestOp op("SparseToDense");
111  op.input_tensors.resize(4);
112
113  // input[1] is the shape tensor.
114  INFER_OK(op, "?;?;?;?", "?");
115  INFER_OK(op, "?;[?];?;?", "?");
116  INFER_OK(op, "?;[4];?;?", "[?,?,?,?]");
117  Tensor in_t = test::AsTensor<int32>({1, 2, 3, 4});
118  op.input_tensors[1] = &in_t;
119  INFER_OK(op, "?;[4];?;?", "[1,2,3,4]");
120}
121
122TEST(SparseOpsTest, SparseReduceSum_ShapeFn) {
123  ShapeInferenceTestOp op("SparseReduceSum");
124
125  // Shape fn always yields unknown.
126  INFER_OK(op, "?;?;?;?", "?");
127}
128
129TEST(SparseOpsTest, SerializeSparse_ShapeFn) {
130  ShapeInferenceTestOp op("SerializeSparse");
131
132  // Rank checks.
133  INFER_ERROR("must be rank 2", op, "[1];?;?");
134  INFER_ERROR("must be rank 1", op, "?;[];?");
135  INFER_ERROR("must be rank 1", op, "?;?;[]");
136
137  // output is always vector of size 3.
138  INFER_OK(op, "?;?;?", "[3]");
139}
140
141TEST(SparseOpsTest, SerializeManySparse_ShapeFn) {
142  ShapeInferenceTestOp op("SerializeManySparse");
143
144  // Rank checks.
145  INFER_ERROR("must be rank 2", op, "[1];?;?");
146  INFER_ERROR("must be rank 1", op, "?;[];?");
147  INFER_ERROR("must be rank 1", op, "?;?;[]");
148
149  // output is always matrix of [?,3].
150  INFER_OK(op, "?;?;?", "[?,3]");
151}
152
153TEST(SparseOpsTest, DeserializeManySparse_ShapeFn) {
154  ShapeInferenceTestOp op("DeserializeManySparse");
155
156  // Rank checks.
157  INFER_ERROR("must be rank 2", op, "[1]");
158  INFER_ERROR("must be 3", op, "[?,4]");
159
160  // output is always [?,?];[?];[?].
161  INFER_OK(op, "?", "[?,?];[?];[?]");
162  INFER_OK(op, "[?,3]", "[?,?];[?];[?]");
163}
164
165TEST(SparseOpsTest, SparseTensorDenseMatMul_ShapeFn) {
166  ShapeInferenceTestOp op("SparseTensorDenseMatMul");
167  auto set_adjoints = [&op](bool adjoint_a, bool adjoint_b) {
168    TF_ASSERT_OK(NodeDefBuilder("test", "SparseTensorDenseMatMul")
169                     .Input({"a_indices", 1, DT_INT64})
170                     .Input({"a_values", 2, DT_INT64})
171                     .Input({"a_shape", 3, DT_INT64})
172                     .Input({"b", 3, DT_INT64})
173                     .Attr("adjoint_a", adjoint_a)
174                     .Attr("adjoint_b", adjoint_b)
175                     .Finalize(&op.node_def));
176  };
177
178  // Inputs are a_indices, a_values, a_shape, b.
179  set_adjoints(false, false);
180
181  // Rank checks.
182  INFER_ERROR("must be rank 2", op, "[1];?;?;?");
183  INFER_ERROR("must be rank 1", op, "?;[];?;?");
184  INFER_ERROR("must be rank 1", op, "?;?;[];?");
185  INFER_ERROR("must be rank 2", op, "?;?;[3];?");
186  INFER_ERROR("must be rank 2", op, "?;?;?;[]");
187
188  // second output dim comes from b, depending on adjoint_b value.
189  INFER_OK(op, "?;?;?;?", "[?,?]");
190  INFER_OK(op, "?;?;?;[?,?]", "[?,d3_1]");    // use d3_1, !adjoint_b.
191  INFER_OK(op, "?;?;?;[1,2]", "[?,d3_1]");    // use d3_1, !adjoint_b.
192  INFER_OK(op, "?;?;[2];[1,2]", "[?,d3_1]");  // use d3_1, !adjoint_b.
193
194  set_adjoints(false, true);
195  INFER_OK(op, "?;?;?;[?,?]", "[?,d3_0]");  // use d3_0, adjoint_b.
196  INFER_OK(op, "?;?;?;[1,2]", "[?,d3_0]");  // use d3_0, adjoint_b.
197
198  // first output comes from a, depending on adjoint_a value.
199  // When input tensor is known, its values determine output shape.
200  Tensor a_shape_t = test::AsTensor<int64>(std::vector<int64>{3, 1});
201  op.input_tensors.resize(4);
202  op.input_tensors[2] = &a_shape_t;
203
204  // Multiplying matrices of shape [3, 1] x [1, 2]
205  set_adjoints(false, false);
206  INFER_OK(op, "?;?;[2];[1,2]", "[3,d3_1]");  // use d3_1, !adjoint_b.
207  INFER_OK(op, "?;?;?;[1,2]", "[3,d3_1]");    // use d3_1, !adjoint_b.
208
209  set_adjoints(true, false);
210  // Trying to multiply matrices of [1, 3] x [1, 2]
211  INFER_ERROR("must be equal", op, "?;?;[2];[1,2]");  // adjoint_a, !adjoint_b.
212
213  // Try with shape tensor describing shape of rank 3.
214  a_shape_t = test::AsTensor<int64>(std::vector<int64>{3, 1, 2});
215  INFER_ERROR("must be rank 2 but is rank 3", op, "?;?;[3];[1,2]");
216}
217
218TEST(SparseOpsTest, SparseSoftmax_ShapeFn) {
219  ShapeInferenceTestOp op("SparseSoftmax");
220
221  // Inputs are sp_indices, sp_values, sp_shape.
222
223  // Rank checks.
224  INFER_ERROR("must be rank 2", op, "[1];?;?");
225  INFER_ERROR("must be rank 1", op, "?;[];?");
226  INFER_ERROR("must be rank 1", op, "?;?;[]");
227
228  // output is values_shape.
229  INFER_OK(op, "?;?;?", "[?]");
230  INFER_OK(op, "?;[?];?", "in1");
231  INFER_OK(op, "?;[5];?", "in1");
232}
233
234TEST(SparseOpsTest, SparseSparseMinAndMin_ShapeFn) {
235  for (const char* op_name : {"SparseSparseMaximum", "SparseSparseMinimum"}) {
236    ShapeInferenceTestOp op(op_name);
237
238    // Rank checks.
239    INFER_ERROR("must be rank 2", op, "[1];?;?;?;?;?");  // a_indices
240    INFER_ERROR("must be rank 1", op, "?;[];?;?;?;?");   // a_values
241    INFER_ERROR("must be rank 1", op, "?;?;[];?;?;?");   // a_shape
242    INFER_ERROR("must be rank 2", op, "?;?;?;[];?;?");   // b_indices
243    INFER_ERROR("must be rank 1", op, "?;?;?;?;[];?");   // b_values
244    INFER_ERROR("must be rank 1", op, "?;?;?;?;?;[]");   // b_shape
245
246    // output is always [?,?];[?]
247    INFER_OK(op, "?;?;?;?;?;?", "[?,?];[?]");
248    INFER_OK(op, "?;[?];?;?;?;?", "[?,?];[?]");
249    INFER_OK(op, "?;[5];?;?;?;?", "[?,?];[?]");
250  }
251}
252
253TEST(SparseOpsTest, SparseConcat_ShapeFn) {
254  ShapeInferenceTestOp op("SparseConcat");
255  std::vector<NodeDefBuilder::NodeOut> src_list;
256  int n = 2;
257  src_list.reserve(n);
258  for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_INT64);
259  TF_ASSERT_OK(NodeDefBuilder("test", "SparseConcat")
260                   .Input(src_list)
261                   .Input(src_list)
262                   .Input(src_list)
263                   .Attr("N", n)
264                   .Finalize(&op.node_def));
265
266  // Rank checks.
267  INFER_ERROR("must be rank 2", op, "[1];?;?;?;?;?");  // indices
268  INFER_ERROR("must be rank 2", op, "?;[1];?;?;?;?");  // indices
269  INFER_ERROR("must be rank 1", op, "?;?;[];?;?;?");   // values
270  INFER_ERROR("must be rank 1", op, "?;?;?;[];?;?");   // values
271  INFER_ERROR("must be rank 1", op, "?;?;?;?;[];?");   // shapes
272  INFER_ERROR("must be rank 1", op, "?;?;?;?;?;[]");   // shapes
273
274  // row count is sum of (indices[i].dim(0) merge values[i].dim(0))
275  // ind_cols is merge of (indices[i].dim(1))
276  //
277  // output 0 is matrix [row_count, ind_cols]
278  // output 1 is matrix [row_count]
279  // output 2 is merge of all shapes
280
281  // Test merge of shapes.
282  INFER_OK(op, "?;?;?;?;?;?", "[?,?];[?];[?]");
283  INFER_OK(op, "?;?;?;?;[?];[?]", "[?,?];[?];in4|in5");
284  INFER_OK(op, "?;?;?;?;[?];[5]", "[?,?];[?];in5");
285
286  // Test accumulation of row_count and ind_cols from indices.
287  INFER_OK(op, "[4,5];[3,?];?;?;?;?", "[7,d0_1];[7];[?]");
288
289  // Test accumulation of row_count and ind_cols from values.
290  INFER_OK(op, "?;?;[4];[3];?;?", "[7,?];[7];[?]");
291
292  // Test merge between row_count and ind_cols.
293  INFER_OK(op, "[?,2];[3,?];[4];[?];?;?", "[7,d0_1];[7];[?]");
294
295  // Test some errors during merge.
296  INFER_ERROR("but are 100 and 200", op, "[100,?];[?,?];[200];[?];?;?");
297  INFER_ERROR("but are 2 and 3", op, "[?,2];[?,3];[?];[?];?;?");
298  INFER_ERROR("but are 4 and 5", op, "?;?;?;?;[4];[5]");
299}
300
301TEST(SparseOpsTest, SparseDenseCwise_ShapeFn) {
302  for (const char* op_name :
303       {"SparseDenseCwiseMul", "SparseDenseCwiseDiv", "SparseDenseCwiseAdd"}) {
304    ShapeInferenceTestOp op(op_name);
305
306    // output is always a vector.
307    INFER_OK(op, "?;?;?;?", "[?]");
308
309    // input(0).dim(0) determines output[0].
310    INFER_OK(op, "[?,?];?;?;?", "[d0_0]");
311
312    // Rank checks.
313    INFER_ERROR("must be rank 2", op, "[1];?;?;?");
314  }
315}
316
317TEST(SparseOpsTest, AddSparseToTensorsMap_ShapeFn) {
318  ShapeInferenceTestOp op("AddSparseToTensorsMap");
319
320  // Rank checks.
321  INFER_ERROR("must be rank 2", op, "[1];?;?");
322  INFER_ERROR("must be rank 1", op, "?;[];?");
323  INFER_ERROR("must be rank 1", op, "?;?;[]");
324
325  // output is always scalar
326  INFER_OK(op, "?;?;?", "[]");
327}
328
329TEST(SparseOpsTest, AddManySparseToTensorsMap_ShapeFn) {
330  ShapeInferenceTestOp op("AddManySparseToTensorsMap");
331
332  // Rank checks.
333  INFER_ERROR("must be rank 2", op, "[1];?;?");
334  INFER_ERROR("must be rank 1", op, "?;[];?");
335  INFER_ERROR("must be rank 1", op, "?;?;[]");
336
337  // output is always matrix of [?].
338  INFER_OK(op, "?;?;?", "[?]");
339}
340
341TEST(SparseOpsTest, TakeManySparseFromTensorsMap_ShapeFn) {
342  ShapeInferenceTestOp op("TakeManySparseFromTensorsMap");
343
344  // Rank checks.
345  INFER_ERROR("must be rank 1", op, "[?,1]");
346
347  // output is always [?,?];[?];[?].
348  INFER_OK(op, "?", "[?,?];[?];[?]");
349  INFER_OK(op, "[?]", "[?,?];[?];[?]");
350}
351
352}  // end namespace tensorflow
353