1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Computes Receptive Field (RF) information given a graph protobuf.
16
17For an example of usage, see accompanying file compute_rf.sh
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import argparse
25import sys
26
27from google.protobuf import text_format
28
29from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field
30from tensorflow.core.framework import graph_pb2
31from tensorflow.python.platform import app
32from tensorflow.python.platform import gfile
33from tensorflow.python.platform import tf_logging as logging
34
35cmd_args = None
36
37
38def _load_graphdef(path):
39  """Helper function to load GraphDef from file.
40
41  Args:
42    path: Path to pbtxt file.
43
44  Returns:
45    graph_def: A GraphDef object.
46  """
47  graph_def = graph_pb2.GraphDef()
48  pbstr = gfile.Open(path).read()
49  text_format.Parse(pbstr, graph_def)
50  return graph_def
51
52
53def main(unused_argv):
54
55  graph_def = _load_graphdef(cmd_args.graph_path)
56
57  (receptive_field_x, receptive_field_y, effective_stride_x, effective_stride_y,
58   effective_padding_x, effective_padding_y
59  ) = receptive_field.compute_receptive_field_from_graph_def(
60      graph_def, cmd_args.input_node, cmd_args.output_node)
61
62  logging.info('Receptive field size (horizontal) = %s', receptive_field_x)
63  logging.info('Receptive field size (vertical) = %s', receptive_field_y)
64  logging.info('Effective stride (horizontal) = %s', effective_stride_x)
65  logging.info('Effective stride (vertical) = %s', effective_stride_y)
66  logging.info('Effective padding (horizontal) = %s', effective_padding_x)
67  logging.info('Effective padding (vertical) = %s', effective_padding_y)
68
69  f = gfile.GFile('%s' % cmd_args.output_path, 'w')
70  f.write('Receptive field size (horizontal) = %s\n' % receptive_field_x)
71  f.write('Receptive field size (vertical) = %s\n' % receptive_field_y)
72  f.write('Effective stride (horizontal) = %s\n' % effective_stride_x)
73  f.write('Effective stride (vertical) = %s\n' % effective_stride_y)
74  f.write('Effective padding (horizontal) = %s\n' % effective_padding_x)
75  f.write('Effective padding (vertical) = %s\n' % effective_padding_y)
76  f.close()
77
78
79if __name__ == '__main__':
80  parser = argparse.ArgumentParser()
81  parser.register('type', 'bool', lambda v: v.lower() == 'true')
82  parser.add_argument(
83      '--graph_path', type=str, default='', help='Graph path (pbtxt format).')
84  parser.add_argument(
85      '--output_path',
86      type=str,
87      default='',
88      help='Path to output text file where RF information will be written to.')
89  parser.add_argument(
90      '--input_node', type=str, default='', help='Name of input node.')
91  parser.add_argument(
92      '--output_node', type=str, default='', help='Name of output node.')
93  cmd_args, unparsed = parser.parse_known_args()
94  app.run(main=main, argv=[sys.argv[0]] + unparsed)
95