1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Imports unittest as a replacement for testing.pybase.googletest."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import atexit
22import itertools
23import os
24import sys
25import tempfile
26
27# go/tf-wildcard-import
28# pylint: disable=wildcard-import
29from unittest import *
30# pylint: enable=wildcard-import
31
32from tensorflow.python.framework import errors
33from tensorflow.python.lib.io import file_io
34from tensorflow.python.platform import app
35from tensorflow.python.platform import benchmark
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.util import tf_decorator
38from tensorflow.python.util import tf_inspect
39
40
41Benchmark = benchmark.TensorFlowBenchmark  # pylint: disable=invalid-name
42
43unittest_main = main
44
45# We keep a global variable in this module to make sure we create the temporary
46# directory only once per test binary invocation.
47_googletest_temp_dir = ''
48
49
50# pylint: disable=invalid-name
51# pylint: disable=undefined-variable
52def g_main(argv):
53  """Delegate to unittest.main after redefining testLoader."""
54  if 'TEST_SHARD_STATUS_FILE' in os.environ:
55    try:
56      f = None
57      try:
58        f = open(os.environ['TEST_SHARD_STATUS_FILE'], 'w')
59        f.write('')
60      except IOError:
61        sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.'
62                         % os.environ['TEST_SHARD_STATUS_FILE'])
63        sys.exit(1)
64    finally:
65      if f is not None: f.close()
66
67  if ('TEST_TOTAL_SHARDS' not in os.environ or
68      'TEST_SHARD_INDEX' not in os.environ):
69    return unittest_main(argv=argv)
70
71  total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
72  shard_index = int(os.environ['TEST_SHARD_INDEX'])
73  base_loader = TestLoader()
74
75  delegate_get_names = base_loader.getTestCaseNames
76  bucket_iterator = itertools.cycle(range(total_shards))
77
78  def getShardedTestCaseNames(testCaseClass):
79    filtered_names = []
80    for testcase in sorted(delegate_get_names(testCaseClass)):
81      bucket = next(bucket_iterator)
82      if bucket == shard_index:
83        filtered_names.append(testcase)
84    return filtered_names
85
86  # Override getTestCaseNames
87  base_loader.getTestCaseNames = getShardedTestCaseNames
88
89  unittest_main(argv=argv, testLoader=base_loader)
90
91
92# Redefine main to allow running benchmarks
93def main(argv=None):  # pylint: disable=function-redefined
94  def main_wrapper():
95    args = argv
96    if args is None:
97      args = sys.argv
98    return app.run(main=g_main, argv=args)
99  benchmark.benchmarks_main(true_main=main_wrapper)
100
101
102def GetTempDir():
103  """Return a temporary directory for tests to use."""
104  global _googletest_temp_dir
105  if not _googletest_temp_dir:
106    first_frame = tf_inspect.stack()[-1][0]
107    temp_dir = os.path.join(tempfile.gettempdir(),
108                            os.path.basename(tf_inspect.getfile(first_frame)))
109    temp_dir = tempfile.mkdtemp(prefix=temp_dir.rstrip('.py'))
110
111    def delete_temp_dir(dirname=temp_dir):
112      try:
113        file_io.delete_recursively(dirname)
114      except errors.OpError as e:
115        logging.error('Error removing %s: %s', dirname, e)
116
117    atexit.register(delete_temp_dir)
118    _googletest_temp_dir = temp_dir
119
120  return _googletest_temp_dir
121
122
123def test_src_dir_path(relative_path):
124  """Creates an absolute test srcdir path given a relative path.
125
126  Args:
127    relative_path: a path relative to tensorflow root.
128      e.g. "contrib/session_bundle/example".
129
130  Returns:
131    An absolute path to the linked in runfiles.
132  """
133  return os.path.join(os.environ['TEST_SRCDIR'],
134                      'org_tensorflow/tensorflow', relative_path)
135
136
137def StatefulSessionAvailable():
138  return False
139
140
141class StubOutForTesting(object):
142  """Support class for stubbing methods out for unit testing.
143
144  Sample Usage:
145
146  You want os.path.exists() to always return true during testing.
147
148     stubs = StubOutForTesting()
149     stubs.Set(os.path, 'exists', lambda x: 1)
150       ...
151     stubs.CleanUp()
152
153  The above changes os.path.exists into a lambda that returns 1.  Once
154  the ... part of the code finishes, the CleanUp() looks up the old
155  value of os.path.exists and restores it.
156  """
157
158  def __init__(self):
159    self.cache = []
160    self.stubs = []
161
162  def __del__(self):
163    """Do not rely on the destructor to undo your stubs.
164
165    You cannot guarantee exactly when the destructor will get called without
166    relying on implementation details of a Python VM that may change.
167    """
168    self.CleanUp()
169
170  # __enter__ and __exit__ allow use as a context manager.
171  def __enter__(self):
172    return self
173
174  def __exit__(self, unused_exc_type, unused_exc_value, unused_tb):
175    self.CleanUp()
176
177  def CleanUp(self):
178    """Undoes all SmartSet() & Set() calls, restoring original definitions."""
179    self.SmartUnsetAll()
180    self.UnsetAll()
181
182  def SmartSet(self, obj, attr_name, new_attr):
183    """Replace obj.attr_name with new_attr.
184
185    This method is smart and works at the module, class, and instance level
186    while preserving proper inheritance. It will not stub out C types however
187    unless that has been explicitly allowed by the type.
188
189    This method supports the case where attr_name is a staticmethod or a
190    classmethod of obj.
191
192    Notes:
193      - If obj is an instance, then it is its class that will actually be
194        stubbed. Note that the method Set() does not do that: if obj is
195        an instance, it (and not its class) will be stubbed.
196      - The stubbing is using the builtin getattr and setattr. So, the __get__
197        and __set__ will be called when stubbing (TODO: A better idea would
198        probably be to manipulate obj.__dict__ instead of getattr() and
199        setattr()).
200
201    Args:
202      obj: The object whose attributes we want to modify.
203      attr_name: The name of the attribute to modify.
204      new_attr: The new value for the attribute.
205
206    Raises:
207      AttributeError: If the attribute cannot be found.
208    """
209    _, obj = tf_decorator.unwrap(obj)
210    if (tf_inspect.ismodule(obj) or
211        (not tf_inspect.isclass(obj) and attr_name in obj.__dict__)):
212      orig_obj = obj
213      orig_attr = getattr(obj, attr_name)
214    else:
215      if not tf_inspect.isclass(obj):
216        mro = list(tf_inspect.getmro(obj.__class__))
217      else:
218        mro = list(tf_inspect.getmro(obj))
219
220      mro.reverse()
221
222      orig_attr = None
223      found_attr = False
224
225      for cls in mro:
226        try:
227          orig_obj = cls
228          orig_attr = getattr(obj, attr_name)
229          found_attr = True
230        except AttributeError:
231          continue
232
233      if not found_attr:
234        raise AttributeError('Attribute not found.')
235
236    # Calling getattr() on a staticmethod transforms it to a 'normal' function.
237    # We need to ensure that we put it back as a staticmethod.
238    old_attribute = obj.__dict__.get(attr_name)
239    if old_attribute is not None and isinstance(old_attribute, staticmethod):
240      orig_attr = staticmethod(orig_attr)
241
242    self.stubs.append((orig_obj, attr_name, orig_attr))
243    setattr(orig_obj, attr_name, new_attr)
244
245  def SmartUnsetAll(self):
246    """Reverses SmartSet() calls, restoring things to original definitions.
247
248    This method is automatically called when the StubOutForTesting()
249    object is deleted; there is no need to call it explicitly.
250
251    It is okay to call SmartUnsetAll() repeatedly, as later calls have
252    no effect if no SmartSet() calls have been made.
253    """
254    for args in reversed(self.stubs):
255      setattr(*args)
256
257    self.stubs = []
258
259  def Set(self, parent, child_name, new_child):
260    """In parent, replace child_name's old definition with new_child.
261
262    The parent could be a module when the child is a function at
263    module scope.  Or the parent could be a class when a class' method
264    is being replaced.  The named child is set to new_child, while the
265    prior definition is saved away for later, when UnsetAll() is
266    called.
267
268    This method supports the case where child_name is a staticmethod or a
269    classmethod of parent.
270
271    Args:
272      parent: The context in which the attribute child_name is to be changed.
273      child_name: The name of the attribute to change.
274      new_child: The new value of the attribute.
275    """
276    old_child = getattr(parent, child_name)
277
278    old_attribute = parent.__dict__.get(child_name)
279    if old_attribute is not None and isinstance(old_attribute, staticmethod):
280      old_child = staticmethod(old_child)
281
282    self.cache.append((parent, old_child, child_name))
283    setattr(parent, child_name, new_child)
284
285  def UnsetAll(self):
286    """Reverses Set() calls, restoring things to their original definitions.
287
288    This method is automatically called when the StubOutForTesting()
289    object is deleted; there is no need to call it explicitly.
290
291    It is okay to call UnsetAll() repeatedly, as later calls have no
292    effect if no Set() calls have been made.
293    """
294    # Undo calls to Set() in reverse order, in case Set() was called on the
295    # same arguments repeatedly (want the original call to be last one undone)
296    for (parent, old_child, child_name) in reversed(self.cache):
297      setattr(parent, child_name, old_child)
298    self.cache = []
299