1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cmath>
17#include <memory>
18#include <unordered_map>
19
20#include "tensorflow/c/checkpoint_reader.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/graph/graph_constructor.h"
23#include "tensorflow/core/graph/node_builder.h"
24#include "tensorflow/core/graph/subgraph.h"
25#include "tensorflow/core/platform/init_main.h"
26#include "tensorflow/core/public/session.h"
27#include "tensorflow/core/util/command_line_flags.h"
28#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
29#include "tensorflow/tools/graph_transforms/transform_utils.h"
30
31namespace tensorflow {
32using str_util::Join;
33using str_util::Split;
34using str_util::StringReplace;
35using strings::StrCat;
36
37namespace graph_transforms {
38
39// Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for
40// non-zero tensor content.
41Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor,
42                       Tensor* values_tensor) {
43  if (tensor.dims() != 2 || tensor.dim_size(1) != 1) {
44    return tensorflow::errors::FailedPrecondition(
45        "Transform only applicable to subgraph with 'Const' with "
46        "tensor of shape [N, 1]. But instead get shape ",
47        tensor.shape().DebugString(), ".");
48  }
49
50  auto flat = tensor.flat<float>();
51  std::vector<int64> indices;
52  std::vector<float> values;
53
54  for (int64 i = 0; i < flat.size(); i++) {
55    float val = flat(i);
56    if (std::abs(val) >= 1.0e-5) {
57      indices.push_back(i);
58      values.push_back(val);
59    }
60  }
61
62  // During model initialization, InitializeTableOp makes use of
63  // KeyValueTensorIterator, which does not accept empty keys or values.
64  // Consequently, adding a dummy pair of indices and values as a walkaround.
65  if (indices.empty() || values.empty()) {
66    indices.push_back(0);
67    values.push_back(0);
68  }
69  *indices_tensor = Tensor(DataTypeToEnum<int64>::value,
70                           {static_cast<int64>(indices.size())});
71  std::copy_n(indices.begin(), indices.size(),
72              indices_tensor->flat<int64>().data());
73
74  *values_tensor =
75      Tensor(DataTypeToEnum<float>::value, {static_cast<int64>(values.size())});
76  std::copy_n(values.begin(), values.size(),
77              values_tensor->flat<float>().data());
78
79  return Status::OK();
80}
81
82void CreateConstNode(const Tensor& tensor, const string& name,
83                     NodeDef* node_def) {
84  node_def->set_op("Const");
85  node_def->set_name(name);
86  SetNodeTensorAttr<float>("value", tensor, node_def);
87}
88
89string GetMonolithicTensorKey(const string& tensor_slice_name) {
90  std::vector<string> names = Split(tensor_slice_name, "/");
91  if (StringPiece(names[names.size() - 1]).starts_with("part_")) {
92    CHECK_GE(names.size(), 2);
93    names.pop_back();
94  }
95  return Join(names, "/");
96}
97
98Status ObtainTensorSlice(const GraphDef& input_graph_def,
99                         const string& target_name,
100                         string* shape_slice_string) {
101  string restore_node_name;
102  for (const auto& node : input_graph_def.node()) {
103    std::vector<string> node_name_parts = Split(node.name(), "/");
104    if (node_name_parts.size() == 2 &&
105        StringPiece(node_name_parts[0]).starts_with("save") &&
106        StringPiece(node_name_parts[1]).starts_with("Assign") &&
107        node.input(0) == target_name) {
108      restore_node_name = node.input(1);
109      break;
110    }
111  }
112
113  std::vector<string> restore_node_parts = Split(restore_node_name, ":");
114  CHECK_LE(restore_node_parts.size(), 2);
115  string tensor_names_node;
116  string shape_and_slices_node;
117  for (const auto& node : input_graph_def.node()) {
118    if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) {
119      tensor_names_node = node.input(1);
120      shape_and_slices_node = node.input(2);
121      break;
122    }
123  }
124
125  int offset = -1;
126  for (const auto& node : input_graph_def.node()) {
127    if (node.name() == tensor_names_node) {
128      Tensor tensor_names_tensor;
129      TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor));
130      const auto& tensor_names_value = tensor_names_tensor.flat<string>();
131      for (int i = 0; i < tensor_names_value.size(); i++) {
132        if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) {
133          offset = i;
134          break;
135        }
136      }
137    }
138  }
139  if (offset == -1) {
140    return errors::Internal("Unable to find RestoreV2 entry for variable: ",
141                            target_name);
142  }
143  for (const auto& node : input_graph_def.node()) {
144    if (node.name() == shape_and_slices_node) {
145      Tensor shape_and_slices_tensor;
146      TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor));
147      const auto& shape_and_slices_value =
148          shape_and_slices_tensor.flat<string>();
149      *shape_slice_string = shape_and_slices_value(offset);
150      return Status::OK();
151    }
152  }
153  return errors::Internal("Unable to find slice for variable: ", target_name);
154}
155
156Status ReadTensorFromCheckpoint(
157    const string& tensor_name, const std::unique_ptr<BundleReader>& ckpt_reader,
158    const string& shape_and_slice, Tensor* tensor) {
159  if (ckpt_reader) {
160    TensorShape parsed_full_shape;
161    TensorSlice parsed_slice;
162    TensorShape parsed_slice_shape;
163
164    bool get_slice = false;
165    if (!shape_and_slice.empty()) {
166      TF_RETURN_IF_ERROR(
167          checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
168                                         &parsed_slice, &parsed_slice_shape));
169      get_slice = (parsed_full_shape != parsed_slice_shape);
170    }
171    if (get_slice) {
172      TF_RETURN_IF_ERROR(ckpt_reader->LookupSlice(
173          GetMonolithicTensorKey(tensor_name), parsed_slice, tensor));
174    } else {
175      TF_RETURN_IF_ERROR(
176          ckpt_reader->Lookup(GetMonolithicTensorKey(tensor_name), tensor));
177    }
178    return Status::OK();
179  }
180  return errors::Internal("Checkpoint reader was not initialized. ");
181}
182
183Status InitializeCheckpointReader(const TransformFuncContext& context,
184                                  std::unique_ptr<BundleReader>* ckpt_reader) {
185  if (context.params.count("input_checkpoint")) {
186    const string input_checkpoint = context.params.at("input_checkpoint")[0];
187    ckpt_reader->reset(new BundleReader(Env::Default(), input_checkpoint));
188    TF_RETURN_IF_ERROR((*ckpt_reader)->status());
189  }
190  return Status::OK();
191}
192
193Status ObtainVariableInfo(
194    const GraphDef& input_graph_def,
195    std::unique_ptr<std::unordered_map<string, string> >* shapes_and_slices) {
196  shapes_and_slices->reset(new std::unordered_map<string, string>());
197  for (const auto& node : input_graph_def.node()) {
198    if ((node.op() == "Variable") || (node.op() == "VariableV2")) {
199      string s;
200      TF_RETURN_IF_ERROR(ObtainTensorSlice(input_graph_def, node.name(), &s));
201      (**shapes_and_slices)[node.name()] = s;
202    }
203  }
204  return Status::OK();
205}
206
207Status RemoveInputAtIndex(NodeDef* n, int index) {
208  for (int i = index; i < n->input_size() - 1; i++) {
209    n->mutable_input()->SwapElements(i, i + 1);
210  }
211  n->mutable_input()->RemoveLast();
212  return Status::OK();
213}
214
215Status RemoveNodeAtIndex(GraphDef* g, int index) {
216  for (int i = index; i < g->node_size() - 1; i++) {
217    g->mutable_node()->SwapElements(i, i + 1);
218  }
219  g->mutable_node()->RemoveLast();
220  return Status::OK();
221}
222
223Status SparsifyGatherInternal(
224    const GraphDef& input_graph_def,
225    const std::unique_ptr<std::unordered_map<string, string> >&
226        shapes_and_slices,
227    const TransformFuncContext& context, const OpTypePattern& pattern,
228    const std::unique_ptr<BundleReader>& ckpt_reader,
229    GraphDef* output_graph_def) {
230  string group_init_node = "group_deps";
231  if (context.params.count("group_init_node")) {
232    group_init_node = context.params.at("group_init_node")[0];
233  }
234  GraphDef current_graph_def = input_graph_def;
235  bool any_match_found = false;
236
237  // Populate references.
238  std::unordered_map<string, int> refs;
239  for (const auto& node : current_graph_def.node()) {
240    for (const auto& input : node.input()) {
241      auto parsed_input = StringReplace(input, "^", "", true);
242      refs[parsed_input] += 1;
243    }
244  }
245
246  // The subgraphs may have overlapping components, therefore GraphMatcher
247  // doesn't return all subgraphs in one round -- this has to be multi-round
248  // update.
249  do {
250    any_match_found = false;
251    GraphDef replaced_graph_def = current_graph_def;
252    std::vector<string> init_table_node_names;
253    std::vector<string> removed_node_names;
254
255    TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
256        current_graph_def, pattern,
257        [&ckpt_reader, &any_match_found, &init_table_node_names,
258         &shapes_and_slices, &removed_node_names,
259         &refs](const NodeMatch& match, const std::set<string>& input_nodes,
260                const std::set<string>& output_nodes,
261                std::vector<NodeDef>* new_nodes) {
262          any_match_found = true;
263
264          // The captured subgraph should be of the following pattern:
265          // Const --> Identity --> Gather --> ...
266          //                          ^
267          //                          |
268          //                        (ids)
269          //
270          // After transform, it becomes:
271          //                   --> NoOp(group_deps)
272          //                   |
273          // Const --> InitializeTable --> HashTable
274          //                   ^              |
275          //                   |              |
276          // Const -------------              |
277          //                                  v
278          //               (ids) ---> LookupTableFind <--- Const(default)
279          //                                  |
280          //                                  v
281          //                                 ...
282
283          // clang-format off
284          // For each subgraph, do the following
285          // 1. Sparsify the `Const`, creating two `Const`, for hashtable
286          // key/val.
287          // 2. Create a `InitializeTable` op connecting to the above 2 `Const`.
288          // 3. Create a `HashTable` op connecting to `InitializeTable` op.
289          // 4. Replace the `Gather` with a `LookupTableFind` op.
290          // 5. Connect the `LookupTableFind` with
291          //    a. `HashTable`
292          //    b. `Gather`'s ids input
293          //    c. a `default_val` arg, valued at 0
294          // clang-format on
295          const NodeDef& gather_node = match.node;
296
297          // GatherV2 adds an "axis" parameter. sparsify_gather only supports
298          // axis 0 gathers.
299          if (gather_node.op() == "GatherV2") {
300            // Per the OpTypePattern, the 3rd input to Gather must be a Const.
301            const NodeDef& axis_node = match.inputs[2].node;
302
303            Tensor axis_t;
304            TF_RETURN_IF_ERROR(GetNodeAttr(axis_node, "value", &axis_t));
305            int64 axis = 0;
306            if (axis_t.dtype() == DT_INT32) {
307              axis = axis_t.scalar<int32>()();
308            } else if (axis_t.dtype() == DT_INT64) {
309              axis = axis_t.scalar<int64>()();
310            } else {
311              return tensorflow::errors::FailedPrecondition(
312                  "Gather axis was not int32 or int64.");
313            }
314
315            if (axis != 0) {
316              return tensorflow::errors::FailedPrecondition(
317                  "Transform only applicable to subgraph with GatherV2 over "
318                  "axis 0. Found axis ",
319                  axis, ".");
320            }
321          }
322
323          const NodeDef& weights_node = match.inputs[0].inputs[0].node;
324
325          DataType data_type;
326          TF_RETURN_IF_ERROR(GetNodeAttr(weights_node, "dtype", &data_type));
327          if (data_type != DT_FLOAT) {
328            return tensorflow::errors::FailedPrecondition(
329                "Transform only applicable to subgraph with 'Const',"
330                "'Variable', or 'VariableV2' of dtype "
331                "'DT_FLOAT'. Found '" +
332                    weights_node.op() + "' with name '",
333                weights_node.name(), "' and dtype '", data_type, "'.");
334          }
335
336          Tensor weight;
337          if (weights_node.op() == "Const") {
338            weight = GetNodeTensorAttr(weights_node, "value");
339          } else {
340            TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint(
341                weights_node.name(), ckpt_reader,
342                (*shapes_and_slices)[weights_node.name()], &weight));
343          }
344          // Add both both weight and identity node names.
345          removed_node_names.push_back(weights_node.name());
346          removed_node_names.push_back(match.inputs[0].node.name());
347          for (auto input_node : match.inputs[0].node.input()) {
348            auto parsed_input = StringReplace(input_node, "^", "", true);
349            refs[parsed_input]--;
350          }
351          Tensor indices_tensor;
352          Tensor values_tensor;
353          TF_RETURN_IF_ERROR(
354              SparsifyWeights(weight, &indices_tensor, &values_tensor));
355
356          // indices and values of sparsified `Const`
357          DataType key_dtype = DT_INT64;
358          NodeDef indices_node;
359          CreateConstNode(indices_tensor,
360                          StrCat(weights_node.name(), "/indices"),
361                          &indices_node);
362          SetNodeAttr("dtype", key_dtype, &indices_node);
363
364          NodeDef values_node;
365          CreateConstNode(values_tensor, StrCat(weights_node.name(), "/values"),
366                          &values_node);
367          SetNodeAttr("dtype", data_type, &values_node);
368
369          // HashTable node
370          NodeDef hashtable_node;
371          hashtable_node.set_op("HashTable");
372          hashtable_node.set_name(StrCat(weights_node.name(), "/HashTable"));
373          SetNodeAttr("key_dtype", key_dtype, &hashtable_node);
374          SetNodeAttr("value_dtype", data_type, &hashtable_node);
375
376          // InitializeTable node
377          NodeDef init_table_node;
378          init_table_node.set_op("InitializeTable");
379          init_table_node.set_name(
380              StrCat(weights_node.name(), "/InitializeTable"));
381          SetNodeAttr("Tkey", key_dtype, &init_table_node);
382          SetNodeAttr("Tval", data_type, &init_table_node);
383          init_table_node_names.push_back(init_table_node.name());
384
385          // LookupTableFind node
386          NodeDef lookup_node;
387          lookup_node.set_op("LookupTableFind");
388          lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind"));
389          SetNodeAttr("Tin", key_dtype, &lookup_node);
390          SetNodeAttr("Tout", data_type, &lookup_node);
391
392          // Default return value of hashtable lookup
393          Tensor zero_tensor(data_type, TensorShape({}));
394          zero_tensor.flat<float>()(0) = 0.0;
395          NodeDef default_value_node;
396          CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"),
397                          &default_value_node);
398          SetNodeAttr("dtype", data_type, &default_value_node);
399
400          // ExpandDims argument
401          Tensor dim_idx(DT_INT32, TensorShape({}));
402          dim_idx.flat<int32>()(0) = -1;
403          NodeDef dim_idx_node;
404          dim_idx_node.set_op("Const");
405          dim_idx_node.set_name(
406              StrCat(gather_node.name(), "/ExpandDims/Const"));
407          SetNodeAttr("value", dim_idx, &dim_idx_node);
408          SetNodeAttr("dtype", DT_INT32, &dim_idx_node);
409
410          // ExpandDims node
411          NodeDef expand_dims_node;
412          expand_dims_node.set_op("ExpandDims");
413          // Reuse gather_node's name so not to change dependent's inputs
414          expand_dims_node.set_name(gather_node.name());
415          SetNodeAttr("T", data_type, &expand_dims_node);
416
417          // Connect nodes
418          AddNodeInput(hashtable_node.name(), &init_table_node);
419          refs[hashtable_node.name()]++;
420          AddNodeInput(indices_node.name(), &init_table_node);
421          refs[indices_node.name()]++;
422          AddNodeInput(values_node.name(), &init_table_node);
423          refs[values_node.name()]++;
424
425          AddNodeInput(hashtable_node.name(), &lookup_node);
426          refs[hashtable_node.name()]++;
427          AddNodeInput(gather_node.input(1), &lookup_node);
428          refs[gather_node.input(1)]++;
429          AddNodeInput(default_value_node.name(), &lookup_node);
430          refs[default_value_node.name()]++;
431
432          AddNodeInput(lookup_node.name(), &expand_dims_node);
433          refs[lookup_node.name()]++;
434          AddNodeInput(dim_idx_node.name(), &expand_dims_node);
435          refs[dim_idx_node.name()]++;
436
437          // Copy 'ids' input of original 'Gather'
438          new_nodes->push_back(match.inputs[1].node);
439          new_nodes->push_back(indices_node);
440          new_nodes->push_back(values_node);
441          new_nodes->push_back(hashtable_node);
442          new_nodes->push_back(init_table_node);
443          new_nodes->push_back(lookup_node);
444          new_nodes->push_back(default_value_node);
445          new_nodes->push_back(dim_idx_node);
446          new_nodes->push_back(expand_dims_node);
447
448          return Status::OK();
449        },
450        {true}, &replaced_graph_def));
451
452    NodeDef* init_op = nullptr;
453    for (int i = 0; i < replaced_graph_def.node_size(); i++) {
454      if (replaced_graph_def.node(i).name() == group_init_node &&
455          replaced_graph_def.node(i).op() == "NoOp") {
456        init_op = replaced_graph_def.mutable_node(i);
457        break;
458      }
459    }
460    if (!init_op) {
461      // Init node
462      init_op = replaced_graph_def.mutable_node()->Add();
463      init_op->set_op("NoOp");
464      init_op->set_name(group_init_node);
465    }
466    for (const string& name : init_table_node_names) {
467      // Add control dependence from init_table_node to group_deps_node
468      AddNodeInput(StrCat("^", name), init_op);
469      refs[name]++;
470    }
471
472    // Erase inputs and outputs as they are not considered for deletion.
473    for (const auto& output : context.output_names) {
474      refs.erase(output);
475    }
476
477    for (const auto& input : context.input_names) {
478      refs.erase(input);
479    }
480
481    // Add nodes with a reference count of 0 for deletion.
482    for (auto entry : refs) {
483      if (entry.second == 0) {
484        removed_node_names.push_back(entry.first);
485      }
486    }
487
488    while (!removed_node_names.empty()) {
489      auto name = removed_node_names.back();
490      removed_node_names.pop_back();
491
492      int i = 0;
493      while (i < replaced_graph_def.node_size()) {
494        // Revisit this to see if we can safely remove RestoreV2 nodes.
495        if ((replaced_graph_def.node(i).name() == name) &&
496            (replaced_graph_def.node(i).op() != "RestoreV2")) {
497          for (const auto& input : replaced_graph_def.node(i).input()) {
498            auto parsed_input = StringReplace(input, "^", "", true);
499            refs[parsed_input] -= 1;
500            if (refs[parsed_input] == 0) {
501              removed_node_names.push_back(parsed_input);
502            }
503          }
504          TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i));
505          continue;
506        }
507        int j = 0;
508        bool deleted_inputs = false;
509        while (j < replaced_graph_def.node(i).input_size()) {
510          if (replaced_graph_def.node(i).input(j) == name ||
511              replaced_graph_def.node(i).input(j) == ("^" + name)) {
512            TF_RETURN_IF_ERROR(
513                RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j));
514            deleted_inputs = true;
515            continue;
516          }
517          j++;
518        }
519        if (deleted_inputs) {
520          if (replaced_graph_def.node(i).op() == "ConcatV2") {
521            if (replaced_graph_def.node(i).input_size() > 2) {
522              SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1,
523                          replaced_graph_def.mutable_node(i));
524            } else if (replaced_graph_def.node(i).input_size() == 2) {
525              if (refs[replaced_graph_def.node(i).input(1)] != 1) {
526                return errors::Internal(
527                    "Expect axis tensor of ConcatV2 node to only be referenced "
528                    "once.");
529              }
530              refs[replaced_graph_def.node(i).input(1)] -= 1;
531              removed_node_names.push_back(replaced_graph_def.node(i).input(1));
532              replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast();
533              replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N");
534              replaced_graph_def.mutable_node(i)->set_op("Identity");
535            } else {
536              return errors::Internal(
537                  "ConcatV2 should have at least two elements");
538            }
539          }
540          if ((replaced_graph_def.node(i).op() == "Assign" ||
541               replaced_graph_def.node(i).op() == "Reshape" ||
542               replaced_graph_def.node(i).op() == "Equal" ||
543               replaced_graph_def.node(i).op() == "Mean" ||
544               replaced_graph_def.node(i).op() == "ScalarSummary") &&
545              replaced_graph_def.node(i).input_size() == 1) {
546            removed_node_names.push_back(replaced_graph_def.node(i).name());
547          }
548          if (!replaced_graph_def.node(i).input_size()) {
549            removed_node_names.push_back(replaced_graph_def.node(i).name());
550          }
551        }
552        i++;
553      }
554    }
555    current_graph_def = replaced_graph_def;
556  } while (any_match_found);
557  *output_graph_def = current_graph_def;
558  return Status::OK();
559}
560
561Status SparsifyGather(const GraphDef& input_graph_def,
562                      const TransformFuncContext& context,
563                      GraphDef* output_graph_def) {
564  // clang-format off
565  const OpTypePattern gather_pattern =
566    {"Gather",
567     {
568       {"Identity",
569        {
570          {"Const|Variable|VariableV2"}
571        }
572       },
573       {"*"},
574     }
575    };
576  const OpTypePattern gather_v2_pattern =
577    {"GatherV2",
578      {
579        {"Identity",
580          {
581            {"Const|Variable|VariableV2"}
582          }
583        },
584        {"*"},
585        // GatherV2's axis must be constant.
586        {"Const"},
587      }
588    };
589  // clang-format on
590
591  GraphDef cleaned_input_graph_def;
592  RemoveAttributes(input_graph_def, {"_output_shapes"},
593                   &cleaned_input_graph_def);
594
595  GraphDef temp_output;
596
597  std::unique_ptr<BundleReader> ckpt_reader;
598  TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader));
599
600  std::unique_ptr<std::unordered_map<string, string> > shapes_and_slices;
601  TF_RETURN_IF_ERROR(
602      ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices));
603
604  TF_RETURN_IF_ERROR(SparsifyGatherInternal(
605      cleaned_input_graph_def, shapes_and_slices, context, gather_pattern,
606      ckpt_reader, &temp_output));
607
608  TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices,
609                                            context, gather_v2_pattern,
610                                            ckpt_reader, output_graph_def));
611
612  return Status::OK();
613}
614
615REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather);
616
617}  // namespace graph_transforms
618}  // namespace tensorflow
619