1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Script Language Operators. See the @{$python/script_ops} guide.
17
18@@py_func
19"""
20
21# pylint: disable=g-bad-name
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import threading
27
28import numpy as np
29import six
30
31from tensorflow.python import pywrap_tensorflow
32from tensorflow.python.eager import context
33from tensorflow.python.framework import function
34from tensorflow.python.framework import ops
35from tensorflow.python.ops import gen_script_ops
36from tensorflow.python.util import nest
37from tensorflow.python.util.tf_export import tf_export
38
39
40class EagerFunc(object):
41  """A wrapper for a function owned by an EagerPyFunc."""
42
43  def __init__(self, func, Tout):
44    """Constructs an EagerFunc.
45
46    Args:
47      func: The function to wrap.
48      Tout: A list of datatypes for the output; an empty list if the output is
49            None.
50    """
51    self._func = func
52    self._out_dtypes = Tout
53
54  def __call__(self, on_gpu, args):
55    """Passes `args` to `self._func`, which is executed eagerly."""
56    with context.eager_mode():
57      ret = self._func(*args)
58      maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu()
59      if isinstance(ret, (tuple, list)):
60        return [
61            maybe_copy_to_gpu(ops.convert_to_tensor(x, dtype=dtype))
62            for (x, dtype) in zip(ret, self._out_dtypes)
63        ]
64      elif ret is None:
65        return ret
66      else:
67        return maybe_copy_to_gpu(
68            ops.convert_to_tensor(ret, dtype=self._out_dtypes[0]))
69
70
71class FuncRegistry(object):
72  """A helper class to keep track of registered py functions.
73
74  FuncRegistry keeps a map from unique tokens (string) to python
75  functions, which takes numpy arrays and outputs numpy arrays.
76  """
77
78  def __init__(self):
79    self._lock = threading.Lock()
80    self._unique_id = 0  # GUARDED_BY(self._lock)
81    self._funcs = {}
82
83  def insert(self, func):
84    """Registers `func` and returns a unique token for this entry."""
85    token = self._next_unique_token()
86    self._funcs[token] = func
87    return token
88
89  def remove(self, token):
90    """Removes the registered function corresponding to `token`."""
91    self._funcs.pop(token, None)
92
93  @staticmethod
94  def _convert(value, dtype=None):
95    """Converts an arg to numpy, avoiding dangerous string and unicode dtypes.
96
97    Numpy pads with zeros when using string and unicode dtypes if different
98    components of a tensor have different lengths.  This is bad: ignoring the
99    padding is wrong for text data, and removing the padding is wrong for binary
100    data.  To avoid this bug, we redo the conversion using an object dtype.
101    Additionally, we convert unicode strings to (byte-)strings for
102    compatibility.
103
104    Args:
105      value: Value to convert to a numpy array.
106      dtype: (Optional.) Desired NumPy type for the returned value.
107
108    Returns:
109      A numpy array.
110    """
111    result = np.asarray(value, dtype=dtype, order="C")
112    if result.dtype.char == "S" and result is not value:
113      return np.asarray(value, order="C", dtype=object)
114    elif result.dtype.char == "U" and result is not value:
115      value = np.vectorize(lambda x: x.encode("utf8"))(value)
116      return np.asarray(value, order="C", dtype=object)
117    elif result.dtype.char == "U":
118      return result.astype(np.bytes_)
119    else:
120      return result
121
122  def __call__(self, token, on_gpu, args):
123    """Calls the registered function for `token` with args.
124
125    Args:
126      token: A key into this `FuncRegistry` identifying which function to call.
127      on_gpu: A boolean indicating whether or not `token`'s corresponding
128        operation was placed on GPU; only used if the function registered for
129        `token` is an `EagerPyFunc`.
130      args: The arguments to pass to the function registered for `token`.
131
132    Returns:
133      The output of the function registered for `token`.
134
135    Raises:
136      ValueError: if no function is registered for `token`.
137    """
138    func = self._funcs[token]
139    if func is None:
140      raise ValueError("callback %s is not found" % token)
141    if isinstance(func, EagerFunc):
142      return func(on_gpu, args)
143    else:
144      ret = func(*args)
145      # Strings seem to lead to a memory leak here if they're not wrapped in a
146      # list.
147      if isinstance(ret, six.binary_type):
148        ret = [ret]
149      # Ensures that we return either a single numpy array or a list of numpy
150      # arrays.
151      if isinstance(ret, (tuple, list)):
152        return [self._convert(x) for x in ret]
153      else:
154        return self._convert(ret)
155
156  def size(self):
157    """Returns how many functions are currently registered."""
158    return len(self._funcs)
159
160  def _next_unique_token(self):
161    """Returns a unique token."""
162    with self._lock:
163      uid = self._unique_id
164      self._unique_id += 1
165    return "pyfunc_%d" % uid
166
167# Global registry for py functions.
168_py_funcs = FuncRegistry()
169
170pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
171
172
173class CleanupFunc(object):
174  """A helper class to remove a registered function from _py_funcs."""
175
176  def __init__(self, token):
177    self._token = token
178
179  def __del__(self):
180    if _py_funcs is not None:
181      # If _py_funcs is None, the program is most likely in shutdown, and the
182      # _py_funcs object has been destroyed already.
183      _py_funcs.remove(self._token)
184
185
186def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
187  """See documentation for py_func and eager_py_func."""
188
189  is_list_or_tuple = False
190  if isinstance(Tout, (list, tuple)):
191    is_list_or_tuple = True
192  else:
193    Tout = [Tout]
194
195  if eager:
196    func = EagerFunc(func, Tout)
197
198  token = _py_funcs.insert(func)
199  # We tie the registered function's lifetime with the current default graph,
200  # i.e., when the current graph is destroyed, we remove its py funcs.
201  graph = ops.get_default_graph()
202
203  # pylint: disable=protected-access
204  while isinstance(graph, function._FuncGraph):
205    # If the py_func was declared inside a _FuncGraph, its lifetime should be
206    # bound to that of the outer graph instead.
207    graph = graph._outer_graph
208
209  cleanup = CleanupFunc(token)
210
211  # TODO(zhifengc): Consider adding a Graph method to collect
212  # `cleanup` objects in one of its member.
213  if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"):
214    graph._cleanup_py_funcs_used_in_graph = []
215
216  # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph
217  # will be destroyed and their __del__ will remove the 'token' from
218  # the funcs registry.
219  graph._cleanup_py_funcs_used_in_graph.append(cleanup)
220  # pylint: enable=protected-access
221
222  # pylint: disable=protected-access
223  if eager:
224    result = gen_script_ops._eager_py_func(
225        input=inp, token=token, Tout=Tout, name=name)
226  else:
227    if stateful:
228      result = gen_script_ops._py_func(
229          input=inp, token=token, Tout=Tout, name=name)
230    else:
231      result = gen_script_ops._py_func_stateless(
232          input=inp, token=token, Tout=Tout, name=name)
233  # pylint: enable=protected-access
234  return result if is_list_or_tuple else result[0]
235
236
237def eager_py_func(func, inp, Tout, name=None):
238  """Wraps a python function into a TensorFlow op.
239
240  When the returned op is executed, `func` is invoked with eager execution
241  enabled. Inputs are Tensor objects and func must return None or objects
242  that may be converted to Tensor objects.
243
244  This function has the same limitations as `py_func` with respect to
245  serialization and distribution.
246
247  Args:
248    func: A Python function which accepts a list of `Tensor` objects
249      having element types that match the corresponding `tf.Tensor` objects
250      in `inp` and returns a list of `Tensor` objects (or a single
251      `Tensor`, or `None`) having element types that match the
252      corresponding values in `Tout`.
253    inp: A list of `Tensor` objects.
254    Tout: A list or tuple of tensorflow data types or a single tensorflow data
255      type if there is only one, indicating what `func` returns; an empty list
256      if no value is returned (i.e., if the return value is `None`).
257    name: A name for the operation (optional).
258
259  Returns:
260    A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
261    if `func` returns None.
262  """
263  return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
264
265
266@tf_export("py_func")
267def py_func(func, inp, Tout, stateful=True, name=None):
268  """Wraps a python function and uses it as a TensorFlow op.
269
270  Given a python function `func`, which takes numpy arrays as its
271  arguments and returns numpy arrays as its outputs, wrap this function as an
272  operation in a TensorFlow graph. The following snippet constructs a simple
273  TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation
274  in the graph:
275
276  ```python
277  def my_func(x):
278    # x will be a numpy array with the contents of the placeholder below
279    return np.sinh(x)
280  input = tf.placeholder(tf.float32)
281  y = tf.py_func(my_func, [input], tf.float32)
282  ```
283
284  **N.B.** The `tf.py_func()` operation has the following known limitations:
285
286  * The body of the function (i.e. `func`) will not be serialized in a
287    `GraphDef`. Therefore, you should not use this function if you need to
288    serialize your model and restore it in a different environment.
289
290  * The operation must run in the same address space as the Python program
291    that calls `tf.py_func()`. If you are using distributed TensorFlow, you
292    must run a `tf.train.Server` in the same process as the program that calls
293    `tf.py_func()` and you must pin the created operation to a device in that
294    server (e.g. using `with tf.device():`).
295
296  Args:
297    func: A Python function, which accepts `ndarray` objects as arguments and
298      returns a list of `ndarray` objects (or a single `ndarray`). This function
299      must accept as many arguments as there are tensors in `inp`, and these
300      argument types will match the corresponding `tf.Tensor` objects
301      in `inp`. The returns `ndarray`s must match the number and types defined
302      `Tout`.
303      Important Note: Input and output numpy `ndarray`s of `func` are not
304      guaranteed to be copies. In some cases their underlying memory will be
305      shared with the corresponding TensorFlow tensors.
306      In-place modification or storing `func` input or return values in
307      python datastructures without explicit (np.)copy
308      can have non-deterministic consequences.
309    inp: A list of `Tensor` objects.
310    Tout: A list or tuple of tensorflow data types or a single tensorflow data
311      type if there is only one, indicating what `func` returns.
312    stateful: (Boolean.) If True, the function should be considered stateful.
313      If a function is stateless, when given the same input it will return the
314      same output and have no observable side effects. Optimizations such as
315      common subexpression elimination are only performed on stateless
316      operations.
317    name: A name for the operation (optional).
318
319  Returns:
320    A list of `Tensor` or a single `Tensor` which `func` computes.
321  """
322  if context.in_eager_mode():
323    result = func(*[x.numpy() for x in inp])
324    result = nest.flatten(result)
325
326    return [x if x is None else ops.convert_to_tensor(x) for x in result]
327
328  return _internal_py_func(
329      func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
330
331
332ops.NotDifferentiable("PyFunc")
333ops.NotDifferentiable("PyFuncStateless")
334