1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Simple script to write Inception-ResNet-v2 model to graph file.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import argparse
23import sys
24
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import graph_io
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.platform import app
30from nets import inception
31
32cmd_args = None
33
34
35def main(unused_argv):
36  # Model definition.
37  g = ops.Graph()
38  with g.as_default():
39    images = array_ops.placeholder(
40        dtypes.float32, shape=(1, None, None, 3), name='input_image')
41    inception.inception_resnet_v2_base(images)
42
43  graph_io.write_graph(g.as_graph_def(), cmd_args.graph_dir,
44                       cmd_args.graph_filename)
45
46
47if __name__ == '__main__':
48  parser = argparse.ArgumentParser()
49  parser.register('type', 'bool', lambda v: v.lower() == 'true')
50  parser.add_argument(
51      '--graph_dir',
52      type=str,
53      default='/tmp',
54      help='Directory where graph will be saved.')
55  parser.add_argument(
56      '--graph_filename',
57      type=str,
58      default='graph.pbtxt',
59      help='Filename of graph that will be saved.')
60  cmd_args, unparsed = parser.parse_known_args()
61  app.run(main=main, argv=[sys.argv[0]] + unparsed)
62