1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for core.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import threading 22 23import numpy as np 24 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.python import pywrap_tensorflow 27from tensorflow.python.eager import context 28from tensorflow.python.eager import core 29from tensorflow.python.eager import execute as execute_lib 30from tensorflow.python.eager import test 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import errors 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import test_util 36from tensorflow.python.ops import nn_ops 37 38 39def execute(op_name, num_outputs, inputs, attrs=None): 40 return execute_lib.execute( 41 op_name, num_outputs, inputs, attrs, context.context()) 42 43 44def truncated_normal(shape): 45 return execute( 46 b'TruncatedNormal', 47 1, 48 inputs=[shape], 49 attrs=('dtype', dtypes.float32.as_datatype_enum, 'T', 50 shape.dtype.as_datatype_enum, 'seed', 0, 'seed2', 0))[0] 51 52 53class TFETest(test_util.TensorFlowTestCase): 54 55 def testContext(self): 56 ctx = context.Context() 57 self.assertFalse(ctx.in_graph_mode()) 58 self.assertTrue(ctx.in_eager_mode()) 59 60 self.assertEqual('', ctx.scope_name) 61 ctx.scope_name = 'foo' 62 self.assertEqual('foo', ctx.scope_name) 63 64 self.assertIsNone(ctx.summary_writer_resource) 65 ctx.summary_writer_resource = 'mock' 66 self.assertEqual('mock', ctx.summary_writer_resource) 67 68 self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0', 69 ctx.device_name) 70 self.assertEqual(ctx.device_name, ctx.device_spec.to_string()) 71 with ctx.device('GPU:0'): 72 self.assertEqual('/job:localhost/replica:0/task:0/device:GPU:0', 73 ctx.device_name) 74 self.assertEqual(ctx.device_name, ctx.device_spec.to_string()) 75 with ctx.device(None): 76 self.assertEqual('', ctx.device_name) 77 self.assertEqual(ctx.device_name, ctx.device_spec.to_string()) 78 with ctx.device('CPU:0'): 79 self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0', 80 ctx.device_name) 81 self.assertEqual(ctx.device_name, ctx.device_spec.to_string()) 82 83 has_cpu_device = False 84 for x in ctx.devices(): 85 has_cpu_device = has_cpu_device or 'CPU' in x 86 self.assertTrue(has_cpu_device) 87 del ctx 88 89 def testRunMetadata(self): 90 context.enable_run_metadata() 91 t = constant_op.constant(1.0) 92 _ = t + t # Runs an operation which will be in the RunMetadata 93 run_metadata = context.export_run_metadata() 94 context.disable_run_metadata() 95 step_stats = run_metadata.step_stats 96 self.assertGreater(len(step_stats.dev_stats), 0) 97 cpu_stats = step_stats.dev_stats[0] 98 self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0', 99 cpu_stats.device) 100 self.assertEqual(len(cpu_stats.node_stats), 1) 101 self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add') 102 103 def testContextStackContainsEagerMode(self): 104 # Eager execution has been enabled, and no other context 105 # switch has occurred, so `context_stack` should contain 106 # exactly one entry. 107 self.assertEqual(len(context.context_stack.stack), 1) 108 stack_entry = context.context_stack.stack[0] 109 110 # The entry should log that eager mode was entered. 111 self.assertIs(stack_entry.enter_context_fn, context.eager_mode) 112 113 # It is not possible to build a graph function when eager execution 114 # is enabled; the stack entry should reflect this fact. 115 self.assertFalse(stack_entry.is_building_function) 116 117 def testInt32GPU(self): 118 if not context.context().num_gpus(): 119 self.skipTest('No GPUs found') 120 with ops.device('gpu:0'): 121 xent = nn_ops.sparse_softmax_cross_entropy_with_logits( 122 logits=[[0.0, 0.0]], labels=[0]) 123 self.assertAllClose(xent, [0.69314718]) 124 125 def _runInThread(self, target, args): 126 t = threading.Thread(target=target, args=args) 127 try: 128 t.start() 129 t.join() 130 except Exception as e: 131 raise e 132 133 # Test that different thread local values are initialized to the same values 134 # in different threads. 135 def testContextThreadLocalMembers(self): 136 137 def get_context_values(ctx): 138 return [ 139 ctx.in_graph_mode(), 140 ctx.in_eager_mode(), ctx.scope_name, ctx.summary_writer_resource, 141 ctx.device_name, ctx.num_gpus() 142 ] 143 144 def get_values(ctx, values): 145 values.extend(get_context_values(ctx)) 146 147 context_values = [] 148 ctx = context.Context() 149 self._runInThread(get_values, (ctx, context_values)) 150 self.assertAllEqual(context_values, get_context_values(ctx)) 151 152 def testContextConfig(self): 153 if not context.context().num_gpus(): 154 self.skipTest('No GPUs found') 155 ctx = context.Context(config=config_pb2.ConfigProto( 156 device_count={'GPU': 0})) 157 self.assertEquals(0, ctx.num_gpus()) 158 159 def testTensorPlacement(self): 160 if not context.context().num_gpus(): 161 self.skipTest('No GPUs found') 162 163 x = constant_op.constant(1.).gpu() 164 with context.device('gpu:0'): 165 y = constant_op.constant(2.) 166 # Add would fail if t2 were not on GPU 167 result = execute( 168 b'Add', 1, inputs=[x, y], 169 attrs=('T', x.dtype.as_datatype_enum))[0].cpu().numpy() 170 self.assertEqual(3, result) 171 172 def testCopyBetweenDevices(self): 173 if not context.context().num_gpus(): 174 self.skipTest('No GPUs found') 175 176 x = constant_op.constant([[1., 2.], [3., 4.]]) 177 x = x.cpu() 178 x = x.gpu() 179 x = x.gpu() 180 x = x.cpu() 181 182 # Invalid device 183 with self.assertRaises(RuntimeError): 184 x.gpu(context.context().num_gpus() + 1) 185 186 def testCopyScope(self): 187 if not context.context().num_gpus(): 188 self.skipTest('No GPUs found') 189 constant = constant_op.constant(1.0) 190 with ops.device('gpu:0'): 191 with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): 192 c = constant + 1.0 193 self.assertAllEqual(c, 2.0) 194 195 def testNumpyForceCPU(self): 196 if not context.context().num_gpus(): 197 self.skipTest('No GPUs found') 198 199 cpu = constant_op.constant([[1., 2.], [3., 4.]]) 200 c2g = cpu.gpu() 201 self.assertAllEqual(c2g, cpu.numpy()) 202 203 def testCopyFromCPUToCPU(self): 204 ta = constant_op.constant([[1, 2], [3, 4]]) 205 tb = ta.cpu() 206 207 self.assertNotEqual(id(ta), id(tb)) 208 self.assertAllEqual(ta, tb.numpy()) 209 210 def testRegisterExceptionClass(self): 211 with self.assertRaises(TypeError): 212 pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str) 213 pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException) # pylint: disable=protected-access 214 215 # TODO(agarwal): add tests passing incorrect typed values to attrs. 216 def testExecuteBasic(self): 217 three = constant_op.constant(3) 218 five = constant_op.constant(5) 219 product = execute( 220 b'Mul', 221 num_outputs=1, 222 inputs=[three, five], 223 attrs=('T', three.dtype.as_datatype_enum))[0] 224 self.assertAllEqual(15, product) 225 226 def testExecuteTooManyNumOutputs(self): 227 # num_outputs provided is 50, but only one output is produced. 228 # That should be okay. 229 product = execute( 230 b'Mul', 231 num_outputs=50, 232 inputs=[constant_op.constant(3), constant_op.constant(5)], 233 attrs=('T', dtypes.int32.as_datatype_enum))[0] 234 self.assertAllEqual(15, product) 235 236 def testMatMulGPU(self): 237 if not context.context().num_gpus(): 238 self.skipTest('No GPUs found') 239 three = constant_op.constant([[3.]]).gpu() 240 five = constant_op.constant([[5.]]).gpu() 241 product = execute( 242 b'MatMul', 243 num_outputs=1, 244 inputs=[three, five], 245 attrs=('transpose_a', False, 'transpose_b', False, 'T', 246 three.dtype.as_datatype_enum))[0] 247 self.assertAllEqual([[15.0]], product) 248 249 def testExecuteStringAttr(self): 250 checked_three = execute( 251 b'CheckNumerics', 252 num_outputs=1, 253 inputs=[constant_op.constant(3.)], 254 attrs=('message', 'just checking', 'T', 255 dtypes.float32.as_datatype_enum))[0] 256 self.assertEqual([[3]], checked_three.numpy()) 257 258 def testExecuteStringAttrBadValue(self): 259 with self.assertRaises(errors.InvalidArgumentError): 260 _ = execute( 261 b'CheckNumerics', 262 num_outputs=1, 263 inputs=[constant_op.constant(3.)], 264 attrs=('message', 1, 'T', dtypes.float32.as_datatype_enum)) 265 266 def testExecuteFloatAttr(self): 267 almost_equal = execute( 268 b'ApproximateEqual', 269 num_outputs=1, 270 inputs=[constant_op.constant(3.0), constant_op.constant(2.9)], 271 attrs=('tolerance', 0.3, 'T', dtypes.float32.as_datatype_enum))[0] 272 self.assertTrue(almost_equal) 273 274 def testExecuteFloatAttrBadValue(self): 275 with self.assertRaises(errors.InvalidArgumentError): 276 _ = execute( 277 b'ApproximateEqual', 278 num_outputs=1, 279 inputs=[constant_op.constant(3.0), constant_op.constant(2.9)], 280 attrs=('tolerance', '0.3', 'T', dtypes.float32.as_datatype_enum)) 281 282 def testExecuteIntAttr(self): 283 total = execute( 284 b'AddN', 285 num_outputs=1, 286 inputs=[constant_op.constant(3), constant_op.constant(4)], 287 attrs=('T', dtypes.int32.as_datatype_enum, 'N', 2))[0] 288 self.assertAllEqual(7, total) 289 290 def testExecuteIntAttrBadValue(self): 291 with self.assertRaises(errors.InvalidArgumentError): 292 _ = execute( 293 b'AddN', 294 num_outputs=1, 295 inputs=[constant_op.constant(3), constant_op.constant(4)], 296 attrs=('T', dtypes.int32.as_datatype_enum, 'N', '2')) 297 298 # Looks like we don't have an existing op with list(bool) attrs. 299 def testExecuteBoolAttr(self): 300 product = execute( 301 b'MatMul', 302 num_outputs=1, 303 inputs=[constant_op.constant([[3]]), 304 constant_op.constant([[5]])], 305 attrs=('transpose_a', True, 'transpose_b', False, 'T', 306 dtypes.int32.as_datatype_enum))[0] 307 self.assertAllEqual([[15]], product) 308 309 def testExecuteShapeAttr(self): 310 execute( 311 b'VarHandleOp', 312 num_outputs=1, 313 inputs=[], 314 attrs=('shape', [1, 2], 'dtype', dtypes.int32.as_datatype_enum, 315 'container', '', 'shared_name', '')) 316 317 def testExecuteShapeAttrBadValue(self): 318 with self.assertRaises(errors.InvalidArgumentError): 319 execute( 320 b'VarHandleOp', 321 num_outputs=1, 322 inputs=[], 323 attrs=('shape', 1, 'dtype', dtypes.int32.as_datatype_enum, 324 'container', '', 'shared_name', '')) 325 326 def testExecuteListStringAttr(self): 327 execute( 328 b'TensorSummary', 329 num_outputs=1, 330 inputs=[constant_op.constant(3.0)], 331 attrs=('T', dtypes.float32.as_datatype_enum, 'description', 332 'tensor_summary', 'labels', ['3', 333 'summary'], 'display_name', 'test')) 334 335 def testExecuteListStringAttrBadValue(self): 336 with self.assertRaises(errors.InvalidArgumentError): 337 execute( 338 b'TensorSummary', 339 num_outputs=1, 340 inputs=[constant_op.constant(3.0)], 341 attrs=('T', dtypes.float32.as_datatype_enum, 'description', '', 342 'labels', 3, 'display_name', 'test')) 343 344 def testExecuteListStringAttrBadListValue(self): 345 with self.assertRaises(errors.InvalidArgumentError): 346 execute( 347 b'TensorSummary', 348 num_outputs=1, 349 inputs=[constant_op.constant(3.0)], 350 attrs=('T', dtypes.float32.as_datatype_enum, 'description', '', 351 'labels', [3], 'display_name', 'test')) 352 353 def testExecuteListFloatAttr(self): 354 b = execute( 355 b'Bucketize', 356 num_outputs=1, 357 inputs=[constant_op.constant([3.0, 5.0, 7.0])], 358 attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', [4.0, 359 6.0]))[0] 360 self.assertAllEqual([0, 1, 2], b) 361 362 def testExecuteListFloatAttrBadValue(self): 363 with self.assertRaises(errors.InvalidArgumentError): 364 execute( 365 b'Bucketize', 366 num_outputs=1, 367 inputs=[constant_op.constant([3.0, 5.0, 7.0])], 368 attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 4.0)) 369 370 def testExecuteListFloatAttrBadListValue(self): 371 with self.assertRaises(errors.InvalidArgumentError): 372 execute( 373 b'Bucketize', 374 num_outputs=1, 375 inputs=[constant_op.constant([3.0, 5.0, 7.0])], 376 attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 377 ['4.0', '6.0'])) 378 379 def testExecuteListIntAttr(self): 380 b = execute( 381 b'Squeeze', 382 num_outputs=1, 383 inputs=[constant_op.constant([[[3.0]]])], 384 attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', [0, 2]))[0] 385 self.assertAllEqual([3], b) 386 387 def testExecuteListIntAttrBadValue(self): 388 with self.assertRaises(errors.InvalidArgumentError): 389 execute( 390 b'Squeeze', 391 num_outputs=1, 392 inputs=[constant_op.constant([[[3.0]]])], 393 attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 0)) 394 395 def testExecuteListIntAttrBadListValue(self): 396 with self.assertRaises(errors.InvalidArgumentError): 397 execute( 398 b'Squeeze', 399 num_outputs=1, 400 inputs=[constant_op.constant([[[3.0]]])], 401 attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 402 ['0', '2'])) 403 404 def testExecuteListTypeListShapeAttr(self): 405 execute( 406 b'Barrier', 407 num_outputs=1, 408 inputs=[], 409 attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes', 410 [[1, 2]], 'capacity', -1, 'container', '', 'shared_name', '')) 411 412 def testExecuteListTypeAttrBadValue(self): 413 with self.assertRaises(errors.InvalidArgumentError): 414 execute( 415 b'Barrier', 416 num_outputs=1, 417 inputs=[], 418 attrs=('component_types', dtypes.float64.as_datatype_enum, 'shapes', 419 [[1, 2]], 'capacity', -1, 'container', '', 'shared_name', '')) 420 421 def testExecuteListTypeAttrBadListValue(self): 422 with self.assertRaises(errors.InvalidArgumentError): 423 execute( 424 b'Barrier', 425 num_outputs=1, 426 inputs=[], 427 attrs=('component_types', '1', 'shapes', [[1, 2]], 'capacity', -1, 428 'container', '', 'shared_name', '')) 429 430 def testExecuteListShapeAttrBadValue(self): 431 with self.assertRaises(errors.InvalidArgumentError): 432 execute( 433 b'Barrier', 434 num_outputs=1, 435 inputs=[], 436 attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes', 437 [1, 2], 'capacity', -1, 'container', '', 'shared_name', '')) 438 439 def testExecuteListShapeAttrBadListValue(self): 440 with self.assertRaises(errors.InvalidArgumentError): 441 execute( 442 b'Barrier', 443 num_outputs=1, 444 inputs=[], 445 attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes', 446 [1], 'capacity', -1, 'container', '', 'shared_name', '')) 447 448 def testExecuteMultipleOutputs(self): 449 split_dim = 1 450 value = [[0, 1, 2], [3, 4, 5]] 451 x1, x2, x3 = execute( 452 b'Split', 453 num_outputs=3, 454 inputs=[constant_op.constant(split_dim), 455 constant_op.constant(value)], 456 attrs=('num_split', 3, 'T', dtypes.int32.as_datatype_enum)) 457 self.assertAllEqual([[0], [3]], x1) 458 self.assertAllEqual([[1], [4]], x2) 459 self.assertAllEqual([[2], [5]], x3) 460 461 def testExecuteBadNumOutputsArgument(self): 462 with self.assertRaises(TypeError): 463 execute( 464 b'Relu', [], 465 inputs=[constant_op.constant(3.0)], 466 attrs=('T', dtypes.float32.as_datatype_enum)) 467 468 def testExecuteUnknownOp(self): 469 with self.assertRaises(errors.NotFoundError): 470 execute(b'BlahBlahBlah', num_outputs=1, inputs=[], attrs=None) 471 472 def testExecuteUnknownAttr(self): 473 with self.assertRaises(errors.InvalidArgumentError): 474 execute( 475 b'Identity', 476 num_outputs=1, 477 inputs=[constant_op.constant(3)], 478 attrs=('T', dtypes.int32.as_datatype_enum, 'unknown_attr', 'blah')) 479 480 def testComposition(self): 481 482 def add(x, y): 483 return execute( 484 b'Add', 485 num_outputs=1, 486 inputs=[x, y], 487 attrs=('T', dtypes.int32.as_datatype_enum))[0] 488 489 x = constant_op.constant(1) 490 three_x = add(add(x, x), x) 491 self.assertEquals(dtypes.int32, three_x.dtype) 492 self.assertAllEqual(3, three_x) 493 494 def testOperationWithNoInputsRunsOnDevice(self): 495 if not context.context().num_gpus(): 496 self.skipTest('No GPUs found') 497 shape = constant_op.constant([], dtype=dtypes.int32) 498 499 # x: Run the "TruncatedNormal" op CPU and copy result to GPU. 500 x = truncated_normal(shape).gpu() 501 # y: Explicitly run the "TruncatedNormal" op on GPU. 502 with context.device('gpu:0'): 503 y = truncated_normal(shape) 504 # Add would fail if x and y were not on the same device. 505 execute( 506 b'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum)) 507 508 def testInvalidDevice(self): 509 with self.assertRaises(ValueError): 510 with context.device('pu:0'): 511 _ = constant_op.constant(1) 512 513 def testConvertMixedEagerTensors(self): 514 array = np.zeros((), dtype=np.float32) 515 tensor = constant_op.constant(0., dtype=dtypes.float32) 516 types, tensors = execute_lib.convert_to_mixed_eager_tensors( 517 [array, tensor], context.context()) 518 for typ, t in zip(types, tensors): 519 self.assertEquals(typ, dtypes.float32) 520 self.assertIsInstance(t, ops.EagerTensor) 521 522 523if __name__ == '__main__': 524 test.main() 525