1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for documentation parser."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import os
23import sys
24
25from tensorflow.python.platform import googletest
26from tensorflow.python.util import tf_inspect
27from tensorflow.tools.docs import parser
28
29
30def test_function(unused_arg, unused_kwarg='default'):
31  """Docstring for test function."""
32  pass
33
34
35def test_function_with_args_kwargs(unused_arg, *unused_args, **unused_kwargs):
36  """Docstring for second test function."""
37  pass
38
39
40class TestClass(object):
41  """Docstring for TestClass itself."""
42
43  def a_method(self, arg='default'):
44    """Docstring for a method."""
45    pass
46
47  class ChildClass(object):
48    """Docstring for a child class."""
49    pass
50
51  @property
52  def a_property(self):
53    """Docstring for a property."""
54    pass
55
56  CLASS_MEMBER = 'a class member'
57
58
59class DummyVisitor(object):
60
61  def __init__(self, index, duplicate_of):
62    self.index = index
63    self.duplicate_of = duplicate_of
64
65
66class ParserTest(googletest.TestCase):
67
68  def test_documentation_path(self):
69    self.assertEqual('test.md', parser.documentation_path('test'))
70    self.assertEqual('test/module.md', parser.documentation_path('test.module'))
71
72  def test_replace_references(self):
73    class HasOneMember(object):
74
75      def foo(self):
76        pass
77
78    string = (
79        'A @{tf.reference}, another @{tf.reference$with\nnewline}, a member '
80        '@{tf.reference.foo}, and a @{tf.third$link `text` with `code` in '
81        'it}.')
82    duplicate_of = {'tf.third': 'tf.fourth'}
83    index = {'tf.reference': HasOneMember,
84             'tf.reference.foo': HasOneMember.foo,
85             'tf.third': HasOneMember,
86             'tf.fourth': HasOneMember}
87
88    visitor = DummyVisitor(index, duplicate_of)
89
90    reference_resolver = parser.ReferenceResolver.from_visitor(
91        visitor=visitor, doc_index={}, py_module_names=['tf'])
92
93    result = reference_resolver.replace_references(string, '../..')
94    self.assertEqual('A <a href="../../tf/reference.md">'
95                     '<code>tf.reference</code></a>, '
96                     'another <a href="../../tf/reference.md">'
97                     'with\nnewline</a>, '
98                     'a member <a href="../../tf/reference.md#foo">'
99                     '<code>tf.reference.foo</code></a>, '
100                     'and a <a href="../../tf/fourth.md">link '
101                     '<code>text</code> with '
102                     '<code>code</code> in it</a>.', result)
103
104  def test_doc_replace_references(self):
105    string = '@{$doc1} @{$doc1#abc} @{$doc1$link} @{$doc1#def$zelda} @{$do/c2}'
106
107    class DocInfo(object):
108      pass
109    doc1 = DocInfo()
110    doc1.title = 'Title1'
111    doc1.url = 'URL1'
112    doc2 = DocInfo()
113    doc2.title = 'Two words'
114    doc2.url = 'somewhere/else'
115    doc_index = {'doc1': doc1, 'do/c2': doc2}
116
117    visitor = DummyVisitor(index={}, duplicate_of={})
118
119    reference_resolver = parser.ReferenceResolver.from_visitor(
120        visitor=visitor, doc_index=doc_index, py_module_names=['tf'])
121    result = reference_resolver.replace_references(string, 'python')
122    self.assertEqual('<a href="../URL1">Title1</a> '
123                     '<a href="../URL1#abc">Title1</a> '
124                     '<a href="../URL1">link</a> '
125                     '<a href="../URL1#def">zelda</a> '
126                     '<a href="../somewhere/else">Two words</a>', result)
127
128  def test_docs_for_class(self):
129
130    index = {
131        'TestClass': TestClass,
132        'TestClass.a_method': TestClass.a_method,
133        'TestClass.a_property': TestClass.a_property,
134        'TestClass.ChildClass': TestClass.ChildClass,
135        'TestClass.CLASS_MEMBER': TestClass.CLASS_MEMBER
136    }
137
138    visitor = DummyVisitor(index=index, duplicate_of={})
139
140    reference_resolver = parser.ReferenceResolver.from_visitor(
141        visitor=visitor, doc_index={}, py_module_names=['tf'])
142
143    tree = {
144        'TestClass': ['a_method', 'a_property', 'ChildClass', 'CLASS_MEMBER']
145    }
146    parser_config = parser.ParserConfig(
147        reference_resolver=reference_resolver,
148        duplicates={},
149        duplicate_of={},
150        tree=tree,
151        index=index,
152        reverse_index={},
153        guide_index={},
154        base_dir='/')
155
156    page_info = parser.docs_for_object(
157        full_name='TestClass', py_object=TestClass, parser_config=parser_config)
158
159    # Make sure the brief docstring is present
160    self.assertEqual(
161        tf_inspect.getdoc(TestClass).split('\n')[0], page_info.doc.brief)
162
163    # Make sure the method is present
164    self.assertEqual(TestClass.a_method, page_info.methods[0].obj)
165
166    # Make sure that the signature is extracted properly and omits self.
167    self.assertEqual(["arg='default'"], page_info.methods[0].signature)
168
169    # Make sure the property is present
170    self.assertIs(TestClass.a_property, page_info.properties[0].obj)
171
172    # Make sure there is a link to the child class and it points the right way.
173    self.assertIs(TestClass.ChildClass, page_info.classes[0].obj)
174
175    # Make sure this file is contained as the definition location.
176    self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
177
178  def test_docs_for_module(self):
179    # Get the current module.
180    module = sys.modules[__name__]
181
182    index = {
183        'TestModule': module,
184        'TestModule.test_function': test_function,
185        'TestModule.test_function_with_args_kwargs':
186        test_function_with_args_kwargs,
187        'TestModule.TestClass': TestClass,
188    }
189
190    visitor = DummyVisitor(index=index, duplicate_of={})
191
192    reference_resolver = parser.ReferenceResolver.from_visitor(
193        visitor=visitor, doc_index={}, py_module_names=['tf'])
194
195    tree = {
196        'TestModule': ['TestClass', 'test_function',
197                       'test_function_with_args_kwargs']
198    }
199    parser_config = parser.ParserConfig(
200        reference_resolver=reference_resolver,
201        duplicates={},
202        duplicate_of={},
203        tree=tree,
204        index=index,
205        reverse_index={},
206        guide_index={},
207        base_dir='/')
208
209    page_info = parser.docs_for_object(
210        full_name='TestModule', py_object=module, parser_config=parser_config)
211
212    # Make sure the brief docstring is present
213    self.assertEqual(tf_inspect.getdoc(module).split('\n')[0],
214                     page_info.doc.brief)
215
216    # Make sure that the members are there
217    funcs = {f_info.obj for f_info in page_info.functions}
218    self.assertEqual({test_function, test_function_with_args_kwargs}, funcs)
219
220    classes = {cls_info.obj for cls_info in page_info.classes}
221    self.assertEqual({TestClass}, classes)
222
223    # Make sure this file is contained as the definition location.
224    self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
225
226  def test_docs_for_function(self):
227    index = {
228        'test_function': test_function
229    }
230
231    visitor = DummyVisitor(index=index, duplicate_of={})
232
233    reference_resolver = parser.ReferenceResolver.from_visitor(
234        visitor=visitor, doc_index={}, py_module_names=['tf'])
235
236    tree = {
237        '': ['test_function']
238    }
239    parser_config = parser.ParserConfig(
240        reference_resolver=reference_resolver,
241        duplicates={},
242        duplicate_of={},
243        tree=tree,
244        index=index,
245        reverse_index={},
246        guide_index={},
247        base_dir='/')
248
249    page_info = parser.docs_for_object(
250        full_name='test_function',
251        py_object=test_function,
252        parser_config=parser_config)
253
254    # Make sure the brief docstring is present
255    self.assertEqual(
256        tf_inspect.getdoc(test_function).split('\n')[0], page_info.doc.brief)
257
258    # Make sure the extracted signature is good.
259    self.assertEqual(['unused_arg', "unused_kwarg='default'"],
260                     page_info.signature)
261
262    # Make sure this file is contained as the definition location.
263    self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
264
265  def test_docs_for_function_with_kwargs(self):
266    index = {
267        'test_function_with_args_kwargs': test_function_with_args_kwargs
268    }
269
270    visitor = DummyVisitor(index=index, duplicate_of={})
271
272    reference_resolver = parser.ReferenceResolver.from_visitor(
273        visitor=visitor, doc_index={}, py_module_names=['tf'])
274
275    tree = {
276        '': ['test_function_with_args_kwargs']
277    }
278    parser_config = parser.ParserConfig(
279        reference_resolver=reference_resolver,
280        duplicates={},
281        duplicate_of={},
282        tree=tree,
283        index=index,
284        reverse_index={},
285        guide_index={},
286        base_dir='/')
287
288    page_info = parser.docs_for_object(
289        full_name='test_function_with_args_kwargs',
290        py_object=test_function_with_args_kwargs,
291        parser_config=parser_config)
292
293    # Make sure the brief docstring is present
294    self.assertEqual(
295        tf_inspect.getdoc(test_function_with_args_kwargs).split('\n')[0],
296        page_info.doc.brief)
297
298    # Make sure the extracted signature is good.
299    self.assertEqual(['unused_arg', '*unused_args', '**unused_kwargs'],
300                     page_info.signature)
301
302  def test_parse_md_docstring(self):
303
304    def test_function_with_fancy_docstring(arg):
305      """Function with a fancy docstring.
306
307      And a bunch of references: @{tf.reference}, another @{tf.reference},
308          a member @{tf.reference.foo}, and a @{tf.third}.
309
310      Args:
311        arg: An argument.
312
313      Raises:
314        an exception
315
316      Returns:
317        arg: the input, and
318        arg: the input, again.
319
320      @compatibility(numpy)
321      NumPy has nothing as awesome as this function.
322      @end_compatibility
323
324      @compatibility(theano)
325      Theano has nothing as awesome as this function.
326
327      Check it out.
328      @end_compatibility
329
330      """
331      return arg, arg
332
333    class HasOneMember(object):
334
335      def foo(self):
336        pass
337
338    duplicate_of = {'tf.third': 'tf.fourth'}
339    index = {
340        'tf.fancy': test_function_with_fancy_docstring,
341        'tf.reference': HasOneMember,
342        'tf.reference.foo': HasOneMember.foo,
343        'tf.third': HasOneMember,
344        'tf.fourth': HasOneMember
345    }
346
347    visitor = DummyVisitor(index=index, duplicate_of=duplicate_of)
348
349    reference_resolver = parser.ReferenceResolver.from_visitor(
350        visitor=visitor, doc_index={}, py_module_names=['tf'])
351
352    doc_info = parser._parse_md_docstring(test_function_with_fancy_docstring,
353                                          '../..', reference_resolver)
354
355    self.assertNotIn('@', doc_info.docstring)
356    self.assertNotIn('compatibility', doc_info.docstring)
357    self.assertNotIn('Raises:', doc_info.docstring)
358
359    self.assertEqual(len(doc_info.function_details), 3)
360    self.assertEqual(set(doc_info.compatibility.keys()), {'numpy', 'theano'})
361
362    self.assertEqual(doc_info.compatibility['numpy'],
363                     'NumPy has nothing as awesome as this function.\n')
364
365  def test_generate_index(self):
366    module = sys.modules[__name__]
367
368    index = {
369        'TestModule': module,
370        'test_function': test_function,
371        'TestModule.test_function': test_function,
372        'TestModule.TestClass': TestClass,
373        'TestModule.TestClass.a_method': TestClass.a_method,
374        'TestModule.TestClass.a_property': TestClass.a_property,
375        'TestModule.TestClass.ChildClass': TestClass.ChildClass,
376    }
377    duplicate_of = {
378        'TestModule.test_function': 'test_function'
379    }
380
381    visitor = DummyVisitor(index=index, duplicate_of=duplicate_of)
382
383    reference_resolver = parser.ReferenceResolver.from_visitor(
384        visitor=visitor, doc_index={}, py_module_names=['tf'])
385
386    docs = parser.generate_global_index('TestLibrary', index=index,
387                                        reference_resolver=reference_resolver)
388
389    # Make sure duplicates and non-top-level symbols are in the index, but
390    # methods and properties are not.
391    self.assertNotIn('a_method', docs)
392    self.assertNotIn('a_property', docs)
393    self.assertIn('TestModule.TestClass', docs)
394    self.assertIn('TestModule.TestClass.ChildClass', docs)
395    self.assertIn('TestModule.test_function', docs)
396    # Leading backtick to make sure it's included top-level.
397    # This depends on formatting, but should be stable.
398    self.assertIn('<code>test_function', docs)
399
400  def test_argspec_for_functools_partial(self):
401
402    # pylint: disable=unused-argument
403    def test_function_for_partial1(arg1, arg2, kwarg1=1, kwarg2=2):
404      pass
405
406    def test_function_for_partial2(arg1, arg2, *my_args, **my_kwargs):
407      pass
408    # pylint: enable=unused-argument
409
410    # pylint: disable=protected-access
411    # Make sure everything works for regular functions.
412    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None,
413                                  None, (1, 2))
414    self.assertEqual(expected, parser._get_arg_spec(test_function_for_partial1))
415
416    # Make sure doing nothing works.
417    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None,
418                                  None, (1, 2))
419    partial = functools.partial(test_function_for_partial1)
420    self.assertEqual(expected, parser._get_arg_spec(partial))
421
422    # Make sure setting args from the front works.
423    expected = tf_inspect.ArgSpec(['arg2', 'kwarg1', 'kwarg2'], None, None,
424                                  (1, 2))
425    partial = functools.partial(test_function_for_partial1, 1)
426    self.assertEqual(expected, parser._get_arg_spec(partial))
427
428    expected = tf_inspect.ArgSpec(['kwarg2',], None, None, (2,))
429    partial = functools.partial(test_function_for_partial1, 1, 2, 3)
430    self.assertEqual(expected, parser._get_arg_spec(partial))
431
432    # Make sure setting kwargs works.
433    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg2'], None, None, (2,))
434    partial = functools.partial(test_function_for_partial1, kwarg1=0)
435    self.assertEqual(expected, parser._get_arg_spec(partial))
436
437    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1'], None, None, (1,))
438    partial = functools.partial(test_function_for_partial1, kwarg2=0)
439    self.assertEqual(expected, parser._get_arg_spec(partial))
440
441    expected = tf_inspect.ArgSpec(['arg1'], None, None, ())
442    partial = functools.partial(test_function_for_partial1,
443                                arg2=0, kwarg1=0, kwarg2=0)
444    self.assertEqual(expected, parser._get_arg_spec(partial))
445
446    # Make sure *args, *kwargs is accounted for.
447    expected = tf_inspect.ArgSpec([], 'my_args', 'my_kwargs', ())
448    partial = functools.partial(test_function_for_partial2, 0, 1)
449    self.assertEqual(expected, parser._get_arg_spec(partial))
450
451    # pylint: enable=protected-access
452
453  def testSaveReferenceResolver(self):
454    you_cant_serialize_this = object()
455
456    duplicate_of = {'AClass': ['AClass2']}
457    doc_index = {'doc': you_cant_serialize_this}
458    is_class = {
459        'tf': False,
460        'tf.AClass': True,
461        'tf.AClass2': True,
462        'tf.function': False
463    }
464    is_module = {
465        'tf': True,
466        'tf.AClass': False,
467        'tf.AClass2': False,
468        'tf.function': False
469    }
470    py_module_names = ['tf', 'tfdbg']
471
472    resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_class,
473                                        is_module, py_module_names)
474
475    outdir = googletest.GetTempDir()
476
477    filepath = os.path.join(outdir, 'resolver.json')
478
479    resolver.to_json_file(filepath)
480    resolver2 = parser.ReferenceResolver.from_json_file(filepath, doc_index)
481
482    # There are no __slots__, so all fields are visible in __dict__.
483    self.assertEqual(resolver.__dict__, resolver2.__dict__)
484
485RELU_DOC = """Computes rectified linear: `max(features, 0)`
486
487Args:
488  features: A `Tensor`. Must be one of the following types: `float32`,
489    `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`, `uint16`,
490    `half`.
491  name: A name for the operation (optional)
492
493Returns:
494  A `Tensor`. Has the same type as `features`
495"""
496
497
498class TestParseFunctionDetails(googletest.TestCase):
499
500  def test_parse_function_details(self):
501    docstring, function_details = parser._parse_function_details(RELU_DOC)
502
503    self.assertEqual(len(function_details), 2)
504    args = function_details[0]
505    self.assertEqual(args.keyword, 'Args')
506    self.assertEqual(len(args.header), 0)
507    self.assertEqual(len(args.items), 2)
508    self.assertEqual(args.items[0][0], 'features')
509    self.assertEqual(args.items[1][0], 'name')
510    self.assertEqual(args.items[1][1],
511                     'A name for the operation (optional)\n\n')
512    returns = function_details[1]
513    self.assertEqual(returns.keyword, 'Returns')
514
515    relu_doc_lines = RELU_DOC.split('\n')
516    self.assertEqual(docstring, relu_doc_lines[0] + '\n\n')
517    self.assertEqual(returns.header, relu_doc_lines[-2] + '\n')
518
519    self.assertEqual(
520        RELU_DOC,
521        docstring + ''.join(str(detail) for detail in function_details))
522
523
524class TestGenerateSignature(googletest.TestCase):
525
526  def test_known_object(self):
527    if sys.version_info >= (3, 0):
528      print('Warning: Doc generation is not supported from python3.')
529      return
530
531    known_object = object()
532    reverse_index = {id(known_object): 'location.of.object.in.api'}
533
534    def example_fun(arg=known_object):  # pylint: disable=unused-argument
535      pass
536
537    sig = parser._generate_signature(example_fun, reverse_index)
538    self.assertEqual(sig, ['arg=location.of.object.in.api'])
539
540  def test_literals(self):
541    if sys.version_info >= (3, 0):
542      print('Warning: Doc generation is not supported from python3.')
543      return
544
545    def example_fun(a=5, b=5.0, c=None, d=True, e='hello', f=(1, (2, 3))):  # pylint: disable=g-bad-name, unused-argument
546      pass
547
548    sig = parser._generate_signature(example_fun, reverse_index={})
549    self.assertEqual(
550        sig, ['a=5', 'b=5.0', 'c=None', 'd=True', "e='hello'", 'f=(1, (2, 3))'])
551
552  def test_dotted_name(self):
553    if sys.version_info >= (3, 0):
554      print('Warning: Doc generation is not supported from python3.')
555      return
556
557    # pylint: disable=g-bad-name
558    class a(object):
559
560      class b(object):
561
562        class c(object):
563
564          class d(object):
565
566            def __init__(self, *args):
567              pass
568    # pylint: enable=g-bad-name
569
570    e = {'f': 1}
571
572    def example_fun(arg1=a.b.c.d, arg2=a.b.c.d(1, 2), arg3=e['f']):  # pylint: disable=unused-argument
573      pass
574
575    sig = parser._generate_signature(example_fun, reverse_index={})
576    self.assertEqual(sig, ['arg1=a.b.c.d', 'arg2=a.b.c.d(1, 2)', "arg3=e['f']"])
577
578
579if __name__ == '__main__':
580  googletest.main()
581