1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ================================
15"""Imports a protobuf model as a graph in Tensorboard."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import sys
23
24from tensorflow.core.framework import graph_pb2
25from tensorflow.python.client import session
26from tensorflow.python.framework import importer
27from tensorflow.python.framework import ops
28from tensorflow.python.platform import app
29from tensorflow.python.platform import gfile
30from tensorflow.python.summary import summary
31
32
33def import_to_tensorboard(model_dir, log_dir):
34  """View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
35
36  Args:
37    model_dir: The location of the protobuf (`pb`) model to visualize
38    log_dir: The location for the Tensorboard log to begin visualization from.
39
40  Usage:
41    Call this function with your model location and desired log directory.
42    Launch Tensorboard by pointing it to the log directory.
43    View your imported `.pb` model as a graph.
44  """
45  with session.Session(graph=ops.Graph()) as sess:
46    with gfile.FastGFile(model_dir, "rb") as f:
47      graph_def = graph_pb2.GraphDef()
48      graph_def.ParseFromString(f.read())
49      importer.import_graph_def(graph_def)
50
51    pb_visual_writer = summary.FileWriter(log_dir)
52    pb_visual_writer.add_graph(sess.graph)
53    print("Model Imported. Visualize by running: "
54          "tensorboard --logdir={}".format(log_dir))
55
56
57def main(unused_args):
58  import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir)
59
60if __name__ == "__main__":
61  parser = argparse.ArgumentParser()
62  parser.register("type", "bool", lambda v: v.lower() == "true")
63  parser.add_argument(
64      "--model_dir",
65      type=str,
66      default="",
67      required=True,
68      help="The location of the protobuf (\'pb\') model to visualize.")
69  parser.add_argument(
70      "--log_dir",
71      type=str,
72      default="",
73      required=True,
74      help="The location for the Tensorboard log to begin visualization from.")
75  FLAGS, unparsed = parser.parse_known_args()
76  app.run(main=main, argv=[sys.argv[0]] + unparsed)
77