1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A simple script for inspect checkpoint files."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import sys
23
24from tensorflow.contrib.framework.python.framework import checkpoint_utils
25from tensorflow.python.platform import app
26
27FLAGS = None
28
29
30def print_tensors_in_checkpoint_file(file_name, tensor_name):
31  """Prints tensors in a checkpoint file.
32
33  If no `tensor_name` is provided, prints the tensor names and shapes
34  in the checkpoint file.
35
36  If `tensor_name` is provided, prints the content of the tensor.
37
38  Args:
39    file_name: Name of the checkpoint file.
40    tensor_name: Name of the tensor in the checkpoint file to print.
41  """
42  try:
43    if not tensor_name:
44      variables = checkpoint_utils.list_variables(file_name)
45      for name, shape in variables:
46        print("%s\t%s" % (name, str(shape)))
47    else:
48      print("tensor_name: ", tensor_name)
49      print(checkpoint_utils.load_variable(file_name, tensor_name))
50  except Exception as e:  # pylint: disable=broad-except
51    print(str(e))
52    if "corrupted compressed block contents" in str(e):
53      print("It's likely that your checkpoint file has been compressed "
54            "with SNAPPY.")
55
56
57def main(unused_argv):
58  if not FLAGS.file_name:
59    print("Usage: inspect_checkpoint --file_name=<checkpoint_file_name "
60          "or directory> [--tensor_name=tensor_to_print]")
61    sys.exit(1)
62  else:
63    print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name)
64
65
66if __name__ == "__main__":
67  parser = argparse.ArgumentParser()
68  parser.register("type", "bool", lambda v: v.lower() == "true")
69  parser.add_argument(
70      "--file_name",
71      type=str,
72      default="",
73      help="Checkpoint filename"
74  )
75  parser.add_argument(
76      "--tensor_name",
77      type=str,
78      default="",
79      help="Name of the tensor to inspect"
80  )
81  FLAGS, unparsed = parser.parse_known_args()
82  app.run(main=main, argv=[sys.argv[0]] + unparsed)
83