1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for core."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22
23import numpy as np
24
25from tensorflow.core.protobuf import config_pb2
26from tensorflow.python import pywrap_tensorflow
27from tensorflow.python.eager import context
28from tensorflow.python.eager import core
29from tensorflow.python.eager import execute as execute_lib
30from tensorflow.python.eager import test
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import test_util
36from tensorflow.python.ops import nn_ops
37
38
39def execute(op_name, num_outputs, inputs, attrs=None):
40  return execute_lib.execute(
41      op_name, num_outputs, inputs, attrs, context.context())
42
43
44def truncated_normal(shape):
45  return execute(
46      b'TruncatedNormal',
47      1,
48      inputs=[shape],
49      attrs=('dtype', dtypes.float32.as_datatype_enum, 'T',
50             shape.dtype.as_datatype_enum, 'seed', 0, 'seed2', 0))[0]
51
52
53class TFETest(test_util.TensorFlowTestCase):
54
55  def testContext(self):
56    ctx = context.Context()
57    self.assertFalse(ctx.in_graph_mode())
58    self.assertTrue(ctx.in_eager_mode())
59
60    self.assertEqual('', ctx.scope_name)
61    ctx.scope_name = 'foo'
62    self.assertEqual('foo', ctx.scope_name)
63
64    self.assertIsNone(ctx.summary_writer_resource)
65    ctx.summary_writer_resource = 'mock'
66    self.assertEqual('mock', ctx.summary_writer_resource)
67
68    self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
69                     ctx.device_name)
70    self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
71    with ctx.device('GPU:0'):
72      self.assertEqual('/job:localhost/replica:0/task:0/device:GPU:0',
73                       ctx.device_name)
74      self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
75      with ctx.device(None):
76        self.assertEqual('', ctx.device_name)
77        self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
78        with ctx.device('CPU:0'):
79          self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
80                           ctx.device_name)
81          self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
82
83    has_cpu_device = False
84    for x in ctx.devices():
85      has_cpu_device = has_cpu_device or 'CPU' in x
86    self.assertTrue(has_cpu_device)
87    del ctx
88
89  def testRunMetadata(self):
90    context.enable_run_metadata()
91    t = constant_op.constant(1.0)
92    _ = t + t  # Runs an operation which will be in the RunMetadata
93    run_metadata = context.export_run_metadata()
94    context.disable_run_metadata()
95    step_stats = run_metadata.step_stats
96    self.assertGreater(len(step_stats.dev_stats), 0)
97    cpu_stats = step_stats.dev_stats[0]
98    self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
99                     cpu_stats.device)
100    self.assertEqual(len(cpu_stats.node_stats), 1)
101    self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add')
102
103  def testContextStackContainsEagerMode(self):
104    # Eager execution has been enabled, and no other context
105    # switch has occurred, so `context_stack` should contain
106    # exactly one entry.
107    self.assertEqual(len(context.context_stack.stack), 1)
108    stack_entry = context.context_stack.stack[0]
109
110    # The entry should log that eager mode was entered.
111    self.assertIs(stack_entry.enter_context_fn, context.eager_mode)
112
113    # It is not possible to build a graph function when eager execution
114    # is enabled; the stack entry should reflect this fact.
115    self.assertFalse(stack_entry.is_building_function)
116
117  def testInt32GPU(self):
118    if not context.context().num_gpus():
119      self.skipTest('No GPUs found')
120    with ops.device('gpu:0'):
121      xent = nn_ops.sparse_softmax_cross_entropy_with_logits(
122          logits=[[0.0, 0.0]], labels=[0])
123    self.assertAllClose(xent, [0.69314718])
124
125  def _runInThread(self, target, args):
126    t = threading.Thread(target=target, args=args)
127    try:
128      t.start()
129      t.join()
130    except Exception as e:
131      raise e
132
133  # Test that different thread local values are initialized to the same values
134  # in different threads.
135  def testContextThreadLocalMembers(self):
136
137    def get_context_values(ctx):
138      return [
139          ctx.in_graph_mode(),
140          ctx.in_eager_mode(), ctx.scope_name, ctx.summary_writer_resource,
141          ctx.device_name, ctx.num_gpus()
142      ]
143
144    def get_values(ctx, values):
145      values.extend(get_context_values(ctx))
146
147    context_values = []
148    ctx = context.Context()
149    self._runInThread(get_values, (ctx, context_values))
150    self.assertAllEqual(context_values, get_context_values(ctx))
151
152  def testContextConfig(self):
153    if not context.context().num_gpus():
154      self.skipTest('No GPUs found')
155    ctx = context.Context(config=config_pb2.ConfigProto(
156        device_count={'GPU': 0}))
157    self.assertEquals(0, ctx.num_gpus())
158
159  def testTensorPlacement(self):
160    if not context.context().num_gpus():
161      self.skipTest('No GPUs found')
162
163    x = constant_op.constant(1.).gpu()
164    with context.device('gpu:0'):
165      y = constant_op.constant(2.)
166    # Add would fail if t2 were not on GPU
167    result = execute(
168        b'Add', 1, inputs=[x, y],
169        attrs=('T', x.dtype.as_datatype_enum))[0].cpu().numpy()
170    self.assertEqual(3, result)
171
172  def testCopyBetweenDevices(self):
173    if not context.context().num_gpus():
174      self.skipTest('No GPUs found')
175
176    x = constant_op.constant([[1., 2.], [3., 4.]])
177    x = x.cpu()
178    x = x.gpu()
179    x = x.gpu()
180    x = x.cpu()
181
182    # Invalid device
183    with self.assertRaises(RuntimeError):
184      x.gpu(context.context().num_gpus() + 1)
185
186  def testCopyScope(self):
187    if not context.context().num_gpus():
188      self.skipTest('No GPUs found')
189    constant = constant_op.constant(1.0)
190    with ops.device('gpu:0'):
191      with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
192        c = constant + 1.0
193    self.assertAllEqual(c, 2.0)
194
195  def testNumpyForceCPU(self):
196    if not context.context().num_gpus():
197      self.skipTest('No GPUs found')
198
199    cpu = constant_op.constant([[1., 2.], [3., 4.]])
200    c2g = cpu.gpu()
201    self.assertAllEqual(c2g, cpu.numpy())
202
203  def testCopyFromCPUToCPU(self):
204    ta = constant_op.constant([[1, 2], [3, 4]])
205    tb = ta.cpu()
206
207    self.assertNotEqual(id(ta), id(tb))
208    self.assertAllEqual(ta, tb.numpy())
209
210  def testRegisterExceptionClass(self):
211    with self.assertRaises(TypeError):
212      pywrap_tensorflow.TFE_Py_RegisterExceptionClass(str)
213    pywrap_tensorflow.TFE_Py_RegisterExceptionClass(core._NotOkStatusException)  # pylint: disable=protected-access
214
215  # TODO(agarwal): add tests passing incorrect typed values to attrs.
216  def testExecuteBasic(self):
217    three = constant_op.constant(3)
218    five = constant_op.constant(5)
219    product = execute(
220        b'Mul',
221        num_outputs=1,
222        inputs=[three, five],
223        attrs=('T', three.dtype.as_datatype_enum))[0]
224    self.assertAllEqual(15, product)
225
226  def testExecuteTooManyNumOutputs(self):
227    # num_outputs provided is 50, but only one output is produced.
228    # That should be okay.
229    product = execute(
230        b'Mul',
231        num_outputs=50,
232        inputs=[constant_op.constant(3), constant_op.constant(5)],
233        attrs=('T', dtypes.int32.as_datatype_enum))[0]
234    self.assertAllEqual(15, product)
235
236  def testMatMulGPU(self):
237    if not context.context().num_gpus():
238      self.skipTest('No GPUs found')
239    three = constant_op.constant([[3.]]).gpu()
240    five = constant_op.constant([[5.]]).gpu()
241    product = execute(
242        b'MatMul',
243        num_outputs=1,
244        inputs=[three, five],
245        attrs=('transpose_a', False, 'transpose_b', False, 'T',
246               three.dtype.as_datatype_enum))[0]
247    self.assertAllEqual([[15.0]], product)
248
249  def testExecuteStringAttr(self):
250    checked_three = execute(
251        b'CheckNumerics',
252        num_outputs=1,
253        inputs=[constant_op.constant(3.)],
254        attrs=('message', 'just checking', 'T',
255               dtypes.float32.as_datatype_enum))[0]
256    self.assertEqual([[3]], checked_three.numpy())
257
258  def testExecuteStringAttrBadValue(self):
259    with self.assertRaises(errors.InvalidArgumentError):
260      _ = execute(
261          b'CheckNumerics',
262          num_outputs=1,
263          inputs=[constant_op.constant(3.)],
264          attrs=('message', 1, 'T', dtypes.float32.as_datatype_enum))
265
266  def testExecuteFloatAttr(self):
267    almost_equal = execute(
268        b'ApproximateEqual',
269        num_outputs=1,
270        inputs=[constant_op.constant(3.0), constant_op.constant(2.9)],
271        attrs=('tolerance', 0.3, 'T', dtypes.float32.as_datatype_enum))[0]
272    self.assertTrue(almost_equal)
273
274  def testExecuteFloatAttrBadValue(self):
275    with self.assertRaises(errors.InvalidArgumentError):
276      _ = execute(
277          b'ApproximateEqual',
278          num_outputs=1,
279          inputs=[constant_op.constant(3.0), constant_op.constant(2.9)],
280          attrs=('tolerance', '0.3', 'T', dtypes.float32.as_datatype_enum))
281
282  def testExecuteIntAttr(self):
283    total = execute(
284        b'AddN',
285        num_outputs=1,
286        inputs=[constant_op.constant(3), constant_op.constant(4)],
287        attrs=('T', dtypes.int32.as_datatype_enum, 'N', 2))[0]
288    self.assertAllEqual(7, total)
289
290  def testExecuteIntAttrBadValue(self):
291    with self.assertRaises(errors.InvalidArgumentError):
292      _ = execute(
293          b'AddN',
294          num_outputs=1,
295          inputs=[constant_op.constant(3), constant_op.constant(4)],
296          attrs=('T', dtypes.int32.as_datatype_enum, 'N', '2'))
297
298  # Looks like we don't have an existing op with list(bool) attrs.
299  def testExecuteBoolAttr(self):
300    product = execute(
301        b'MatMul',
302        num_outputs=1,
303        inputs=[constant_op.constant([[3]]),
304                constant_op.constant([[5]])],
305        attrs=('transpose_a', True, 'transpose_b', False, 'T',
306               dtypes.int32.as_datatype_enum))[0]
307    self.assertAllEqual([[15]], product)
308
309  def testExecuteShapeAttr(self):
310    execute(
311        b'VarHandleOp',
312        num_outputs=1,
313        inputs=[],
314        attrs=('shape', [1, 2], 'dtype', dtypes.int32.as_datatype_enum,
315               'container', '', 'shared_name', ''))
316
317  def testExecuteShapeAttrBadValue(self):
318    with self.assertRaises(errors.InvalidArgumentError):
319      execute(
320          b'VarHandleOp',
321          num_outputs=1,
322          inputs=[],
323          attrs=('shape', 1, 'dtype', dtypes.int32.as_datatype_enum,
324                 'container', '', 'shared_name', ''))
325
326  def testExecuteListStringAttr(self):
327    execute(
328        b'TensorSummary',
329        num_outputs=1,
330        inputs=[constant_op.constant(3.0)],
331        attrs=('T', dtypes.float32.as_datatype_enum, 'description',
332               'tensor_summary', 'labels', ['3',
333                                            'summary'], 'display_name', 'test'))
334
335  def testExecuteListStringAttrBadValue(self):
336    with self.assertRaises(errors.InvalidArgumentError):
337      execute(
338          b'TensorSummary',
339          num_outputs=1,
340          inputs=[constant_op.constant(3.0)],
341          attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
342                 'labels', 3, 'display_name', 'test'))
343
344  def testExecuteListStringAttrBadListValue(self):
345    with self.assertRaises(errors.InvalidArgumentError):
346      execute(
347          b'TensorSummary',
348          num_outputs=1,
349          inputs=[constant_op.constant(3.0)],
350          attrs=('T', dtypes.float32.as_datatype_enum, 'description', '',
351                 'labels', [3], 'display_name', 'test'))
352
353  def testExecuteListFloatAttr(self):
354    b = execute(
355        b'Bucketize',
356        num_outputs=1,
357        inputs=[constant_op.constant([3.0, 5.0, 7.0])],
358        attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', [4.0,
359                                                                    6.0]))[0]
360    self.assertAllEqual([0, 1, 2], b)
361
362  def testExecuteListFloatAttrBadValue(self):
363    with self.assertRaises(errors.InvalidArgumentError):
364      execute(
365          b'Bucketize',
366          num_outputs=1,
367          inputs=[constant_op.constant([3.0, 5.0, 7.0])],
368          attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 4.0))
369
370  def testExecuteListFloatAttrBadListValue(self):
371    with self.assertRaises(errors.InvalidArgumentError):
372      execute(
373          b'Bucketize',
374          num_outputs=1,
375          inputs=[constant_op.constant([3.0, 5.0, 7.0])],
376          attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries',
377                 ['4.0', '6.0']))
378
379  def testExecuteListIntAttr(self):
380    b = execute(
381        b'Squeeze',
382        num_outputs=1,
383        inputs=[constant_op.constant([[[3.0]]])],
384        attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', [0, 2]))[0]
385    self.assertAllEqual([3], b)
386
387  def testExecuteListIntAttrBadValue(self):
388    with self.assertRaises(errors.InvalidArgumentError):
389      execute(
390          b'Squeeze',
391          num_outputs=1,
392          inputs=[constant_op.constant([[[3.0]]])],
393          attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 0))
394
395  def testExecuteListIntAttrBadListValue(self):
396    with self.assertRaises(errors.InvalidArgumentError):
397      execute(
398          b'Squeeze',
399          num_outputs=1,
400          inputs=[constant_op.constant([[[3.0]]])],
401          attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims',
402                 ['0', '2']))
403
404  def testExecuteListTypeListShapeAttr(self):
405    execute(
406        b'Barrier',
407        num_outputs=1,
408        inputs=[],
409        attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
410               [[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
411
412  def testExecuteListTypeAttrBadValue(self):
413    with self.assertRaises(errors.InvalidArgumentError):
414      execute(
415          b'Barrier',
416          num_outputs=1,
417          inputs=[],
418          attrs=('component_types', dtypes.float64.as_datatype_enum, 'shapes',
419                 [[1, 2]], 'capacity', -1, 'container', '', 'shared_name', ''))
420
421  def testExecuteListTypeAttrBadListValue(self):
422    with self.assertRaises(errors.InvalidArgumentError):
423      execute(
424          b'Barrier',
425          num_outputs=1,
426          inputs=[],
427          attrs=('component_types', '1', 'shapes', [[1, 2]], 'capacity', -1,
428                 'container', '', 'shared_name', ''))
429
430  def testExecuteListShapeAttrBadValue(self):
431    with self.assertRaises(errors.InvalidArgumentError):
432      execute(
433          b'Barrier',
434          num_outputs=1,
435          inputs=[],
436          attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
437                 [1, 2], 'capacity', -1, 'container', '', 'shared_name', ''))
438
439  def testExecuteListShapeAttrBadListValue(self):
440    with self.assertRaises(errors.InvalidArgumentError):
441      execute(
442          b'Barrier',
443          num_outputs=1,
444          inputs=[],
445          attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes',
446                 [1], 'capacity', -1, 'container', '', 'shared_name', ''))
447
448  def testExecuteMultipleOutputs(self):
449    split_dim = 1
450    value = [[0, 1, 2], [3, 4, 5]]
451    x1, x2, x3 = execute(
452        b'Split',
453        num_outputs=3,
454        inputs=[constant_op.constant(split_dim),
455                constant_op.constant(value)],
456        attrs=('num_split', 3, 'T', dtypes.int32.as_datatype_enum))
457    self.assertAllEqual([[0], [3]], x1)
458    self.assertAllEqual([[1], [4]], x2)
459    self.assertAllEqual([[2], [5]], x3)
460
461  def testExecuteBadNumOutputsArgument(self):
462    with self.assertRaises(TypeError):
463      execute(
464          b'Relu', [],
465          inputs=[constant_op.constant(3.0)],
466          attrs=('T', dtypes.float32.as_datatype_enum))
467
468  def testExecuteUnknownOp(self):
469    with self.assertRaises(errors.NotFoundError):
470      execute(b'BlahBlahBlah', num_outputs=1, inputs=[], attrs=None)
471
472  def testExecuteUnknownAttr(self):
473    with self.assertRaises(errors.InvalidArgumentError):
474      execute(
475          b'Identity',
476          num_outputs=1,
477          inputs=[constant_op.constant(3)],
478          attrs=('T', dtypes.int32.as_datatype_enum, 'unknown_attr', 'blah'))
479
480  def testComposition(self):
481
482    def add(x, y):
483      return execute(
484          b'Add',
485          num_outputs=1,
486          inputs=[x, y],
487          attrs=('T', dtypes.int32.as_datatype_enum))[0]
488
489    x = constant_op.constant(1)
490    three_x = add(add(x, x), x)
491    self.assertEquals(dtypes.int32, three_x.dtype)
492    self.assertAllEqual(3, three_x)
493
494  def testOperationWithNoInputsRunsOnDevice(self):
495    if not context.context().num_gpus():
496      self.skipTest('No GPUs found')
497    shape = constant_op.constant([], dtype=dtypes.int32)
498
499    # x: Run the "TruncatedNormal" op CPU and copy result to GPU.
500    x = truncated_normal(shape).gpu()
501    # y: Explicitly run the "TruncatedNormal" op on GPU.
502    with context.device('gpu:0'):
503      y = truncated_normal(shape)
504    # Add would fail if x and y were not on the same device.
505    execute(
506        b'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum))
507
508  def testInvalidDevice(self):
509    with self.assertRaises(ValueError):
510      with context.device('pu:0'):
511        _ = constant_op.constant(1)
512
513  def testConvertMixedEagerTensors(self):
514    array = np.zeros((), dtype=np.float32)
515    tensor = constant_op.constant(0., dtype=dtypes.float32)
516    types, tensors = execute_lib.convert_to_mixed_eager_tensors(
517        [array, tensor], context.context())
518    for typ, t in zip(types, tensors):
519      self.assertEquals(typ, dtypes.float32)
520      self.assertIsInstance(t, ops.EagerTensor)
521
522
523if __name__ == '__main__':
524  test.main()
525