1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.eager import backprop
21from tensorflow.python.eager import graph_callable
22from tensorflow.python.eager import test
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import function
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import variable_scope
30
31
32class GraphCallableTest(test.TestCase):
33
34  def testBasic(self):
35
36    @graph_callable.graph_callable(
37        [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
38    def my_function(x):
39      v = variable_scope.get_variable(
40          "v", initializer=init_ops.zeros_initializer(), shape=())
41      return v + x
42
43    self.assertEqual(
44        2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
45
46    my_function.variables[0].assign(1.)
47    self.assertEqual(
48        3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
49
50  def testFunctionWithoutReturnValue(self):
51
52    @graph_callable.graph_callable(
53        [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
54    def my_function(x):
55      v = variable_scope.get_variable(
56          "v", initializer=init_ops.zeros_initializer(), shape=())
57      v.assign(x)
58
59    my_function(constant_op.constant(4, dtype=dtypes.float32))
60    self.assertAllEqual(4, my_function.variables[0].read_value())
61
62  def testFunctionWithoutReturnValueAndArgs(self):
63
64    @graph_callable.graph_callable([])
65    def my_function():
66      v = variable_scope.get_variable(
67          "v", initializer=init_ops.zeros_initializer(), shape=())
68      v.assign(4)
69
70    my_function()
71    self.assertAllEqual(4, my_function.variables[0].read_value())
72
73  def testVariableAPI(self):
74
75    @graph_callable.graph_callable(
76        [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
77    def my_function(x):
78      v = variable_scope.get_variable(
79          "v", initializer=init_ops.zeros_initializer(), shape=())
80      return v.read_value() + x
81
82    self.assertEqual(
83        2, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
84
85    my_function.variables[0].assign(1.)
86    self.assertEqual(
87        3, my_function(constant_op.constant(2, dtype=dtypes.float32)).numpy())
88
89  def testTensorShape(self):
90
91    @graph_callable.graph_callable(
92        [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)])
93    def my_function(x):
94      _ = x.get_shape()
95      v = variable_scope.get_variable(
96          "v", initializer=init_ops.zeros_initializer(), shape=[x.shape[0]])
97      self.assertEqual(v.shape[0], x.shape[0])
98      return v + x
99
100    self.assertEqual([2.],
101                     my_function(
102                         constant_op.constant([2.],
103                                              dtype=dtypes.float32)).numpy())
104
105  def testUpdatesAreOrdered(self):
106
107    @graph_callable.graph_callable(
108        [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
109    def my_function(x):
110      v = variable_scope.get_variable(
111          "v", initializer=init_ops.zeros_initializer(), shape=())
112      v.assign(x + 1)
113      v.assign(v * x)
114      return v.read_value()
115
116    self.assertAllEqual(my_function(constant_op.constant(2.0)), 6.0)
117
118  def testEmptyInitializer(self):
119
120    @graph_callable.graph_callable(
121        [graph_callable.ShapeAndDtype(shape=(1), dtype=dtypes.float32)])
122    def my_function(x):
123      v = variable_scope.get_variable("v", shape=[1])
124      return x + 0 * v
125
126    self.assertEqual([2.],
127                     my_function(
128                         constant_op.constant([2.],
129                                              dtype=dtypes.float32)).numpy())
130
131  def testMismatchingNumArgs(self):
132    # pylint: disable=anomalous-backslash-in-string
133    with self.assertRaisesRegexp(TypeError,
134                                 "The number of arguments accepted by the "
135                                 "decorated function `my_function` \(2\) must "
136                                 "match the number of ShapeAndDtype objects "
137                                 "passed to the graph_callable\(\) decorator "
138                                 "\(1\)."):
139      @graph_callable.graph_callable([
140          graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
141      def my_function(x, y):  # pylint: disable=unused-variable
142        return x + y
143    # pylint: enable=anomalous-backslash-in-string
144
145  def testPureFunction(self):
146
147    @graph_callable.graph_callable(
148        [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
149    def f(x):
150      return math_ops.add(x, constant_op.constant(3))
151
152    self.assertAllEqual(5, f(constant_op.constant(2)))
153
154  def testNestedFunction(self):
155    # TensorFlow function (which is what would be used in TensorFlow graph
156    # construction).
157    @function.Defun(dtypes.int32, dtypes.int32)
158    def add(a, b):
159      return math_ops.add(a, b)
160
161    # A graph_callable that will invoke the TensorFlow function.
162    @graph_callable.graph_callable(
163        [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
164    def add_one(x):
165      return add(x, 1)
166
167    self.assertAllEqual(3, add_one(constant_op.constant(2)))
168
169  # TODO(ashankar): Make this work.
170  # The problem is that the two graph_callables (for add_one and add_two)
171  # are both trying to register the FunctionDef corresponding to "add".
172  def DISABLED_testRepeatedUseOfSubFunction(self):
173
174    @function.Defun(dtypes.int32, dtypes.int32)
175    def add(a, b):
176      return math_ops.add(a, b)
177
178    @graph_callable.graph_callable(
179        [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
180    def add_one(x):
181      return add(x, 1)
182
183    @graph_callable.graph_callable(
184        [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.int32)])
185    def add_two(x):
186      return add(x, 2)
187
188    two = constant_op.constant(2)
189    self.assertAllEqual(3, add_one(two))
190    self.assertAllEqual(4, add_two(two))
191
192  def testNestedSequenceInputs(self):
193    sd = graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)
194    @graph_callable.graph_callable([[sd, tuple([sd, sd]), sd]])
195    def my_op(inputs):
196      a, b, c = inputs
197      e, f = b
198      v = variable_scope.get_variable(
199          "my_v", initializer=init_ops.zeros_initializer(), shape=())
200      return [a + a + v, tuple([e + e, f + f]), c + c], a + e + f + c + v
201
202    inputs = [constant_op.constant(1.),
203              [constant_op.constant(2.), constant_op.constant(3.)],
204              constant_op.constant(4.)]
205    ret = my_op(inputs)
206    self.assertEqual(len(ret), 2.)
207    self.assertAllEqual(ret[1], 10.)
208
209    my_op.variables[0].assign(1.)
210    ret = my_op(inputs)
211    self.assertAllEqual(ret[1], 11.)
212
213  def testVariableShapeIsTensorShape(self):
214    @graph_callable.graph_callable([])
215    def my_function():
216      v = variable_scope.get_variable(
217          "v", initializer=init_ops.zeros_initializer(), shape=())
218      self.assertIsInstance(v.get_shape(), tensor_shape.TensorShape)
219
220    my_function()
221
222  def testIncorrectlyShapedInputs(self):
223    @graph_callable.graph_callable(
224        [graph_callable.ShapeAndDtype(shape=(3), dtype=dtypes.float32)])
225    def my_function(x):
226      v = variable_scope.get_variable(
227          "v", initializer=init_ops.zeros_initializer(), shape=())
228      return v + x
229
230    with self.assertRaises(ValueError):
231      my_function([1, 2])
232
233    self.assertTrue(([1, 2, 3] == my_function(
234        constant_op.constant([1, 2, 3], dtype=dtypes.float32)).numpy()).all())
235
236  def testGradients(self):
237    @graph_callable.graph_callable([])
238    def my_function():
239      v = variable_scope.get_variable(
240          "v", initializer=init_ops.constant_initializer(3.), shape=())
241      return v * v
242
243    grad_fn = backprop.implicit_grad(my_function)
244    grads_and_vars = list(zip(*grad_fn()))
245    self.assertAllEqual(6., grads_and_vars[0][0])
246
247
248if __name__ == "__main__":
249  test.main()
250