1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.eager import backprop 21from tensorflow.python.eager import graph_callable 22from tensorflow.python.eager import test 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import function 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import init_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import variable_scope 30 31 32class GraphCallableTest(test.TestCase): 33 34 def testBasic(self): 35 36 @graph_callable.graph_callable( 37 [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) 38 def my_function(x): 39 v = variable_scope.get_variable( 40 "v", initializer=init_ops.zeros_initializer(), shape=()) 41 return v + x 42 43 self.assertEqual( 44 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy()) 45 46 my_function.variables[0].assign(1.) 47 self.assertEqual( 48 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy()) 49 50 def testFunctionWithoutReturnValue(self): 51 52 @graph_callable.graph_callable( 53 [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) 54 def my_function(x): 55 v = variable_scope.get_variable( 56 "v", initializer=init_ops.zeros_initializer(), shape=()) 57 v.assign(x) 58 59 my_function(constant_op.constant(4, dtype=dtypes.float32)) 60 self.assertAllEqual(4, my_function.variables[0].read_value()) 61 62 def testFunctionWithoutReturnValueAndArgs(self): 63 64 @graph_callable.graph_callable([]) 65 def my_function(): 66 v = variable_scope.get_variable( 67 "v", initializer=init_ops.zeros_initializer(), shape=()) 68 v.assign(4) 69 70 my_function() 71 self.assertAllEqual(4, my_function.variables[0].read_value()) 72 73 def testVariableAPI(self): 74 75 @graph_callable.graph_callable( 76 [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) 77 def my_function(x): 78 v = variable_scope.get_variable( 79 "v", initializer=init_ops.zeros_initializer(), shape=()) 80 return v.read_value() + x 81 82 self.assertEqual( 83 2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy()) 84 85 my_function.variables[0].assign(1.) 86 self.assertEqual( 87 3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy()) 88 89 def testTensorShape(self): 90 91 @graph_callable.graph_callable( 92 [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)]) 93 def my_function(x): 94 _ = x.get_shape() 95 v = variable_scope.get_variable( 96 "v", initializer=init_ops.zeros_initializer(), shape=[x.shape[0]]) 97 self.assertEqual(v.shape[0], x.shape[0]) 98 return v + x 99 100 self.assertEqual([2.], 101 my_function( 102 constant_op.constant([2.], 103 dtype=dtypes.float32)).numpy()) 104 105 def testUpdatesAreOrdered(self): 106 107 @graph_callable.graph_callable( 108 [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) 109 def my_function(x): 110 v = variable_scope.get_variable( 111 "v", initializer=init_ops.zeros_initializer(), shape=()) 112 v.assign(x + 1) 113 v.assign(v * x) 114 return v.read_value() 115 116 self.assertAllEqual(my_function(constant_op.constant(2.0)), 6.0) 117 118 def testEmptyInitializer(self): 119 120 @graph_callable.graph_callable( 121 [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)]) 122 def my_function(x): 123 v = variable_scope.get_variable("v", shape=[1]) 124 return x + 0 * v 125 126 self.assertEqual([2.], 127 my_function( 128 constant_op.constant([2.], 129 dtype=dtypes.float32)).numpy()) 130 131 def testMismatchingNumArgs(self): 132 # pylint: disable=anomalous-backslash-in-string 133 with self.assertRaisesRegexp(TypeError, 134 "The number of arguments accepted by the " 135 "decorated function `my_function` \(2\) must " 136 "match the number of ShapeAndDtype objects " 137 "passed to the graph_callable\(\) decorator " 138 "\(1\)."): 139 @graph_callable.graph_callable([ 140 graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) 141 def my_function(x, y): # pylint: disable=unused-variable 142 return x + y 143 # pylint: enable=anomalous-backslash-in-string 144 145 def testPureFunction(self): 146 147 @graph_callable.graph_callable( 148 [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) 149 def f(x): 150 return math_ops.add(x, constant_op.constant(3)) 151 152 self.assertAllEqual(5, f(constant_op.constant(2))) 153 154 def testNestedFunction(self): 155 # TensorFlow function (which is what would be used in TensorFlow graph 156 # construction). 157 @function.Defun(dtypes.int32, dtypes.int32) 158 def add(a, b): 159 return math_ops.add(a, b) 160 161 # A graph_callable that will invoke the TensorFlow function. 162 @graph_callable.graph_callable( 163 [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) 164 def add_one(x): 165 return add(x, 1) 166 167 self.assertAllEqual(3, add_one(constant_op.constant(2))) 168 169 # TODO(ashankar): Make this work. 170 # The problem is that the two graph_callables (for add_one and add_two) 171 # are both trying to register the FunctionDef corresponding to "add". 172 def DISABLED_testRepeatedUseOfSubFunction(self): 173 174 @function.Defun(dtypes.int32, dtypes.int32) 175 def add(a, b): 176 return math_ops.add(a, b) 177 178 @graph_callable.graph_callable( 179 [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) 180 def add_one(x): 181 return add(x, 1) 182 183 @graph_callable.graph_callable( 184 [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)]) 185 def add_two(x): 186 return add(x, 2) 187 188 two = constant_op.constant(2) 189 self.assertAllEqual(3, add_one(two)) 190 self.assertAllEqual(4, add_two(two)) 191 192 def testNestedSequenceInputs(self): 193 sd = graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32) 194 @graph_callable.graph_callable([[sd, tuple([sd, sd]), sd]]) 195 def my_op(inputs): 196 a, b, c = inputs 197 e, f = b 198 v = variable_scope.get_variable( 199 "my_v", initializer=init_ops.zeros_initializer(), shape=()) 200 return [a + a + v, tuple([e + e, f + f]), c + c], a + e + f + c + v 201 202 inputs = [constant_op.constant(1.), 203 [constant_op.constant(2.), constant_op.constant(3.)], 204 constant_op.constant(4.)] 205 ret = my_op(inputs) 206 self.assertEqual(len(ret), 2.) 207 self.assertAllEqual(ret[1], 10.) 208 209 my_op.variables[0].assign(1.) 210 ret = my_op(inputs) 211 self.assertAllEqual(ret[1], 11.) 212 213 def testVariableShapeIsTensorShape(self): 214 @graph_callable.graph_callable([]) 215 def my_function(): 216 v = variable_scope.get_variable( 217 "v", initializer=init_ops.zeros_initializer(), shape=()) 218 self.assertIsInstance(v.get_shape(), tensor_shape.TensorShape) 219 220 my_function() 221 222 def testIncorrectlyShapedInputs(self): 223 @graph_callable.graph_callable( 224 [graph_callable.ShapeAndDtype(shape=(3), dtype=dtypes.float32)]) 225 def my_function(x): 226 v = variable_scope.get_variable( 227 "v", initializer=init_ops.zeros_initializer(), shape=()) 228 return v + x 229 230 with self.assertRaises(ValueError): 231 my_function([1, 2]) 232 233 self.assertTrue(([1, 2, 3] == my_function( 234 constant_op.constant([1, 2, 3], dtype=dtypes.float32)).numpy()).all()) 235 236 def testGradients(self): 237 @graph_callable.graph_callable([]) 238 def my_function(): 239 v = variable_scope.get_variable( 240 "v", initializer=init_ops.constant_initializer(3.), shape=()) 241 return v * v 242 243 grad_fn = backprop.implicit_grad(my_function) 244 grads_and_vars = list(zip(*grad_fn())) 245 self.assertAllEqual(6., grads_and_vars[0][0]) 246 247 248if __name__ == "__main__": 249 test.main() 250