1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/framework/node_def_builder.h" 17#include "tensorflow/core/framework/op.h" 18#include "tensorflow/core/framework/shape_inference_testutil.h" 19#include "tensorflow/core/framework/tensor_testutil.h" 20#include "tensorflow/core/lib/core/status_test_util.h" 21#include "tensorflow/core/platform/test.h" 22 23namespace tensorflow { 24 25TEST(SparseOpsTest, SparseTensorDenseAdd_ShapeFn) { 26 ShapeInferenceTestOp op("SparseTensorDenseAdd"); 27 28 // Copies input 3 to output 0. 29 INFER_OK(op, "?;?;?;?", "in3"); 30} 31 32TEST(SparseOpsTest, SparseAdd_ShapeFn) { 33 ShapeInferenceTestOp op("SparseAdd"); 34 35 INFER_OK(op, "?;?;?;?;?;?;?", "[?,?];[?];[?]"); 36 37 // input(2) determines the output[0]. 38 INFER_OK(op, "?;?;[?];?;?;?;?", "[?,d2_0];[?];in2"); 39 INFER_OK(op, "?;?;[1];?;?;?;?", "[?,d2_0];[?];in2"); 40} 41 42TEST(SparseOpsTest, SparseAddGrad_ShapeFn) { 43 ShapeInferenceTestOp op("SparseAddGrad"); 44 45 // Rank checks. 46 INFER_ERROR("must be rank 2", op, "?;?;[1];?"); 47 INFER_ERROR("must be rank 2", op, "?;[1];?;?"); 48 49 INFER_OK(op, "?;?;?;?", "[?];[?]"); 50 51 // input[1].dim(0) and input[2].dim(0) determine output. 52 INFER_OK(op, "?;[?,?];[?,?];?", "[d1_0];[d2_0]"); 53} 54 55TEST(SparseOpsTest, SparseReorder_ShapeFn) { 56 ShapeInferenceTestOp op("SparseReorder"); 57 58 // Inputs are input_indices, input_values, and input_shape. 59 60 // Rank checks. 61 INFER_ERROR("must be rank 2", op, "[1];?;?"); 62 INFER_ERROR("must be rank 1", op, "?;[];?"); 63 INFER_ERROR("must be rank 1", op, "?;?;[]"); 64 65 // output is always matrix and vector. 66 INFER_OK(op, "?;?;?", "[?,?];[?]"); 67 68 // input_indices and input_values and transferred to outputs 0 and 1. 69 INFER_OK(op, "[?,?];[?];?", "in0;in1"); 70} 71 72TEST(SparseOpsTest, SparseReshape_ShapeFn) { 73 ShapeInferenceTestOp op("SparseReshape"); 74 75 // Inputs are input_indices, input_shape, and new_shape. 76 77 // Rank checks. 78 INFER_ERROR("must be rank 2", op, "[1];?;?"); 79 INFER_ERROR("must be rank 1", op, "?;[];?"); 80 INFER_ERROR("must be rank 1", op, "?;?;[]"); 81 82 // output is always matrix and vector. 83 INFER_OK(op, "?;?;?", "[?,?];[?]"); 84 85 // first output is matrix [input_indices.dim(0), new_shape.dim(0)]. 86 // new_shape is transferred to second output. 87 INFER_OK(op, "[?,?];?;[?]", "[d0_0,d2_0];in2"); 88} 89 90TEST(SparseOpsTest, SparseSplit_ShapeFn) { 91 ShapeInferenceTestOp op("SparseSplit"); 92 TF_ASSERT_OK(NodeDefBuilder("test", "SparseSplit") 93 .Input({"split_dim", 0, DT_INT64}) 94 .Input({"indices", 1, DT_INT64}) 95 .Input({"values", 2, DT_INT64}) 96 .Input({"shape", 3, DT_INT64}) 97 .Attr("num_split", 2) // each output is copied twice. 98 .Finalize(&op.node_def)); 99 100 // output has three shape types, derived from input_shape (which is input(3)). 101 // each type is copied #splits times. 102 // First output is [?, NumElements(input_shape)]. 103 // Second output is [?] 104 // Third output is input_shape. 105 INFER_OK(op, "?;?;?;?", "[?,?];[?,?];[?];[?];in3;in3"); 106 INFER_OK(op, "?;?;?;[5,4,3,2,1]", "[?,120];[?,120];[?];[?];in3;in3"); 107} 108 109TEST(SparseOpsTest, SparseToDense_ShapeFn) { 110 ShapeInferenceTestOp op("SparseToDense"); 111 op.input_tensors.resize(4); 112 113 // input[1] is the shape tensor. 114 INFER_OK(op, "?;?;?;?", "?"); 115 INFER_OK(op, "?;[?];?;?", "?"); 116 INFER_OK(op, "?;[4];?;?", "[?,?,?,?]"); 117 Tensor in_t = test::AsTensor<int32>({1, 2, 3, 4}); 118 op.input_tensors[1] = &in_t; 119 INFER_OK(op, "?;[4];?;?", "[1,2,3,4]"); 120} 121 122TEST(SparseOpsTest, SparseReduceSum_ShapeFn) { 123 ShapeInferenceTestOp op("SparseReduceSum"); 124 125 // Shape fn always yields unknown. 126 INFER_OK(op, "?;?;?;?", "?"); 127} 128 129TEST(SparseOpsTest, SerializeSparse_ShapeFn) { 130 ShapeInferenceTestOp op("SerializeSparse"); 131 132 // Rank checks. 133 INFER_ERROR("must be rank 2", op, "[1];?;?"); 134 INFER_ERROR("must be rank 1", op, "?;[];?"); 135 INFER_ERROR("must be rank 1", op, "?;?;[]"); 136 137 // output is always vector of size 3. 138 INFER_OK(op, "?;?;?", "[3]"); 139} 140 141TEST(SparseOpsTest, SerializeManySparse_ShapeFn) { 142 ShapeInferenceTestOp op("SerializeManySparse"); 143 144 // Rank checks. 145 INFER_ERROR("must be rank 2", op, "[1];?;?"); 146 INFER_ERROR("must be rank 1", op, "?;[];?"); 147 INFER_ERROR("must be rank 1", op, "?;?;[]"); 148 149 // output is always matrix of [?,3]. 150 INFER_OK(op, "?;?;?", "[?,3]"); 151} 152 153TEST(SparseOpsTest, DeserializeManySparse_ShapeFn) { 154 ShapeInferenceTestOp op("DeserializeManySparse"); 155 156 // Rank checks. 157 INFER_ERROR("must be rank 2", op, "[1]"); 158 INFER_ERROR("must be 3", op, "[?,4]"); 159 160 // output is always [?,?];[?];[?]. 161 INFER_OK(op, "?", "[?,?];[?];[?]"); 162 INFER_OK(op, "[?,3]", "[?,?];[?];[?]"); 163} 164 165TEST(SparseOpsTest, SparseTensorDenseMatMul_ShapeFn) { 166 ShapeInferenceTestOp op("SparseTensorDenseMatMul"); 167 auto set_adjoints = [&op](bool adjoint_a, bool adjoint_b) { 168 TF_ASSERT_OK(NodeDefBuilder("test", "SparseTensorDenseMatMul") 169 .Input({"a_indices", 1, DT_INT64}) 170 .Input({"a_values", 2, DT_INT64}) 171 .Input({"a_shape", 3, DT_INT64}) 172 .Input({"b", 3, DT_INT64}) 173 .Attr("adjoint_a", adjoint_a) 174 .Attr("adjoint_b", adjoint_b) 175 .Finalize(&op.node_def)); 176 }; 177 178 // Inputs are a_indices, a_values, a_shape, b. 179 set_adjoints(false, false); 180 181 // Rank checks. 182 INFER_ERROR("must be rank 2", op, "[1];?;?;?"); 183 INFER_ERROR("must be rank 1", op, "?;[];?;?"); 184 INFER_ERROR("must be rank 1", op, "?;?;[];?"); 185 INFER_ERROR("must be rank 2", op, "?;?;[3];?"); 186 INFER_ERROR("must be rank 2", op, "?;?;?;[]"); 187 188 // second output dim comes from b, depending on adjoint_b value. 189 INFER_OK(op, "?;?;?;?", "[?,?]"); 190 INFER_OK(op, "?;?;?;[?,?]", "[?,d3_1]"); // use d3_1, !adjoint_b. 191 INFER_OK(op, "?;?;?;[1,2]", "[?,d3_1]"); // use d3_1, !adjoint_b. 192 INFER_OK(op, "?;?;[2];[1,2]", "[?,d3_1]"); // use d3_1, !adjoint_b. 193 194 set_adjoints(false, true); 195 INFER_OK(op, "?;?;?;[?,?]", "[?,d3_0]"); // use d3_0, adjoint_b. 196 INFER_OK(op, "?;?;?;[1,2]", "[?,d3_0]"); // use d3_0, adjoint_b. 197 198 // first output comes from a, depending on adjoint_a value. 199 // When input tensor is known, its values determine output shape. 200 Tensor a_shape_t = test::AsTensor<int64>(std::vector<int64>{3, 1}); 201 op.input_tensors.resize(4); 202 op.input_tensors[2] = &a_shape_t; 203 204 // Multiplying matrices of shape [3, 1] x [1, 2] 205 set_adjoints(false, false); 206 INFER_OK(op, "?;?;[2];[1,2]", "[3,d3_1]"); // use d3_1, !adjoint_b. 207 INFER_OK(op, "?;?;?;[1,2]", "[3,d3_1]"); // use d3_1, !adjoint_b. 208 209 set_adjoints(true, false); 210 // Trying to multiply matrices of [1, 3] x [1, 2] 211 INFER_ERROR("must be equal", op, "?;?;[2];[1,2]"); // adjoint_a, !adjoint_b. 212 213 // Try with shape tensor describing shape of rank 3. 214 a_shape_t = test::AsTensor<int64>(std::vector<int64>{3, 1, 2}); 215 INFER_ERROR("must be rank 2 but is rank 3", op, "?;?;[3];[1,2]"); 216} 217 218TEST(SparseOpsTest, SparseSoftmax_ShapeFn) { 219 ShapeInferenceTestOp op("SparseSoftmax"); 220 221 // Inputs are sp_indices, sp_values, sp_shape. 222 223 // Rank checks. 224 INFER_ERROR("must be rank 2", op, "[1];?;?"); 225 INFER_ERROR("must be rank 1", op, "?;[];?"); 226 INFER_ERROR("must be rank 1", op, "?;?;[]"); 227 228 // output is values_shape. 229 INFER_OK(op, "?;?;?", "[?]"); 230 INFER_OK(op, "?;[?];?", "in1"); 231 INFER_OK(op, "?;[5];?", "in1"); 232} 233 234TEST(SparseOpsTest, SparseSparseMinAndMin_ShapeFn) { 235 for (const char* op_name : {"SparseSparseMaximum", "SparseSparseMinimum"}) { 236 ShapeInferenceTestOp op(op_name); 237 238 // Rank checks. 239 INFER_ERROR("must be rank 2", op, "[1];?;?;?;?;?"); // a_indices 240 INFER_ERROR("must be rank 1", op, "?;[];?;?;?;?"); // a_values 241 INFER_ERROR("must be rank 1", op, "?;?;[];?;?;?"); // a_shape 242 INFER_ERROR("must be rank 2", op, "?;?;?;[];?;?"); // b_indices 243 INFER_ERROR("must be rank 1", op, "?;?;?;?;[];?"); // b_values 244 INFER_ERROR("must be rank 1", op, "?;?;?;?;?;[]"); // b_shape 245 246 // output is always [?,?];[?] 247 INFER_OK(op, "?;?;?;?;?;?", "[?,?];[?]"); 248 INFER_OK(op, "?;[?];?;?;?;?", "[?,?];[?]"); 249 INFER_OK(op, "?;[5];?;?;?;?", "[?,?];[?]"); 250 } 251} 252 253TEST(SparseOpsTest, SparseConcat_ShapeFn) { 254 ShapeInferenceTestOp op("SparseConcat"); 255 std::vector<NodeDefBuilder::NodeOut> src_list; 256 int n = 2; 257 src_list.reserve(n); 258 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_INT64); 259 TF_ASSERT_OK(NodeDefBuilder("test", "SparseConcat") 260 .Input(src_list) 261 .Input(src_list) 262 .Input(src_list) 263 .Attr("N", n) 264 .Finalize(&op.node_def)); 265 266 // Rank checks. 267 INFER_ERROR("must be rank 2", op, "[1];?;?;?;?;?"); // indices 268 INFER_ERROR("must be rank 2", op, "?;[1];?;?;?;?"); // indices 269 INFER_ERROR("must be rank 1", op, "?;?;[];?;?;?"); // values 270 INFER_ERROR("must be rank 1", op, "?;?;?;[];?;?"); // values 271 INFER_ERROR("must be rank 1", op, "?;?;?;?;[];?"); // shapes 272 INFER_ERROR("must be rank 1", op, "?;?;?;?;?;[]"); // shapes 273 274 // row count is sum of (indices[i].dim(0) merge values[i].dim(0)) 275 // ind_cols is merge of (indices[i].dim(1)) 276 // 277 // output 0 is matrix [row_count, ind_cols] 278 // output 1 is matrix [row_count] 279 // output 2 is merge of all shapes 280 281 // Test merge of shapes. 282 INFER_OK(op, "?;?;?;?;?;?", "[?,?];[?];[?]"); 283 INFER_OK(op, "?;?;?;?;[?];[?]", "[?,?];[?];in4|in5"); 284 INFER_OK(op, "?;?;?;?;[?];[5]", "[?,?];[?];in5"); 285 286 // Test accumulation of row_count and ind_cols from indices. 287 INFER_OK(op, "[4,5];[3,?];?;?;?;?", "[7,d0_1];[7];[?]"); 288 289 // Test accumulation of row_count and ind_cols from values. 290 INFER_OK(op, "?;?;[4];[3];?;?", "[7,?];[7];[?]"); 291 292 // Test merge between row_count and ind_cols. 293 INFER_OK(op, "[?,2];[3,?];[4];[?];?;?", "[7,d0_1];[7];[?]"); 294 295 // Test some errors during merge. 296 INFER_ERROR("but are 100 and 200", op, "[100,?];[?,?];[200];[?];?;?"); 297 INFER_ERROR("but are 2 and 3", op, "[?,2];[?,3];[?];[?];?;?"); 298 INFER_ERROR("but are 4 and 5", op, "?;?;?;?;[4];[5]"); 299} 300 301TEST(SparseOpsTest, SparseDenseCwise_ShapeFn) { 302 for (const char* op_name : 303 {"SparseDenseCwiseMul", "SparseDenseCwiseDiv", "SparseDenseCwiseAdd"}) { 304 ShapeInferenceTestOp op(op_name); 305 306 // output is always a vector. 307 INFER_OK(op, "?;?;?;?", "[?]"); 308 309 // input(0).dim(0) determines output[0]. 310 INFER_OK(op, "[?,?];?;?;?", "[d0_0]"); 311 312 // Rank checks. 313 INFER_ERROR("must be rank 2", op, "[1];?;?;?"); 314 } 315} 316 317TEST(SparseOpsTest, AddSparseToTensorsMap_ShapeFn) { 318 ShapeInferenceTestOp op("AddSparseToTensorsMap"); 319 320 // Rank checks. 321 INFER_ERROR("must be rank 2", op, "[1];?;?"); 322 INFER_ERROR("must be rank 1", op, "?;[];?"); 323 INFER_ERROR("must be rank 1", op, "?;?;[]"); 324 325 // output is always scalar 326 INFER_OK(op, "?;?;?", "[]"); 327} 328 329TEST(SparseOpsTest, AddManySparseToTensorsMap_ShapeFn) { 330 ShapeInferenceTestOp op("AddManySparseToTensorsMap"); 331 332 // Rank checks. 333 INFER_ERROR("must be rank 2", op, "[1];?;?"); 334 INFER_ERROR("must be rank 1", op, "?;[];?"); 335 INFER_ERROR("must be rank 1", op, "?;?;[]"); 336 337 // output is always matrix of [?]. 338 INFER_OK(op, "?;?;?", "[?]"); 339} 340 341TEST(SparseOpsTest, TakeManySparseFromTensorsMap_ShapeFn) { 342 ShapeInferenceTestOp op("TakeManySparseFromTensorsMap"); 343 344 // Rank checks. 345 INFER_ERROR("must be rank 1", op, "[?,1]"); 346 347 // output is always [?,?];[?];[?]. 348 INFER_OK(op, "?", "[?,?];[?];[?]"); 349 INFER_OK(op, "[?]", "[?,?];[?];[?]"); 350} 351 352} // end namespace tensorflow 353