1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base TFDecorator class and utility functions for working with decorators.
16
17There are two ways to create decorators that TensorFlow can introspect into.
18This is important for documentation generation purposes, so that function
19signatures aren't obscured by the (*args, **kwds) signature that decorators
20often provide.
21
221. Call `tf_decorator.make_decorator` on your wrapper function. If your
23decorator is stateless, or can capture all of the variables it needs to work
24with through lexical closure, this is the simplest option. Create your wrapper
25function as usual, but instead of returning it, return
26`tf_decorator.make_decorator(target, your_wrapper)`. This will attach some
27decorator introspection metadata onto your wrapper and return it.
28
29Example:
30
31  def print_hello_before_calling(target):
32    def wrapper(*args, **kwargs):
33      print('hello')
34      return target(*args, **kwargs)
35    return tf_decorator.make_decorator(target, wrapper)
36
372. Derive from TFDecorator. If your decorator needs to be stateful, you can
38implement it in terms of a TFDecorator. Store whatever state you need in your
39derived class, and implement the `__call__` method to do your work before
40calling into your target. You can retrieve the target via
41`super(MyDecoratorClass, self).decorated_target`, and call it with whatever
42parameters it needs.
43
44Example:
45
46  class CallCounter(tf_decorator.TFDecorator):
47    def __init__(self, target):
48      super(CallCounter, self).__init__('count_calls', target)
49      self.call_count = 0
50
51    def __call__(self, *args, **kwargs):
52      self.call_count += 1
53      return super(CallCounter, self).decorated_target(*args, **kwargs)
54
55  def count_calls(target):
56    return CallCounter(target)
57"""
58from __future__ import absolute_import
59from __future__ import division
60from __future__ import print_function
61
62import functools as _functools
63import traceback as _traceback
64
65
66def make_decorator(target,
67                   decorator_func,
68                   decorator_name=None,
69                   decorator_doc='',
70                   decorator_argspec=None):
71  """Make a decorator from a wrapper and a target.
72
73  Args:
74    target: The final callable to be wrapped.
75    decorator_func: The wrapper function.
76    decorator_name: The name of the decorator. If `None`, the name of the
77      function calling make_decorator.
78    decorator_doc: Documentation specific to this application of
79      `decorator_func` to `target`.
80    decorator_argspec: The new callable signature of this decorator.
81
82  Returns:
83    The `decorator_func` argument with new metadata attached.
84  """
85  if decorator_name is None:
86    frame = _traceback.extract_stack(limit=2)[0]
87    # frame name is tuple[2] in python2, and object.name in python3
88    decorator_name = getattr(frame, 'name', frame[2])  # Caller's name
89  decorator = TFDecorator(decorator_name, target, decorator_doc,
90                          decorator_argspec)
91  setattr(decorator_func, '_tf_decorator', decorator)
92  # Objects that are callables (e.g., a functools.partial object) may not have
93  # the following attributes.
94  if hasattr(target, '__name__'):
95    decorator_func.__name__ = target.__name__
96  if hasattr(target, '__module__'):
97    decorator_func.__module__ = target.__module__
98  if hasattr(target, '__doc__'):
99    decorator_func.__doc__ = decorator.__doc__
100  decorator_func.__wrapped__ = target
101  return decorator_func
102
103
104def unwrap(maybe_tf_decorator):
105  """Unwraps an object into a list of TFDecorators and a final target.
106
107  Args:
108    maybe_tf_decorator: Any callable object.
109
110  Returns:
111    A tuple whose first element is an list of TFDecorator-derived objects that
112    were applied to the final callable target, and whose second element is the
113    final undecorated callable target. If the `maybe_tf_decorator` parameter is
114    not decorated by any TFDecorators, the first tuple element will be an empty
115    list. The `TFDecorator` list is ordered from outermost to innermost
116    decorators.
117  """
118  decorators = []
119  cur = maybe_tf_decorator
120  while True:
121    if isinstance(cur, TFDecorator):
122      decorators.append(cur)
123    elif hasattr(cur, '_tf_decorator'):
124      decorators.append(getattr(cur, '_tf_decorator'))
125    else:
126      break
127    cur = decorators[-1].decorated_target
128  return decorators, cur
129
130
131class TFDecorator(object):
132  """Base class for all TensorFlow decorators.
133
134  TFDecorator captures and exposes the wrapped target, and provides details
135  about the current decorator.
136  """
137
138  def __init__(self,
139               decorator_name,
140               target,
141               decorator_doc='',
142               decorator_argspec=None):
143    self._decorated_target = target
144    self._decorator_name = decorator_name
145    self._decorator_doc = decorator_doc
146    self._decorator_argspec = decorator_argspec
147    if hasattr(target, '__name__'):
148      self.__name__ = target.__name__
149    if self._decorator_doc:
150      self.__doc__ = self._decorator_doc
151    elif hasattr(target, '__doc__') and target.__doc__:
152      self.__doc__ = target.__doc__
153    else:
154      self.__doc__ = ''
155
156  def __get__(self, obj, objtype):
157    return _functools.partial(self.__call__, obj)
158
159  def __call__(self, *args, **kwargs):
160    return self._decorated_target(*args, **kwargs)
161
162  @property
163  def decorated_target(self):
164    return self._decorated_target
165
166  @property
167  def decorator_name(self):
168    return self._decorator_name
169
170  @property
171  def decorator_doc(self):
172    return self._decorator_doc
173
174  @property
175  def decorator_argspec(self):
176    return self._decorator_argspec
177