1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Removes unneeded nodes from a GraphDef file.
16
17This script is designed to help streamline models, by taking the input and
18output nodes that will be used by an application and figuring out the smallest
19set of operations that are required to run for those arguments. The resulting
20minimal graph is then saved out.
21
22The advantages of running this script are:
23 - You may be able to shrink the file size.
24 - Operations that are unsupported on your platform but still present can be
25   safely removed.
26The resulting graph may not be as flexible as the original though, since any
27input nodes that weren't explicitly mentioned may not be accessible any more.
28
29An example of command-line usage is:
30bazel build tensorflow/python/tools:strip_unused && \
31bazel-bin/tensorflow/python/tools/strip_unused \
32--input_graph=some_graph_def.pb \
33--output_graph=/tmp/stripped_graph.pb \
34--input_node_names=input0
35--output_node_names=softmax
36
37You can also look at strip_unused_test.py for an example of how to use it.
38
39"""
40from __future__ import absolute_import
41from __future__ import division
42from __future__ import print_function
43
44import argparse
45import sys
46
47from tensorflow.python.framework import dtypes
48from tensorflow.python.platform import app
49from tensorflow.python.tools import strip_unused_lib
50
51FLAGS = None
52
53
54def main(unused_args):
55  strip_unused_lib.strip_unused_from_files(FLAGS.input_graph,
56                                           FLAGS.input_binary,
57                                           FLAGS.output_graph,
58                                           FLAGS.output_binary,
59                                           FLAGS.input_node_names,
60                                           FLAGS.output_node_names,
61                                           FLAGS.placeholder_type_enum)
62
63
64if __name__ == '__main__':
65  parser = argparse.ArgumentParser()
66  parser.register('type', 'bool', lambda v: v.lower() == 'true')
67  parser.add_argument(
68      '--input_graph',
69      type=str,
70      default='',
71      help='TensorFlow \'GraphDef\' file to load.')
72  parser.add_argument(
73      '--input_binary',
74      nargs='?',
75      const=True,
76      type='bool',
77      default=False,
78      help='Whether the input files are in binary format.')
79  parser.add_argument(
80      '--output_graph',
81      type=str,
82      default='',
83      help='Output \'GraphDef\' file name.')
84  parser.add_argument(
85      '--output_binary',
86      nargs='?',
87      const=True,
88      type='bool',
89      default=True,
90      help='Whether to write a binary format graph.')
91  parser.add_argument(
92      '--input_node_names',
93      type=str,
94      default='',
95      help='The name of the input nodes, comma separated.')
96  parser.add_argument(
97      '--output_node_names',
98      type=str,
99      default='',
100      help='The name of the output nodes, comma separated.')
101  parser.add_argument(
102      '--placeholder_type_enum',
103      type=int,
104      default=dtypes.float32.as_datatype_enum,
105      help='The AttrValue enum to use for placeholders.')
106  FLAGS, unparsed = parser.parse_known_args()
107  app.run(main=main, argv=[sys.argv[0]] + unparsed)
108