1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""decorator_utils tests."""
16
17# pylint: disable=unused-import
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23
24from tensorflow.python.platform import test
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.util import decorator_utils
27
28
29def _test_function(unused_arg=0):
30  pass
31
32
33class GetQualifiedNameTest(test.TestCase):
34
35  def test_method(self):
36    self.assertEqual(
37        "GetQualifiedNameTest.test_method",
38        decorator_utils.get_qualified_name(GetQualifiedNameTest.test_method))
39
40  def test_function(self):
41    self.assertEqual("_test_function",
42                     decorator_utils.get_qualified_name(_test_function))
43
44
45class AddNoticeToDocstringTest(test.TestCase):
46
47  def _check(self, doc, expected):
48    self.assertEqual(
49        decorator_utils.add_notice_to_docstring(
50            doc=doc,
51            instructions="Instructions",
52            no_doc_str="Nothing here",
53            suffix_str="(suffix)",
54            notice=["Go away"]),
55        expected)
56
57  def test_regular(self):
58    expected = ("Brief (suffix)\n\nGo away\nInstructions\n\nDocstring\n\n"
59                "Args:\n  arg1: desc")
60    # No indent for main docstring
61    self._check("Brief\n\nDocstring\n\nArgs:\n  arg1: desc", expected)
62    # 2 space indent for main docstring, blank lines not indented
63    self._check("Brief\n\n  Docstring\n\n  Args:\n    arg1: desc", expected)
64    # 2 space indent for main docstring, blank lines indented as well.
65    self._check("Brief\n  \n  Docstring\n  \n  Args:\n    arg1: desc", expected)
66    # No indent for main docstring, first line blank.
67    self._check("\n  Brief\n  \n  Docstring\n  \n  Args:\n    arg1: desc",
68                expected)
69    # 2 space indent, first line blank.
70    self._check("\n  Brief\n  \n  Docstring\n  \n  Args:\n    arg1: desc",
71                expected)
72
73  def test_brief_only(self):
74    expected = "Brief (suffix)\n\nGo away\nInstructions"
75    self._check("Brief", expected)
76    self._check("Brief\n", expected)
77    self._check("Brief\n  ", expected)
78    self._check("\nBrief\n  ", expected)
79    self._check("\n  Brief\n  ", expected)
80
81  def test_no_docstring(self):
82    expected = "Nothing here\n\nGo away\nInstructions"
83    self._check(None, expected)
84    self._check("", expected)
85
86  def test_no_empty_line(self):
87    expected = "Brief (suffix)\n\nGo away\nInstructions\n\nDocstring"
88    # No second line indent
89    self._check("Brief\nDocstring", expected)
90    # 2 space second line indent
91    self._check("Brief\n  Docstring", expected)
92    # No second line indent, first line blank
93    self._check("\nBrief\nDocstring", expected)
94    # 2 space second line indent, first line blank
95    self._check("\n  Brief\n  Docstring", expected)
96
97
98class ValidateCallableTest(test.TestCase):
99
100  def test_function(self):
101    decorator_utils.validate_callable(_test_function, "test")
102
103  def test_method(self):
104    decorator_utils.validate_callable(self.test_method, "test")
105
106  def test_callable(self):
107
108    class TestClass(object):
109
110      def __call__(self):
111        pass
112
113    decorator_utils.validate_callable(TestClass(), "test")
114
115  def test_partial(self):
116    partial = functools.partial(_test_function, unused_arg=7)
117    decorator_utils.validate_callable(partial, "test")
118
119  def test_fail_non_callable(self):
120    x = 0
121    self.assertRaises(ValueError, decorator_utils.validate_callable, x, "test")
122
123
124if __name__ == "__main__":
125  test.main()
126