1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Generates and prints out imports and constants for new TensorFlow python api.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import collections
23import os
24import sys
25
26# This import is needed so that we can traverse over TensorFlow modules.
27import tensorflow as tf  # pylint: disable=unused-import
28from tensorflow.python.util import tf_decorator
29
30
31_API_CONSTANTS_ATTR = '_tf_api_constants'
32_API_NAMES_ATTR = '_tf_api_names'
33_API_DIR = '/api/'
34_CONTRIB_IMPORT = 'from tensorflow import contrib'
35_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
36
37This file is MACHINE GENERATED! Do not edit.
38Generated by: tensorflow/tools/api/generator/create_python_api.py script.
39\"\"\"
40"""
41
42
43def format_import(source_module_name, source_name, dest_name):
44  """Formats import statement.
45
46  Args:
47    source_module_name: (string) Source module to import from.
48    source_name: (string) Source symbol name to import.
49    dest_name: (string) Destination alias name.
50
51  Returns:
52    An import statement string.
53  """
54  if source_module_name:
55    if source_name == dest_name:
56      return 'from %s import %s' % (source_module_name, source_name)
57    else:
58      return 'from %s import %s as %s' % (
59          source_module_name, source_name, dest_name)
60  else:
61    if source_name == dest_name:
62      return 'import %s' % source_name
63    else:
64      return 'import %s as %s' % (source_name, dest_name)
65
66
67def get_api_imports():
68  """Get a map from destination module to formatted imports.
69
70  Returns:
71    A dictionary where
72      key: (string) destination module (for e.g. tf or tf.consts).
73      value: List of strings representing module imports
74          (for e.g. 'from foo import bar') and constant
75          assignments (for e.g. 'FOO = 123').
76  """
77  module_imports = collections.defaultdict(list)
78  # Traverse over everything imported above. Specifically,
79  # we want to traverse over TensorFlow Python modules.
80  for module in sys.modules.values():
81    # Only look at tensorflow modules.
82    if not module or 'tensorflow.' not in module.__name__:
83      continue
84    # Do not generate __init__.py files for contrib modules for now.
85    if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
86      continue
87
88    for module_contents_name in dir(module):
89      attr = getattr(module, module_contents_name)
90
91      # If attr is _tf_api_constants attribute, then add the constants.
92      if module_contents_name == _API_CONSTANTS_ATTR:
93        for exports, value in attr:
94          for export in exports:
95            names = ['tf'] + export.split('.')
96            dest_module = '.'.join(names[:-1])
97            import_str = format_import(module.__name__, value, names[-1])
98            module_imports[dest_module].append(import_str)
99        continue
100
101      _, attr = tf_decorator.unwrap(attr)
102      # If attr is a symbol with _tf_api_names attribute, then
103      # add import for it.
104      if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
105        # The same op might be accessible from multiple modules.
106        # We only want to consider location where function was defined.
107        if attr.__module__ != module.__name__:
108          continue
109
110        for export in attr._tf_api_names:  # pylint: disable=protected-access
111          names = ['tf'] + export.split('.')
112          dest_module = '.'.join(names[:-1])
113          import_str = format_import(
114              module.__name__, module_contents_name, names[-1])
115          module_imports[dest_module].append(import_str)
116
117  # Import all required modules in their parent modules.
118  # For e.g. if we import 'tf.foo.bar.Value'. Then, we also
119  # import 'bar' in 'tf.foo'.
120  dest_modules = set(module_imports.keys())
121  for dest_module in dest_modules:
122    dest_module_split = dest_module.split('.')
123    for dest_submodule_index in range(1, len(dest_module_split)):
124      dest_submodule = '.'.join(dest_module_split[:dest_submodule_index])
125      submodule_import = format_import(
126          '', dest_module_split[dest_submodule_index],
127          dest_module_split[dest_submodule_index])
128      if submodule_import not in module_imports[dest_submodule]:
129        module_imports[dest_submodule].append(submodule_import)
130
131  return module_imports
132
133
134def create_api_files(output_files):
135  """Creates __init__.py files for the Python API.
136
137  Args:
138    output_files: List of __init__.py file paths to create.
139      Each file must be under api/ directory.
140
141  Raises:
142    ValueError: if an output file is not under api/ directory,
143      or output_files list is missing a required file.
144  """
145  module_name_to_file_path = {}
146  for output_file in output_files:
147    if _API_DIR not in output_file:
148      raise ValueError(
149          'Output files must be in api/ directory, found %s.' % output_file)
150    # Get the module name that corresponds to output_file.
151    # First get module directory under _API_DIR.
152    module_dir = os.path.dirname(
153        output_file[output_file.rfind(_API_DIR)+len(_API_DIR):])
154    # Convert / to . and prefix with tf.
155    module_name = '.'.join(['tf', module_dir.replace('/', '.')]).strip('.')
156    module_name_to_file_path[module_name] = output_file
157
158  # Create file for each expected output in genrule.
159  for module, file_path in module_name_to_file_path.items():
160    if not os.path.isdir(os.path.dirname(file_path)):
161      os.makedirs(os.path.dirname(file_path))
162    open(file_path, 'a').close()
163
164  module_imports = get_api_imports()
165  module_imports['tf'].append(_CONTRIB_IMPORT)  # Include all of contrib.
166
167  # Add imports to output files.
168  missing_output_files = []
169  for module, exports in module_imports.items():
170    # Make sure genrule output file list is in sync with API exports.
171    if module not in module_name_to_file_path:
172      module_without_tf = module[len('tf.'):]
173      module_file_path = '"api/%s/__init__.py"' %  (
174          module_without_tf.replace('.', '/'))
175      missing_output_files.append(module_file_path)
176      continue
177    with open(module_name_to_file_path[module], 'w') as fp:
178      fp.write(_GENERATED_FILE_HEADER + '\n'.join(exports))
179
180  if missing_output_files:
181    raise ValueError(
182        'Missing outputs for python_api_gen genrule:\n%s.'
183        'Make sure all required outputs are in the '
184        'tensorflow/tools/api/generator/BUILD file.' %
185        ',\n'.join(sorted(missing_output_files)))
186
187
188def main(output_files):
189  create_api_files(output_files)
190
191if __name__ == '__main__':
192  parser = argparse.ArgumentParser()
193  parser.add_argument(
194      'outputs', metavar='O', type=str, nargs='+',
195      help='Python files that we expect this script to output.')
196  args = parser.parse_args()
197  main(args.outputs)
198