1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ================================ 15"""Imports a protobuf model as a graph in Tensorboard.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import sys 23 24from tensorflow.core.framework import graph_pb2 25from tensorflow.python.client import session 26from tensorflow.python.framework import importer 27from tensorflow.python.framework import ops 28from tensorflow.python.platform import app 29from tensorflow.python.platform import gfile 30from tensorflow.python.summary import summary 31 32 33def import_to_tensorboard(model_dir, log_dir): 34 """View an imported protobuf model (`.pb` file) as a graph in Tensorboard. 35 36 Args: 37 model_dir: The location of the protobuf (`pb`) model to visualize 38 log_dir: The location for the Tensorboard log to begin visualization from. 39 40 Usage: 41 Call this function with your model location and desired log directory. 42 Launch Tensorboard by pointing it to the log directory. 43 View your imported `.pb` model as a graph. 44 """ 45 with session.Session(graph=ops.Graph()) as sess: 46 with gfile.FastGFile(model_dir, "rb") as f: 47 graph_def = graph_pb2.GraphDef() 48 graph_def.ParseFromString(f.read()) 49 importer.import_graph_def(graph_def) 50 51 pb_visual_writer = summary.FileWriter(log_dir) 52 pb_visual_writer.add_graph(sess.graph) 53 print("Model Imported. Visualize by running: " 54 "tensorboard --logdir={}".format(log_dir)) 55 56 57def main(unused_args): 58 import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir) 59 60if __name__ == "__main__": 61 parser = argparse.ArgumentParser() 62 parser.register("type", "bool", lambda v: v.lower() == "true") 63 parser.add_argument( 64 "--model_dir", 65 type=str, 66 default="", 67 required=True, 68 help="The location of the protobuf (\'pb\') model to visualize.") 69 parser.add_argument( 70 "--log_dir", 71 type=str, 72 default="", 73 required=True, 74 help="The location for the Tensorboard log to begin visualization from.") 75 FLAGS, unparsed = parser.parse_known_args() 76 app.run(main=main, argv=[sys.argv[0]] + unparsed) 77