1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for conversion module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gast 22 23from tensorflow.contrib.py2tf.impl import conversion 24from tensorflow.python.platform import test 25 26 27class ConversionTest(test.TestCase): 28 29 def test_entity_to_graph_unsupported_types(self): 30 with self.assertRaises(ValueError): 31 conversion_map = conversion.ConversionMap(True, (), (), None) 32 conversion.entity_to_graph('dummy', conversion_map, None, None) 33 34 def test_entity_to_graph_callable(self): 35 36 def f(a): 37 return a 38 39 conversion_map = conversion.ConversionMap(True, (), (), None) 40 ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None) 41 self.assertTrue(isinstance(ast, gast.FunctionDef), ast) 42 self.assertEqual('tf__f', new_name) 43 44 def test_entity_to_graph_call_tree(self): 45 46 def g(a): 47 return a 48 49 def f(a): 50 return g(a) 51 52 conversion_map = conversion.ConversionMap(True, (), (), None) 53 conversion.entity_to_graph(f, conversion_map, None, None) 54 55 self.assertTrue(f in conversion_map.dependency_cache) 56 self.assertTrue(g in conversion_map.dependency_cache) 57 self.assertEqual('tf__f', conversion_map.dependency_cache[f].name) 58 self.assertEqual( 59 'tf__g', conversion_map.dependency_cache[f].body[0].value.func.id) 60 self.assertEqual('tf__g', conversion_map.dependency_cache[g].name) 61 62 63if __name__ == '__main__': 64 test.main() 65