1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <cmath> 17#include <memory> 18#include <unordered_map> 19 20#include "tensorflow/c/checkpoint_reader.h" 21#include "tensorflow/core/framework/tensor.h" 22#include "tensorflow/core/graph/graph_constructor.h" 23#include "tensorflow/core/graph/node_builder.h" 24#include "tensorflow/core/graph/subgraph.h" 25#include "tensorflow/core/platform/init_main.h" 26#include "tensorflow/core/public/session.h" 27#include "tensorflow/core/util/command_line_flags.h" 28#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" 29#include "tensorflow/tools/graph_transforms/transform_utils.h" 30 31namespace tensorflow { 32using str_util::Join; 33using str_util::Split; 34using str_util::StringReplace; 35using strings::StrCat; 36 37namespace graph_transforms { 38 39// Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for 40// non-zero tensor content. 41Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor, 42 Tensor* values_tensor) { 43 if (tensor.dims() != 2 || tensor.dim_size(1) != 1) { 44 return tensorflow::errors::FailedPrecondition( 45 "Transform only applicable to subgraph with 'Const' with " 46 "tensor of shape [N, 1]. But instead get shape ", 47 tensor.shape().DebugString(), "."); 48 } 49 50 auto flat = tensor.flat<float>(); 51 std::vector<int64> indices; 52 std::vector<float> values; 53 54 for (int64 i = 0; i < flat.size(); i++) { 55 float val = flat(i); 56 if (std::abs(val) >= 1.0e-5) { 57 indices.push_back(i); 58 values.push_back(val); 59 } 60 } 61 62 // During model initialization, InitializeTableOp makes use of 63 // KeyValueTensorIterator, which does not accept empty keys or values. 64 // Consequently, adding a dummy pair of indices and values as a walkaround. 65 if (indices.empty() || values.empty()) { 66 indices.push_back(0); 67 values.push_back(0); 68 } 69 *indices_tensor = Tensor(DataTypeToEnum<int64>::value, 70 {static_cast<int64>(indices.size())}); 71 std::copy_n(indices.begin(), indices.size(), 72 indices_tensor->flat<int64>().data()); 73 74 *values_tensor = 75 Tensor(DataTypeToEnum<float>::value, {static_cast<int64>(values.size())}); 76 std::copy_n(values.begin(), values.size(), 77 values_tensor->flat<float>().data()); 78 79 return Status::OK(); 80} 81 82void CreateConstNode(const Tensor& tensor, const string& name, 83 NodeDef* node_def) { 84 node_def->set_op("Const"); 85 node_def->set_name(name); 86 SetNodeTensorAttr<float>("value", tensor, node_def); 87} 88 89string GetMonolithicTensorKey(const string& tensor_slice_name) { 90 std::vector<string> names = Split(tensor_slice_name, "/"); 91 if (StringPiece(names[names.size() - 1]).starts_with("part_")) { 92 CHECK_GE(names.size(), 2); 93 names.pop_back(); 94 } 95 return Join(names, "/"); 96} 97 98Status ObtainTensorSlice(const GraphDef& input_graph_def, 99 const string& target_name, 100 string* shape_slice_string) { 101 string restore_node_name; 102 for (const auto& node : input_graph_def.node()) { 103 std::vector<string> node_name_parts = Split(node.name(), "/"); 104 if (node_name_parts.size() == 2 && 105 StringPiece(node_name_parts[0]).starts_with("save") && 106 StringPiece(node_name_parts[1]).starts_with("Assign") && 107 node.input(0) == target_name) { 108 restore_node_name = node.input(1); 109 break; 110 } 111 } 112 113 std::vector<string> restore_node_parts = Split(restore_node_name, ":"); 114 CHECK_LE(restore_node_parts.size(), 2); 115 string tensor_names_node; 116 string shape_and_slices_node; 117 for (const auto& node : input_graph_def.node()) { 118 if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) { 119 tensor_names_node = node.input(1); 120 shape_and_slices_node = node.input(2); 121 break; 122 } 123 } 124 125 int offset = -1; 126 for (const auto& node : input_graph_def.node()) { 127 if (node.name() == tensor_names_node) { 128 Tensor tensor_names_tensor; 129 TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor)); 130 const auto& tensor_names_value = tensor_names_tensor.flat<string>(); 131 for (int i = 0; i < tensor_names_value.size(); i++) { 132 if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) { 133 offset = i; 134 break; 135 } 136 } 137 } 138 } 139 if (offset == -1) { 140 return errors::Internal("Unable to find RestoreV2 entry for variable: ", 141 target_name); 142 } 143 for (const auto& node : input_graph_def.node()) { 144 if (node.name() == shape_and_slices_node) { 145 Tensor shape_and_slices_tensor; 146 TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor)); 147 const auto& shape_and_slices_value = 148 shape_and_slices_tensor.flat<string>(); 149 *shape_slice_string = shape_and_slices_value(offset); 150 return Status::OK(); 151 } 152 } 153 return errors::Internal("Unable to find slice for variable: ", target_name); 154} 155 156Status ReadTensorFromCheckpoint( 157 const string& tensor_name, const std::unique_ptr<BundleReader>& ckpt_reader, 158 const string& shape_and_slice, Tensor* tensor) { 159 if (ckpt_reader) { 160 TensorShape parsed_full_shape; 161 TensorSlice parsed_slice; 162 TensorShape parsed_slice_shape; 163 164 bool get_slice = false; 165 if (!shape_and_slice.empty()) { 166 TF_RETURN_IF_ERROR( 167 checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape, 168 &parsed_slice, &parsed_slice_shape)); 169 get_slice = (parsed_full_shape != parsed_slice_shape); 170 } 171 if (get_slice) { 172 TF_RETURN_IF_ERROR(ckpt_reader->LookupSlice( 173 GetMonolithicTensorKey(tensor_name), parsed_slice, tensor)); 174 } else { 175 TF_RETURN_IF_ERROR( 176 ckpt_reader->Lookup(GetMonolithicTensorKey(tensor_name), tensor)); 177 } 178 return Status::OK(); 179 } 180 return errors::Internal("Checkpoint reader was not initialized. "); 181} 182 183Status InitializeCheckpointReader(const TransformFuncContext& context, 184 std::unique_ptr<BundleReader>* ckpt_reader) { 185 if (context.params.count("input_checkpoint")) { 186 const string input_checkpoint = context.params.at("input_checkpoint")[0]; 187 ckpt_reader->reset(new BundleReader(Env::Default(), input_checkpoint)); 188 TF_RETURN_IF_ERROR((*ckpt_reader)->status()); 189 } 190 return Status::OK(); 191} 192 193Status ObtainVariableInfo( 194 const GraphDef& input_graph_def, 195 std::unique_ptr<std::unordered_map<string, string> >* shapes_and_slices) { 196 shapes_and_slices->reset(new std::unordered_map<string, string>()); 197 for (const auto& node : input_graph_def.node()) { 198 if ((node.op() == "Variable") || (node.op() == "VariableV2")) { 199 string s; 200 TF_RETURN_IF_ERROR(ObtainTensorSlice(input_graph_def, node.name(), &s)); 201 (**shapes_and_slices)[node.name()] = s; 202 } 203 } 204 return Status::OK(); 205} 206 207Status RemoveInputAtIndex(NodeDef* n, int index) { 208 for (int i = index; i < n->input_size() - 1; i++) { 209 n->mutable_input()->SwapElements(i, i + 1); 210 } 211 n->mutable_input()->RemoveLast(); 212 return Status::OK(); 213} 214 215Status RemoveNodeAtIndex(GraphDef* g, int index) { 216 for (int i = index; i < g->node_size() - 1; i++) { 217 g->mutable_node()->SwapElements(i, i + 1); 218 } 219 g->mutable_node()->RemoveLast(); 220 return Status::OK(); 221} 222 223Status SparsifyGatherInternal( 224 const GraphDef& input_graph_def, 225 const std::unique_ptr<std::unordered_map<string, string> >& 226 shapes_and_slices, 227 const TransformFuncContext& context, const OpTypePattern& pattern, 228 const std::unique_ptr<BundleReader>& ckpt_reader, 229 GraphDef* output_graph_def) { 230 string group_init_node = "group_deps"; 231 if (context.params.count("group_init_node")) { 232 group_init_node = context.params.at("group_init_node")[0]; 233 } 234 GraphDef current_graph_def = input_graph_def; 235 bool any_match_found = false; 236 237 // Populate references. 238 std::unordered_map<string, int> refs; 239 for (const auto& node : current_graph_def.node()) { 240 for (const auto& input : node.input()) { 241 auto parsed_input = StringReplace(input, "^", "", true); 242 refs[parsed_input] += 1; 243 } 244 } 245 246 // The subgraphs may have overlapping components, therefore GraphMatcher 247 // doesn't return all subgraphs in one round -- this has to be multi-round 248 // update. 249 do { 250 any_match_found = false; 251 GraphDef replaced_graph_def = current_graph_def; 252 std::vector<string> init_table_node_names; 253 std::vector<string> removed_node_names; 254 255 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( 256 current_graph_def, pattern, 257 [&ckpt_reader, &any_match_found, &init_table_node_names, 258 &shapes_and_slices, &removed_node_names, 259 &refs](const NodeMatch& match, const std::set<string>& input_nodes, 260 const std::set<string>& output_nodes, 261 std::vector<NodeDef>* new_nodes) { 262 any_match_found = true; 263 264 // The captured subgraph should be of the following pattern: 265 // Const --> Identity --> Gather --> ... 266 // ^ 267 // | 268 // (ids) 269 // 270 // After transform, it becomes: 271 // --> NoOp(group_deps) 272 // | 273 // Const --> InitializeTable --> HashTable 274 // ^ | 275 // | | 276 // Const ------------- | 277 // v 278 // (ids) ---> LookupTableFind <--- Const(default) 279 // | 280 // v 281 // ... 282 283 // clang-format off 284 // For each subgraph, do the following 285 // 1. Sparsify the `Const`, creating two `Const`, for hashtable 286 // key/val. 287 // 2. Create a `InitializeTable` op connecting to the above 2 `Const`. 288 // 3. Create a `HashTable` op connecting to `InitializeTable` op. 289 // 4. Replace the `Gather` with a `LookupTableFind` op. 290 // 5. Connect the `LookupTableFind` with 291 // a. `HashTable` 292 // b. `Gather`'s ids input 293 // c. a `default_val` arg, valued at 0 294 // clang-format on 295 const NodeDef& gather_node = match.node; 296 297 // GatherV2 adds an "axis" parameter. sparsify_gather only supports 298 // axis 0 gathers. 299 if (gather_node.op() == "GatherV2") { 300 // Per the OpTypePattern, the 3rd input to Gather must be a Const. 301 const NodeDef& axis_node = match.inputs[2].node; 302 303 Tensor axis_t; 304 TF_RETURN_IF_ERROR(GetNodeAttr(axis_node, "value", &axis_t)); 305 int64 axis = 0; 306 if (axis_t.dtype() == DT_INT32) { 307 axis = axis_t.scalar<int32>()(); 308 } else if (axis_t.dtype() == DT_INT64) { 309 axis = axis_t.scalar<int64>()(); 310 } else { 311 return tensorflow::errors::FailedPrecondition( 312 "Gather axis was not int32 or int64."); 313 } 314 315 if (axis != 0) { 316 return tensorflow::errors::FailedPrecondition( 317 "Transform only applicable to subgraph with GatherV2 over " 318 "axis 0. Found axis ", 319 axis, "."); 320 } 321 } 322 323 const NodeDef& weights_node = match.inputs[0].inputs[0].node; 324 325 DataType data_type; 326 TF_RETURN_IF_ERROR(GetNodeAttr(weights_node, "dtype", &data_type)); 327 if (data_type != DT_FLOAT) { 328 return tensorflow::errors::FailedPrecondition( 329 "Transform only applicable to subgraph with 'Const'," 330 "'Variable', or 'VariableV2' of dtype " 331 "'DT_FLOAT'. Found '" + 332 weights_node.op() + "' with name '", 333 weights_node.name(), "' and dtype '", data_type, "'."); 334 } 335 336 Tensor weight; 337 if (weights_node.op() == "Const") { 338 weight = GetNodeTensorAttr(weights_node, "value"); 339 } else { 340 TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint( 341 weights_node.name(), ckpt_reader, 342 (*shapes_and_slices)[weights_node.name()], &weight)); 343 } 344 // Add both both weight and identity node names. 345 removed_node_names.push_back(weights_node.name()); 346 removed_node_names.push_back(match.inputs[0].node.name()); 347 for (auto input_node : match.inputs[0].node.input()) { 348 auto parsed_input = StringReplace(input_node, "^", "", true); 349 refs[parsed_input]--; 350 } 351 Tensor indices_tensor; 352 Tensor values_tensor; 353 TF_RETURN_IF_ERROR( 354 SparsifyWeights(weight, &indices_tensor, &values_tensor)); 355 356 // indices and values of sparsified `Const` 357 DataType key_dtype = DT_INT64; 358 NodeDef indices_node; 359 CreateConstNode(indices_tensor, 360 StrCat(weights_node.name(), "/indices"), 361 &indices_node); 362 SetNodeAttr("dtype", key_dtype, &indices_node); 363 364 NodeDef values_node; 365 CreateConstNode(values_tensor, StrCat(weights_node.name(), "/values"), 366 &values_node); 367 SetNodeAttr("dtype", data_type, &values_node); 368 369 // HashTable node 370 NodeDef hashtable_node; 371 hashtable_node.set_op("HashTable"); 372 hashtable_node.set_name(StrCat(weights_node.name(), "/HashTable")); 373 SetNodeAttr("key_dtype", key_dtype, &hashtable_node); 374 SetNodeAttr("value_dtype", data_type, &hashtable_node); 375 376 // InitializeTable node 377 NodeDef init_table_node; 378 init_table_node.set_op("InitializeTable"); 379 init_table_node.set_name( 380 StrCat(weights_node.name(), "/InitializeTable")); 381 SetNodeAttr("Tkey", key_dtype, &init_table_node); 382 SetNodeAttr("Tval", data_type, &init_table_node); 383 init_table_node_names.push_back(init_table_node.name()); 384 385 // LookupTableFind node 386 NodeDef lookup_node; 387 lookup_node.set_op("LookupTableFind"); 388 lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind")); 389 SetNodeAttr("Tin", key_dtype, &lookup_node); 390 SetNodeAttr("Tout", data_type, &lookup_node); 391 392 // Default return value of hashtable lookup 393 Tensor zero_tensor(data_type, TensorShape({})); 394 zero_tensor.flat<float>()(0) = 0.0; 395 NodeDef default_value_node; 396 CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"), 397 &default_value_node); 398 SetNodeAttr("dtype", data_type, &default_value_node); 399 400 // ExpandDims argument 401 Tensor dim_idx(DT_INT32, TensorShape({})); 402 dim_idx.flat<int32>()(0) = -1; 403 NodeDef dim_idx_node; 404 dim_idx_node.set_op("Const"); 405 dim_idx_node.set_name( 406 StrCat(gather_node.name(), "/ExpandDims/Const")); 407 SetNodeAttr("value", dim_idx, &dim_idx_node); 408 SetNodeAttr("dtype", DT_INT32, &dim_idx_node); 409 410 // ExpandDims node 411 NodeDef expand_dims_node; 412 expand_dims_node.set_op("ExpandDims"); 413 // Reuse gather_node's name so not to change dependent's inputs 414 expand_dims_node.set_name(gather_node.name()); 415 SetNodeAttr("T", data_type, &expand_dims_node); 416 417 // Connect nodes 418 AddNodeInput(hashtable_node.name(), &init_table_node); 419 refs[hashtable_node.name()]++; 420 AddNodeInput(indices_node.name(), &init_table_node); 421 refs[indices_node.name()]++; 422 AddNodeInput(values_node.name(), &init_table_node); 423 refs[values_node.name()]++; 424 425 AddNodeInput(hashtable_node.name(), &lookup_node); 426 refs[hashtable_node.name()]++; 427 AddNodeInput(gather_node.input(1), &lookup_node); 428 refs[gather_node.input(1)]++; 429 AddNodeInput(default_value_node.name(), &lookup_node); 430 refs[default_value_node.name()]++; 431 432 AddNodeInput(lookup_node.name(), &expand_dims_node); 433 refs[lookup_node.name()]++; 434 AddNodeInput(dim_idx_node.name(), &expand_dims_node); 435 refs[dim_idx_node.name()]++; 436 437 // Copy 'ids' input of original 'Gather' 438 new_nodes->push_back(match.inputs[1].node); 439 new_nodes->push_back(indices_node); 440 new_nodes->push_back(values_node); 441 new_nodes->push_back(hashtable_node); 442 new_nodes->push_back(init_table_node); 443 new_nodes->push_back(lookup_node); 444 new_nodes->push_back(default_value_node); 445 new_nodes->push_back(dim_idx_node); 446 new_nodes->push_back(expand_dims_node); 447 448 return Status::OK(); 449 }, 450 {true}, &replaced_graph_def)); 451 452 NodeDef* init_op = nullptr; 453 for (int i = 0; i < replaced_graph_def.node_size(); i++) { 454 if (replaced_graph_def.node(i).name() == group_init_node && 455 replaced_graph_def.node(i).op() == "NoOp") { 456 init_op = replaced_graph_def.mutable_node(i); 457 break; 458 } 459 } 460 if (!init_op) { 461 // Init node 462 init_op = replaced_graph_def.mutable_node()->Add(); 463 init_op->set_op("NoOp"); 464 init_op->set_name(group_init_node); 465 } 466 for (const string& name : init_table_node_names) { 467 // Add control dependence from init_table_node to group_deps_node 468 AddNodeInput(StrCat("^", name), init_op); 469 refs[name]++; 470 } 471 472 // Erase inputs and outputs as they are not considered for deletion. 473 for (const auto& output : context.output_names) { 474 refs.erase(output); 475 } 476 477 for (const auto& input : context.input_names) { 478 refs.erase(input); 479 } 480 481 // Add nodes with a reference count of 0 for deletion. 482 for (auto entry : refs) { 483 if (entry.second == 0) { 484 removed_node_names.push_back(entry.first); 485 } 486 } 487 488 while (!removed_node_names.empty()) { 489 auto name = removed_node_names.back(); 490 removed_node_names.pop_back(); 491 492 int i = 0; 493 while (i < replaced_graph_def.node_size()) { 494 // Revisit this to see if we can safely remove RestoreV2 nodes. 495 if ((replaced_graph_def.node(i).name() == name) && 496 (replaced_graph_def.node(i).op() != "RestoreV2")) { 497 for (const auto& input : replaced_graph_def.node(i).input()) { 498 auto parsed_input = StringReplace(input, "^", "", true); 499 refs[parsed_input] -= 1; 500 if (refs[parsed_input] == 0) { 501 removed_node_names.push_back(parsed_input); 502 } 503 } 504 TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i)); 505 continue; 506 } 507 int j = 0; 508 bool deleted_inputs = false; 509 while (j < replaced_graph_def.node(i).input_size()) { 510 if (replaced_graph_def.node(i).input(j) == name || 511 replaced_graph_def.node(i).input(j) == ("^" + name)) { 512 TF_RETURN_IF_ERROR( 513 RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j)); 514 deleted_inputs = true; 515 continue; 516 } 517 j++; 518 } 519 if (deleted_inputs) { 520 if (replaced_graph_def.node(i).op() == "ConcatV2") { 521 if (replaced_graph_def.node(i).input_size() > 2) { 522 SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1, 523 replaced_graph_def.mutable_node(i)); 524 } else if (replaced_graph_def.node(i).input_size() == 2) { 525 if (refs[replaced_graph_def.node(i).input(1)] != 1) { 526 return errors::Internal( 527 "Expect axis tensor of ConcatV2 node to only be referenced " 528 "once."); 529 } 530 refs[replaced_graph_def.node(i).input(1)] -= 1; 531 removed_node_names.push_back(replaced_graph_def.node(i).input(1)); 532 replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast(); 533 replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N"); 534 replaced_graph_def.mutable_node(i)->set_op("Identity"); 535 } else { 536 return errors::Internal( 537 "ConcatV2 should have at least two elements"); 538 } 539 } 540 if ((replaced_graph_def.node(i).op() == "Assign" || 541 replaced_graph_def.node(i).op() == "Reshape" || 542 replaced_graph_def.node(i).op() == "Equal" || 543 replaced_graph_def.node(i).op() == "Mean" || 544 replaced_graph_def.node(i).op() == "ScalarSummary") && 545 replaced_graph_def.node(i).input_size() == 1) { 546 removed_node_names.push_back(replaced_graph_def.node(i).name()); 547 } 548 if (!replaced_graph_def.node(i).input_size()) { 549 removed_node_names.push_back(replaced_graph_def.node(i).name()); 550 } 551 } 552 i++; 553 } 554 } 555 current_graph_def = replaced_graph_def; 556 } while (any_match_found); 557 *output_graph_def = current_graph_def; 558 return Status::OK(); 559} 560 561Status SparsifyGather(const GraphDef& input_graph_def, 562 const TransformFuncContext& context, 563 GraphDef* output_graph_def) { 564 // clang-format off 565 const OpTypePattern gather_pattern = 566 {"Gather", 567 { 568 {"Identity", 569 { 570 {"Const|Variable|VariableV2"} 571 } 572 }, 573 {"*"}, 574 } 575 }; 576 const OpTypePattern gather_v2_pattern = 577 {"GatherV2", 578 { 579 {"Identity", 580 { 581 {"Const|Variable|VariableV2"} 582 } 583 }, 584 {"*"}, 585 // GatherV2's axis must be constant. 586 {"Const"}, 587 } 588 }; 589 // clang-format on 590 591 GraphDef cleaned_input_graph_def; 592 RemoveAttributes(input_graph_def, {"_output_shapes"}, 593 &cleaned_input_graph_def); 594 595 GraphDef temp_output; 596 597 std::unique_ptr<BundleReader> ckpt_reader; 598 TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader)); 599 600 std::unique_ptr<std::unordered_map<string, string> > shapes_and_slices; 601 TF_RETURN_IF_ERROR( 602 ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices)); 603 604 TF_RETURN_IF_ERROR(SparsifyGatherInternal( 605 cleaned_input_graph_def, shapes_and_slices, context, gather_pattern, 606 ckpt_reader, &temp_output)); 607 608 TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices, 609 context, gather_v2_pattern, 610 ckpt_reader, output_graph_def)); 611 612 return Status::OK(); 613} 614 615REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather); 616 617} // namespace graph_transforms 618} // namespace tensorflow 619