1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Generates and prints out imports and constants for new TensorFlow python api. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import collections 23import os 24import sys 25 26# This import is needed so that we can traverse over TensorFlow modules. 27import tensorflow as tf # pylint: disable=unused-import 28from tensorflow.python.util import tf_decorator 29 30 31_API_CONSTANTS_ATTR = '_tf_api_constants' 32_API_NAMES_ATTR = '_tf_api_names' 33_API_DIR = '/api/' 34_CONTRIB_IMPORT = 'from tensorflow import contrib' 35_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API. 36 37This file is MACHINE GENERATED! Do not edit. 38Generated by: tensorflow/tools/api/generator/create_python_api.py script. 39\"\"\" 40""" 41 42 43def format_import(source_module_name, source_name, dest_name): 44 """Formats import statement. 45 46 Args: 47 source_module_name: (string) Source module to import from. 48 source_name: (string) Source symbol name to import. 49 dest_name: (string) Destination alias name. 50 51 Returns: 52 An import statement string. 53 """ 54 if source_module_name: 55 if source_name == dest_name: 56 return 'from %s import %s' % (source_module_name, source_name) 57 else: 58 return 'from %s import %s as %s' % ( 59 source_module_name, source_name, dest_name) 60 else: 61 if source_name == dest_name: 62 return 'import %s' % source_name 63 else: 64 return 'import %s as %s' % (source_name, dest_name) 65 66 67def get_api_imports(): 68 """Get a map from destination module to formatted imports. 69 70 Returns: 71 A dictionary where 72 key: (string) destination module (for e.g. tf or tf.consts). 73 value: List of strings representing module imports 74 (for e.g. 'from foo import bar') and constant 75 assignments (for e.g. 'FOO = 123'). 76 """ 77 module_imports = collections.defaultdict(list) 78 # Traverse over everything imported above. Specifically, 79 # we want to traverse over TensorFlow Python modules. 80 for module in sys.modules.values(): 81 # Only look at tensorflow modules. 82 if not module or 'tensorflow.' not in module.__name__: 83 continue 84 # Do not generate __init__.py files for contrib modules for now. 85 if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'): 86 continue 87 88 for module_contents_name in dir(module): 89 attr = getattr(module, module_contents_name) 90 91 # If attr is _tf_api_constants attribute, then add the constants. 92 if module_contents_name == _API_CONSTANTS_ATTR: 93 for exports, value in attr: 94 for export in exports: 95 names = ['tf'] + export.split('.') 96 dest_module = '.'.join(names[:-1]) 97 import_str = format_import(module.__name__, value, names[-1]) 98 module_imports[dest_module].append(import_str) 99 continue 100 101 _, attr = tf_decorator.unwrap(attr) 102 # If attr is a symbol with _tf_api_names attribute, then 103 # add import for it. 104 if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__: 105 # The same op might be accessible from multiple modules. 106 # We only want to consider location where function was defined. 107 if attr.__module__ != module.__name__: 108 continue 109 110 for export in attr._tf_api_names: # pylint: disable=protected-access 111 names = ['tf'] + export.split('.') 112 dest_module = '.'.join(names[:-1]) 113 import_str = format_import( 114 module.__name__, module_contents_name, names[-1]) 115 module_imports[dest_module].append(import_str) 116 117 # Import all required modules in their parent modules. 118 # For e.g. if we import 'tf.foo.bar.Value'. Then, we also 119 # import 'bar' in 'tf.foo'. 120 dest_modules = set(module_imports.keys()) 121 for dest_module in dest_modules: 122 dest_module_split = dest_module.split('.') 123 for dest_submodule_index in range(1, len(dest_module_split)): 124 dest_submodule = '.'.join(dest_module_split[:dest_submodule_index]) 125 submodule_import = format_import( 126 '', dest_module_split[dest_submodule_index], 127 dest_module_split[dest_submodule_index]) 128 if submodule_import not in module_imports[dest_submodule]: 129 module_imports[dest_submodule].append(submodule_import) 130 131 return module_imports 132 133 134def create_api_files(output_files): 135 """Creates __init__.py files for the Python API. 136 137 Args: 138 output_files: List of __init__.py file paths to create. 139 Each file must be under api/ directory. 140 141 Raises: 142 ValueError: if an output file is not under api/ directory, 143 or output_files list is missing a required file. 144 """ 145 module_name_to_file_path = {} 146 for output_file in output_files: 147 if _API_DIR not in output_file: 148 raise ValueError( 149 'Output files must be in api/ directory, found %s.' % output_file) 150 # Get the module name that corresponds to output_file. 151 # First get module directory under _API_DIR. 152 module_dir = os.path.dirname( 153 output_file[output_file.rfind(_API_DIR)+len(_API_DIR):]) 154 # Convert / to . and prefix with tf. 155 module_name = '.'.join(['tf', module_dir.replace('/', '.')]).strip('.') 156 module_name_to_file_path[module_name] = output_file 157 158 # Create file for each expected output in genrule. 159 for module, file_path in module_name_to_file_path.items(): 160 if not os.path.isdir(os.path.dirname(file_path)): 161 os.makedirs(os.path.dirname(file_path)) 162 open(file_path, 'a').close() 163 164 module_imports = get_api_imports() 165 module_imports['tf'].append(_CONTRIB_IMPORT) # Include all of contrib. 166 167 # Add imports to output files. 168 missing_output_files = [] 169 for module, exports in module_imports.items(): 170 # Make sure genrule output file list is in sync with API exports. 171 if module not in module_name_to_file_path: 172 module_without_tf = module[len('tf.'):] 173 module_file_path = '"api/%s/__init__.py"' % ( 174 module_without_tf.replace('.', '/')) 175 missing_output_files.append(module_file_path) 176 continue 177 with open(module_name_to_file_path[module], 'w') as fp: 178 fp.write(_GENERATED_FILE_HEADER + '\n'.join(exports)) 179 180 if missing_output_files: 181 raise ValueError( 182 'Missing outputs for python_api_gen genrule:\n%s.' 183 'Make sure all required outputs are in the ' 184 'tensorflow/tools/api/generator/BUILD file.' % 185 ',\n'.join(sorted(missing_output_files))) 186 187 188def main(output_files): 189 create_api_files(output_files) 190 191if __name__ == '__main__': 192 parser = argparse.ArgumentParser() 193 parser.add_argument( 194 'outputs', metavar='O', type=str, nargs='+', 195 help='Python files that we expect this script to output.') 196 args = parser.parse_args() 197 main(args.outputs) 198