1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Script Language Operators. See the @{$python/script_ops} guide. 17 18@@py_func 19""" 20 21# pylint: disable=g-bad-name 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import threading 27 28import numpy as np 29import six 30 31from tensorflow.python import pywrap_tensorflow 32from tensorflow.python.eager import context 33from tensorflow.python.framework import function 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import gen_script_ops 36from tensorflow.python.util import nest 37from tensorflow.python.util.tf_export import tf_export 38 39 40class EagerFunc(object): 41 """A wrapper for a function owned by an EagerPyFunc.""" 42 43 def __init__(self, func, Tout): 44 """Constructs an EagerFunc. 45 46 Args: 47 func: The function to wrap. 48 Tout: A list of datatypes for the output; an empty list if the output is 49 None. 50 """ 51 self._func = func 52 self._out_dtypes = Tout 53 54 def __call__(self, on_gpu, args): 55 """Passes `args` to `self._func`, which is executed eagerly.""" 56 with context.eager_mode(): 57 ret = self._func(*args) 58 maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu() 59 if isinstance(ret, (tuple, list)): 60 return [ 61 maybe_copy_to_gpu(ops.convert_to_tensor(x, dtype=dtype)) 62 for (x, dtype) in zip(ret, self._out_dtypes) 63 ] 64 elif ret is None: 65 return ret 66 else: 67 return maybe_copy_to_gpu( 68 ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])) 69 70 71class FuncRegistry(object): 72 """A helper class to keep track of registered py functions. 73 74 FuncRegistry keeps a map from unique tokens (string) to python 75 functions, which takes numpy arrays and outputs numpy arrays. 76 """ 77 78 def __init__(self): 79 self._lock = threading.Lock() 80 self._unique_id = 0 # GUARDED_BY(self._lock) 81 self._funcs = {} 82 83 def insert(self, func): 84 """Registers `func` and returns a unique token for this entry.""" 85 token = self._next_unique_token() 86 self._funcs[token] = func 87 return token 88 89 def remove(self, token): 90 """Removes the registered function corresponding to `token`.""" 91 self._funcs.pop(token, None) 92 93 @staticmethod 94 def _convert(value, dtype=None): 95 """Converts an arg to numpy, avoiding dangerous string and unicode dtypes. 96 97 Numpy pads with zeros when using string and unicode dtypes if different 98 components of a tensor have different lengths. This is bad: ignoring the 99 padding is wrong for text data, and removing the padding is wrong for binary 100 data. To avoid this bug, we redo the conversion using an object dtype. 101 Additionally, we convert unicode strings to (byte-)strings for 102 compatibility. 103 104 Args: 105 value: Value to convert to a numpy array. 106 dtype: (Optional.) Desired NumPy type for the returned value. 107 108 Returns: 109 A numpy array. 110 """ 111 result = np.asarray(value, dtype=dtype, order="C") 112 if result.dtype.char == "S" and result is not value: 113 return np.asarray(value, order="C", dtype=object) 114 elif result.dtype.char == "U" and result is not value: 115 value = np.vectorize(lambda x: x.encode("utf8"))(value) 116 return np.asarray(value, order="C", dtype=object) 117 elif result.dtype.char == "U": 118 return result.astype(np.bytes_) 119 else: 120 return result 121 122 def __call__(self, token, on_gpu, args): 123 """Calls the registered function for `token` with args. 124 125 Args: 126 token: A key into this `FuncRegistry` identifying which function to call. 127 on_gpu: A boolean indicating whether or not `token`'s corresponding 128 operation was placed on GPU; only used if the function registered for 129 `token` is an `EagerPyFunc`. 130 args: The arguments to pass to the function registered for `token`. 131 132 Returns: 133 The output of the function registered for `token`. 134 135 Raises: 136 ValueError: if no function is registered for `token`. 137 """ 138 func = self._funcs[token] 139 if func is None: 140 raise ValueError("callback %s is not found" % token) 141 if isinstance(func, EagerFunc): 142 return func(on_gpu, args) 143 else: 144 ret = func(*args) 145 # Strings seem to lead to a memory leak here if they're not wrapped in a 146 # list. 147 if isinstance(ret, six.binary_type): 148 ret = [ret] 149 # Ensures that we return either a single numpy array or a list of numpy 150 # arrays. 151 if isinstance(ret, (tuple, list)): 152 return [self._convert(x) for x in ret] 153 else: 154 return self._convert(ret) 155 156 def size(self): 157 """Returns how many functions are currently registered.""" 158 return len(self._funcs) 159 160 def _next_unique_token(self): 161 """Returns a unique token.""" 162 with self._lock: 163 uid = self._unique_id 164 self._unique_id += 1 165 return "pyfunc_%d" % uid 166 167# Global registry for py functions. 168_py_funcs = FuncRegistry() 169 170pywrap_tensorflow.InitializePyTrampoline(_py_funcs) 171 172 173class CleanupFunc(object): 174 """A helper class to remove a registered function from _py_funcs.""" 175 176 def __init__(self, token): 177 self._token = token 178 179 def __del__(self): 180 if _py_funcs is not None: 181 # If _py_funcs is None, the program is most likely in shutdown, and the 182 # _py_funcs object has been destroyed already. 183 _py_funcs.remove(self._token) 184 185 186def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): 187 """See documentation for py_func and eager_py_func.""" 188 189 is_list_or_tuple = False 190 if isinstance(Tout, (list, tuple)): 191 is_list_or_tuple = True 192 else: 193 Tout = [Tout] 194 195 if eager: 196 func = EagerFunc(func, Tout) 197 198 token = _py_funcs.insert(func) 199 # We tie the registered function's lifetime with the current default graph, 200 # i.e., when the current graph is destroyed, we remove its py funcs. 201 graph = ops.get_default_graph() 202 203 # pylint: disable=protected-access 204 while isinstance(graph, function._FuncGraph): 205 # If the py_func was declared inside a _FuncGraph, its lifetime should be 206 # bound to that of the outer graph instead. 207 graph = graph._outer_graph 208 209 cleanup = CleanupFunc(token) 210 211 # TODO(zhifengc): Consider adding a Graph method to collect 212 # `cleanup` objects in one of its member. 213 if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"): 214 graph._cleanup_py_funcs_used_in_graph = [] 215 216 # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph 217 # will be destroyed and their __del__ will remove the 'token' from 218 # the funcs registry. 219 graph._cleanup_py_funcs_used_in_graph.append(cleanup) 220 # pylint: enable=protected-access 221 222 # pylint: disable=protected-access 223 if eager: 224 result = gen_script_ops._eager_py_func( 225 input=inp, token=token, Tout=Tout, name=name) 226 else: 227 if stateful: 228 result = gen_script_ops._py_func( 229 input=inp, token=token, Tout=Tout, name=name) 230 else: 231 result = gen_script_ops._py_func_stateless( 232 input=inp, token=token, Tout=Tout, name=name) 233 # pylint: enable=protected-access 234 return result if is_list_or_tuple else result[0] 235 236 237def eager_py_func(func, inp, Tout, name=None): 238 """Wraps a python function into a TensorFlow op. 239 240 When the returned op is executed, `func` is invoked with eager execution 241 enabled. Inputs are Tensor objects and func must return None or objects 242 that may be converted to Tensor objects. 243 244 This function has the same limitations as `py_func` with respect to 245 serialization and distribution. 246 247 Args: 248 func: A Python function which accepts a list of `Tensor` objects 249 having element types that match the corresponding `tf.Tensor` objects 250 in `inp` and returns a list of `Tensor` objects (or a single 251 `Tensor`, or `None`) having element types that match the 252 corresponding values in `Tout`. 253 inp: A list of `Tensor` objects. 254 Tout: A list or tuple of tensorflow data types or a single tensorflow data 255 type if there is only one, indicating what `func` returns; an empty list 256 if no value is returned (i.e., if the return value is `None`). 257 name: A name for the operation (optional). 258 259 Returns: 260 A list of `Tensor` or a single `Tensor` which `func` computes; an empty list 261 if `func` returns None. 262 """ 263 return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name) 264 265 266@tf_export("py_func") 267def py_func(func, inp, Tout, stateful=True, name=None): 268 """Wraps a python function and uses it as a TensorFlow op. 269 270 Given a python function `func`, which takes numpy arrays as its 271 arguments and returns numpy arrays as its outputs, wrap this function as an 272 operation in a TensorFlow graph. The following snippet constructs a simple 273 TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation 274 in the graph: 275 276 ```python 277 def my_func(x): 278 # x will be a numpy array with the contents of the placeholder below 279 return np.sinh(x) 280 input = tf.placeholder(tf.float32) 281 y = tf.py_func(my_func, [input], tf.float32) 282 ``` 283 284 **N.B.** The `tf.py_func()` operation has the following known limitations: 285 286 * The body of the function (i.e. `func`) will not be serialized in a 287 `GraphDef`. Therefore, you should not use this function if you need to 288 serialize your model and restore it in a different environment. 289 290 * The operation must run in the same address space as the Python program 291 that calls `tf.py_func()`. If you are using distributed TensorFlow, you 292 must run a `tf.train.Server` in the same process as the program that calls 293 `tf.py_func()` and you must pin the created operation to a device in that 294 server (e.g. using `with tf.device():`). 295 296 Args: 297 func: A Python function, which accepts `ndarray` objects as arguments and 298 returns a list of `ndarray` objects (or a single `ndarray`). This function 299 must accept as many arguments as there are tensors in `inp`, and these 300 argument types will match the corresponding `tf.Tensor` objects 301 in `inp`. The returns `ndarray`s must match the number and types defined 302 `Tout`. 303 Important Note: Input and output numpy `ndarray`s of `func` are not 304 guaranteed to be copies. In some cases their underlying memory will be 305 shared with the corresponding TensorFlow tensors. 306 In-place modification or storing `func` input or return values in 307 python datastructures without explicit (np.)copy 308 can have non-deterministic consequences. 309 inp: A list of `Tensor` objects. 310 Tout: A list or tuple of tensorflow data types or a single tensorflow data 311 type if there is only one, indicating what `func` returns. 312 stateful: (Boolean.) If True, the function should be considered stateful. 313 If a function is stateless, when given the same input it will return the 314 same output and have no observable side effects. Optimizations such as 315 common subexpression elimination are only performed on stateless 316 operations. 317 name: A name for the operation (optional). 318 319 Returns: 320 A list of `Tensor` or a single `Tensor` which `func` computes. 321 """ 322 if context.in_eager_mode(): 323 result = func(*[x.numpy() for x in inp]) 324 result = nest.flatten(result) 325 326 return [x if x is None else ops.convert_to_tensor(x) for x in result] 327 328 return _internal_py_func( 329 func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name) 330 331 332ops.NotDifferentiable("PyFunc") 333ops.NotDifferentiable("PyFuncStateless") 334