1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow Eager Execution: Sanity tests.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import tempfile 21 22from tensorflow.contrib.eager.python import tfe 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import numerics 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import test 32from tensorflow.python.summary import summary 33from tensorflow.python.summary.writer import writer 34 35 36class TFETest(test_util.TensorFlowTestCase): 37 38 def testMatmul(self): 39 x = [[2.]] 40 y = math_ops.matmul(x, x) # tf.matmul 41 self.assertAllEqual([[4.]], y.numpy()) 42 43 def testInstantError(self): 44 with self.assertRaisesRegexp(errors.InvalidArgumentError, 45 r'indices = 7 is not in \[0, 3\)'): 46 array_ops.gather([0, 1, 2], 7) 47 48 def testVariableError(self): 49 with self.assertRaisesRegexp( 50 RuntimeError, r'Variable not supported in Eager mode'): 51 variables.Variable(initial_value=1.0) 52 53 def testGradients(self): 54 55 def square(x): 56 return math_ops.multiply(x, x) 57 58 grad = tfe.gradients_function(square) 59 self.assertEquals([6], [x.numpy() for x in grad(3)]) 60 61 def testGradOfGrad(self): 62 63 def square(x): 64 return math_ops.multiply(x, x) 65 66 grad = tfe.gradients_function(square) 67 gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) 68 self.assertEquals([2], [x.numpy() for x in gradgrad(3)]) 69 70 def testCustomGrad(self): 71 72 @tfe.custom_gradient 73 def f(x): 74 y = math_ops.multiply(x, x) 75 76 def grad_fn(_): 77 return [x + y] 78 79 return y, grad_fn 80 81 grad = tfe.gradients_function(f) 82 self.assertEquals([12], [x.numpy() for x in grad(3)]) 83 84 def testGPU(self): 85 if tfe.num_gpus() <= 0: 86 self.skipTest('No GPUs available') 87 88 # tf.Tensor.as_gpu_device() moves a tensor to GPU. 89 x = constant_op.constant([[1., 2.], [3., 4.]]).gpu() 90 # Alternatively, tf.device() as a context manager places tensors and 91 # operations. 92 with ops.device('gpu:0'): 93 x += 1. 94 # Without a device context, heuristics are used to place ops. 95 # In this case, ops.reduce_mean runs on the GPU. 96 reduction_indices = range(x.shape.ndims) 97 m = math_ops.reduce_mean(x, reduction_indices) 98 # m is on GPU, bring it back to CPU and compare. 99 self.assertEqual(3.5, m.cpu().numpy()) 100 101 def testListDevices(self): 102 # Expect at least one device. 103 self.assertTrue(tfe.list_devices()) 104 105 def testAddCheckNumericsOpsRaisesError(self): 106 with self.assertRaisesRegexp( 107 RuntimeError, 108 r'add_check_numerics_ops\(\) is not compatible with eager execution'): 109 numerics.add_check_numerics_ops() 110 111 def testClassicSummaryOpsErrorOut(self): 112 x = constant_op.constant(42) 113 x_summary = summary.scalar('x', x) 114 y = constant_op.constant([1, 3, 3, 7]) 115 y_summary = summary.histogram('hist', y) 116 117 with self.assertRaisesRegexp( 118 RuntimeError, 119 r'Merging tf\.summary\.\* ops is not compatible with eager execution'): 120 summary.merge([x_summary, y_summary]) 121 122 with self.assertRaisesRegexp( 123 RuntimeError, 124 r'Merging tf\.summary\.\* ops is not compatible with eager execution'): 125 summary.merge_all() 126 127 def testClassicSummaryFileWriterErrorsOut(self): 128 with self.assertRaisesRegexp( 129 RuntimeError, 130 r'tf\.summary\.FileWriter is not compatible with eager execution'): 131 writer.FileWriter(tempfile.mkdtemp()) 132 133 134if __name__ == '__main__': 135 tfe.enable_eager_execution() 136 test.main() 137