1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Eager Execution: Sanity tests."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import tempfile
21
22from tensorflow.contrib.eager.python import tfe
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import numerics
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import test
32from tensorflow.python.summary import summary
33from tensorflow.python.summary.writer import writer
34
35
36class TFETest(test_util.TensorFlowTestCase):
37
38  def testMatmul(self):
39    x = [[2.]]
40    y = math_ops.matmul(x, x)  # tf.matmul
41    self.assertAllEqual([[4.]], y.numpy())
42
43  def testInstantError(self):
44    with self.assertRaisesRegexp(errors.InvalidArgumentError,
45                                 r'indices = 7 is not in \[0, 3\)'):
46      array_ops.gather([0, 1, 2], 7)
47
48  def testVariableError(self):
49    with self.assertRaisesRegexp(
50        RuntimeError, r'Variable not supported in Eager mode'):
51      variables.Variable(initial_value=1.0)
52
53  def testGradients(self):
54
55    def square(x):
56      return math_ops.multiply(x, x)
57
58    grad = tfe.gradients_function(square)
59    self.assertEquals([6], [x.numpy() for x in grad(3)])
60
61  def testGradOfGrad(self):
62
63    def square(x):
64      return math_ops.multiply(x, x)
65
66    grad = tfe.gradients_function(square)
67    gradgrad = tfe.gradients_function(lambda x: grad(x)[0])
68    self.assertEquals([2], [x.numpy() for x in gradgrad(3)])
69
70  def testCustomGrad(self):
71
72    @tfe.custom_gradient
73    def f(x):
74      y = math_ops.multiply(x, x)
75
76      def grad_fn(_):
77        return [x + y]
78
79      return y, grad_fn
80
81    grad = tfe.gradients_function(f)
82    self.assertEquals([12], [x.numpy() for x in grad(3)])
83
84  def testGPU(self):
85    if tfe.num_gpus() <= 0:
86      self.skipTest('No GPUs available')
87
88    # tf.Tensor.as_gpu_device() moves a tensor to GPU.
89    x = constant_op.constant([[1., 2.], [3., 4.]]).gpu()
90    # Alternatively, tf.device() as a context manager places tensors and
91    # operations.
92    with ops.device('gpu:0'):
93      x += 1.
94    # Without a device context, heuristics are used to place ops.
95    # In this case, ops.reduce_mean runs on the GPU.
96    reduction_indices = range(x.shape.ndims)
97    m = math_ops.reduce_mean(x, reduction_indices)
98    # m is on GPU, bring it back to CPU and compare.
99    self.assertEqual(3.5, m.cpu().numpy())
100
101  def testListDevices(self):
102    # Expect at least one device.
103    self.assertTrue(tfe.list_devices())
104
105  def testAddCheckNumericsOpsRaisesError(self):
106    with self.assertRaisesRegexp(
107        RuntimeError,
108        r'add_check_numerics_ops\(\) is not compatible with eager execution'):
109      numerics.add_check_numerics_ops()
110
111  def testClassicSummaryOpsErrorOut(self):
112    x = constant_op.constant(42)
113    x_summary = summary.scalar('x', x)
114    y = constant_op.constant([1, 3, 3, 7])
115    y_summary = summary.histogram('hist', y)
116
117    with self.assertRaisesRegexp(
118        RuntimeError,
119        r'Merging tf\.summary\.\* ops is not compatible with eager execution'):
120      summary.merge([x_summary, y_summary])
121
122    with self.assertRaisesRegexp(
123        RuntimeError,
124        r'Merging tf\.summary\.\* ops is not compatible with eager execution'):
125      summary.merge_all()
126
127  def testClassicSummaryFileWriterErrorsOut(self):
128    with self.assertRaisesRegexp(
129        RuntimeError,
130        r'tf\.summary\.FileWriter is not compatible with eager execution'):
131      writer.FileWriter(tempfile.mkdtemp())
132
133
134if __name__ == '__main__':
135  tfe.enable_eager_execution()
136  test.main()
137