1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf upgrader.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20import os 21import tempfile 22import six 23from tensorflow.python.framework import test_util 24from tensorflow.python.platform import test as test_lib 25from tensorflow.tools.compatibility import tf_upgrade 26 27 28class TestUpgrade(test_util.TensorFlowTestCase): 29 """Test various APIs that have been changed in 1.0. 30 31 We also test whether a converted file is executable. test_file_v0_11.py 32 aims to exhaustively test that API changes are convertible and actually 33 work when run with current TensorFlow. 34 """ 35 36 def _upgrade(self, old_file_text): 37 in_file = six.StringIO(old_file_text) 38 out_file = six.StringIO() 39 upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) 40 count, report, errors = ( 41 upgrader.process_opened_file("test.py", in_file, 42 "test_out.py", out_file)) 43 return count, report, errors, out_file.getvalue() 44 45 def testParseError(self): 46 _, report, unused_errors, unused_new_text = self._upgrade( 47 "import tensorflow as tf\na + \n") 48 self.assertTrue(report.find("Failed to parse") != -1) 49 50 def testReport(self): 51 text = "tf.mul(a, b)\n" 52 _, report, unused_errors, unused_new_text = self._upgrade(text) 53 # This is not a complete test, but it is a sanity test that a report 54 # is generating information. 55 self.assertTrue(report.find("Renamed function `tf.mul` to `tf.multiply`")) 56 57 def testRename(self): 58 text = "tf.mul(a, tf.sub(b, c))\n" 59 _, unused_report, unused_errors, new_text = self._upgrade(text) 60 self.assertEqual(new_text, "tf.multiply(a, tf.subtract(b, c))\n") 61 62 def testRenamePack(self): 63 text = "tf.pack(a)\n" 64 _, unused_report, unused_errors, new_text = self._upgrade(text) 65 self.assertEqual(new_text, "tf.stack(a)\n") 66 text = "tf.unpack(a)\n" 67 _, unused_report, unused_errors, new_text = self._upgrade(text) 68 self.assertEqual(new_text, "tf.unstack(a)\n") 69 70 def testReorder(self): 71 text = "tf.concat(a, b)\ntf.split(a, b, c)\n" 72 _, unused_report, unused_errors, new_text = self._upgrade(text) 73 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n" 74 "tf.split(axis=a, num_or_size_splits=b, value=c)\n") 75 76 def testConcatReorderWithKeywordArgs(self): 77 text = "tf.concat(concat_dim=a, values=b)\n" 78 _, unused_report, unused_errors, new_text = self._upgrade(text) 79 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") 80 text = "tf.concat(values=b, concat_dim=a)\n" 81 _, unused_report, unused_errors, new_text = self._upgrade(text) 82 self.assertEqual(new_text, "tf.concat(values=b, axis=a)\n") 83 text = "tf.concat(a, values=b)\n" 84 _, unused_report, unused_errors, new_text = self._upgrade(text) 85 self.assertEqual(new_text, "tf.concat(axis=a, values=b)\n") 86 87 def testConcatReorderNested(self): 88 text = "tf.concat(a, tf.concat(c, d))\n" 89 _, unused_report, unused_errors, new_text = self._upgrade(text) 90 self.assertEqual( 91 new_text, "tf.concat(axis=a, values=tf.concat(axis=c, values=d))\n") 92 93 def testInitializers(self): 94 text = ("tf.zeros_initializer;tf.zeros_initializer ()\n" 95 "tf.ones_initializer;tf.ones_initializer ()\n") 96 _, unused_report, unused_errors, new_text = self._upgrade(text) 97 self.assertEqual( 98 new_text, "tf.zeros_initializer();tf.zeros_initializer ()\n" 99 "tf.ones_initializer();tf.ones_initializer ()\n") 100 101 def testKeyword(self): 102 text = "tf.reduce_any(a, reduction_indices=[1, 2])\n" 103 _, unused_report, unused_errors, new_text = self._upgrade(text) 104 self.assertEqual(new_text, "tf.reduce_any(a, axis=[1, 2])\n") 105 106 def testComplexExpression(self): 107 text = "(foo + bar)[a].word()" 108 _ = self._upgrade(text) 109 110 def testReverse(self): 111 text = "tf.reverse(a, b)\n" 112 _, unused_report, errors, new_text = self._upgrade(text) 113 self.assertEqual(new_text, new_text) 114 self.assertEqual(errors, ["test.py:1: tf.reverse requires manual check."]) 115 116 def testListComprehension(self): 117 def _test(input, output): # pylint: disable=redefined-builtin 118 _, unused_report, errors, new_text = self._upgrade(input) 119 self.assertEqual(new_text, output) 120 _test("tf.concat(0, \t[x for x in y])\n", 121 "tf.concat(axis=0, \tvalues=[x for x in y])\n") 122 _test("tf.concat(0,[x for x in y])\n", 123 "tf.concat(axis=0,values=[x for x in y])\n") 124 _test("tf.concat(0,[\nx for x in y])\n", 125 "tf.concat(axis=0,values=[\nx for x in y])\n") 126 _test("tf.concat(0,[\n \tx for x in y])\n", 127 "tf.concat(axis=0,values=[\n \tx for x in y])\n") 128 129 # TODO(aselle): Explicitly not testing command line interface and process_tree 130 # for now, since this is a one off utility. 131 132 133class TestUpgradeFiles(test_util.TensorFlowTestCase): 134 135 def testInplace(self): 136 """Check to make sure we don't have a file system race.""" 137 temp_file = tempfile.NamedTemporaryFile("w", delete=False) 138 original = "tf.mul(a, b)\n" 139 upgraded = "tf.multiply(a, b)\n" 140 temp_file.write(original) 141 temp_file.close() 142 upgrader = tf_upgrade.ASTCodeUpgrader(tf_upgrade.TFAPIChangeSpec()) 143 upgrader.process_file(temp_file.name, temp_file.name) 144 self.assertAllEqual(open(temp_file.name).read(), upgraded) 145 os.unlink(temp_file.name) 146 147 148if __name__ == "__main__": 149 test_lib.main() 150