1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for operations in eager execution."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.eager import context
23from tensorflow.python.eager import execute
24from tensorflow.python.eager import test
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import test_util
30from tensorflow.python.layers import core
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import random_ops
35from tensorflow.python.ops import resource_variable_ops
36from tensorflow.python.ops import sparse_ops
37
38
39class OpsTest(test_util.TensorFlowTestCase):
40
41  def testExecuteBasic(self):
42    three = constant_op.constant(3)
43    five = constant_op.constant(5)
44    product = three * five
45    self.assertAllEqual(15, product)
46
47  def testMatMulGPU(self):
48    if not context.context().num_gpus():
49      self.skipTest('No GPUs found')
50    three = constant_op.constant([[3.]]).gpu()
51    five = constant_op.constant([[5.]]).gpu()
52    product = math_ops.matmul(three, five)
53    self.assertEqual([[15.0]], product.numpy())
54
55  def testExecuteStringAttr(self):
56    three = constant_op.constant(3.0)
57    checked_three = array_ops.check_numerics(three,
58                                             message='just checking')
59    self.assertEqual([[3]], checked_three.numpy())
60
61  def testExecuteFloatAttr(self):
62    three = constant_op.constant(3.0)
63    almost_three = constant_op.constant(2.8)
64    almost_equal = math_ops.approximate_equal(
65        three, almost_three, tolerance=0.3)
66    self.assertTrue(almost_equal)
67
68  def testExecuteIntAttr(self):
69    three = constant_op.constant(3)
70    four = constant_op.constant(4)
71    total = math_ops.add_n([three, four])
72    self.assertAllEqual(7, total)
73
74  def testExecuteBoolAttr(self):
75    three = constant_op.constant([[3]])
76    five = constant_op.constant([[5]])
77    product = math_ops.matmul(three, five, transpose_a=True)
78    self.assertAllEqual([[15]], product)
79
80  def testExecuteOneListOutput(self):
81    split_dim = constant_op.constant(1)
82    value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
83    x1, x2, x3 = array_ops.split(value, 3, axis=split_dim)
84    self.assertAllEqual([[0], [3]], x1)
85    self.assertAllEqual([[1], [4]], x2)
86    self.assertAllEqual([[2], [5]], x3)
87
88  def testGraphMode(self):
89    graph = ops.Graph()
90    with graph.as_default(), context.graph_mode():
91      array_ops.placeholder(dtypes.int32)
92    self.assertEqual(1, len(graph.get_operations()))
93
94  # See comments on handling of int32 tensors on GPU in
95  # EagerTensor.__init__.
96  def testInt32CPUDefault(self):
97    if not context.context().num_gpus():
98      self.skipTest('No GPUs found')
99    with context.device('/gpu:0'):
100      r = constant_op.constant(1) + constant_op.constant(2)
101    self.assertAllEqual(r, 3)
102
103  def testExecuteListOutputLen1(self):
104    split_dim = constant_op.constant(1)
105    value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
106    result = array_ops.split(value, 1, axis=split_dim)
107    self.assertTrue(isinstance(result, list))
108    self.assertEqual(1, len(result))
109    self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0])
110
111  def testExecuteListOutputLen0(self):
112    empty = constant_op.constant([], dtype=dtypes.int32)
113    result = array_ops.unstack(empty, 0)
114    self.assertTrue(isinstance(result, list))
115    self.assertEqual(0, len(result))
116
117  def testExecuteMultipleNonListOutput(self):
118    x = constant_op.constant([1, 2, 3, 4, 5, 6])
119    y = constant_op.constant([1, 3, 5])
120    result = array_ops.listdiff(x, y)
121    out, idx = result
122    self.assertTrue(out is result.out)
123    self.assertTrue(idx is result.idx)
124    self.assertAllEqual([2, 4, 6], out)
125    self.assertAllEqual([1, 3, 5], idx)
126
127  def testExecuteMultipleListOutput(self):
128    split_dim = constant_op.constant(1, dtype=dtypes.int64)
129    indices = constant_op.constant([[0, 2], [0, 4], [0, 5], [1, 0], [1, 1]],
130                                   dtype=dtypes.int64)
131    values = constant_op.constant([2, 3, 5, 7, 11])
132    shape = constant_op.constant([2, 7], dtype=dtypes.int64)
133    result = sparse_ops.gen_sparse_ops._sparse_split(  # pylint: disable=protected-access
134        split_dim, indices, values, shape, num_split=2)
135    output_indices, output_values, output_shape = result
136    self.assertEqual(2, len(output_indices))
137    self.assertEqual(2, len(output_values))
138    self.assertEqual(2, len(output_shape))
139    self.assertEqual(output_indices, result.output_indices)
140    self.assertEqual(output_values, result.output_values)
141    self.assertEqual(output_shape, result.output_shape)
142    self.assertAllEqual([[0, 2], [1, 0], [1, 1]], output_indices[0])
143    self.assertAllEqual([[0, 0], [0, 1]], output_indices[1])
144    self.assertAllEqual([2, 7, 11], output_values[0])
145    self.assertAllEqual([3, 5], output_values[1])
146    self.assertAllEqual([2, 4], output_shape[0])
147    self.assertAllEqual([2, 3], output_shape[1])
148
149  # TODO(josh11b): Test an op that has multiple outputs, some but not
150  # all of which are lists. Examples: barrier_take_many (currently
151  # unsupported since it uses a type list) or sdca_optimizer (I don't
152  # have an example of legal inputs & outputs).
153
154  def testComposition(self):
155    x = constant_op.constant(1, dtype=dtypes.int32)
156    three_x = x + x + x
157    self.assertEquals(dtypes.int32, three_x.dtype)
158    self.assertAllEqual(3, three_x)
159
160  def testOperatorOverrides(self):
161    # TODO(henrytan): test with negative number.
162    a = constant_op.constant([1])
163    b = constant_op.constant([2])
164
165    self.assertAllEqual((-a), [-1])
166    self.assertAllEqual(abs(b), [2])
167
168    self.assertAllEqual((a + b), [3])
169    self.assertAllEqual((a - b), [-1])
170    self.assertAllEqual((a * b), [2])
171    self.assertAllEqual((a * a), [1])
172
173    self.assertAllEqual((a**b), [1])
174    self.assertAllEqual((a / b), [1 / 2])
175    self.assertAllEqual((a / a), [1])
176    self.assertAllEqual((a % b), [1])
177
178    self.assertAllEqual((a < b), [True])
179    self.assertAllEqual((a <= b), [True])
180    self.assertAllEqual((a > b), [False])
181    self.assertAllEqual((a >= b), [False])
182    self.assertAllEqual((a == b), False)
183    self.assertAllEqual((a != b), True)
184
185    self.assertAllEqual(1, a[constant_op.constant(0)])
186
187  def test_basic_slice(self):
188    npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
189    t = constant_op.constant(npt)
190
191    self.assertAllEqual(npt[:, :, :], t[:, :, :])
192    self.assertAllEqual(npt[::, ::, ::], t[::, ::, ::])
193    self.assertAllEqual(npt[::1, ::1, ::1], t[::1, ::1, ::1])
194    self.assertAllEqual(npt[::1, ::5, ::2], t[::1, ::5, ::2])
195    self.assertAllEqual(npt[::-1, :, :], t[::-1, :, :])
196    self.assertAllEqual(npt[:, ::-1, :], t[:, ::-1, :])
197    self.assertAllEqual(npt[:, :, ::-1], t[:, :, ::-1])
198    self.assertAllEqual(npt[-2::-1, :, ::1], t[-2::-1, :, ::1])
199    self.assertAllEqual(npt[-2::-1, :, ::2], t[-2::-1, :, ::2])
200
201  def testDegenerateSlices(self):
202    npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
203    t = constant_op.constant(npt)
204    # degenerate by offering a forward interval with a negative stride
205    self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :])
206    # degenerate with a reverse interval with a positive stride
207    self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :])
208    # empty interval in every dimension
209    self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1])
210
211  def testEllipsis(self):
212    npt = np.array(
213        [[[[[1, 2], [3, 4], [5, 6]]], [[[7, 8], [9, 10], [11, 12]]]]])
214    t = constant_op.constant(npt)
215
216    self.assertAllEqual(npt[0:], t[0:])
217    # implicit ellipsis
218    self.assertAllEqual(npt[0:, ...], t[0:, ...])
219    # ellipsis alone
220    self.assertAllEqual(npt[...], t[...])
221    # ellipsis at end
222    self.assertAllEqual(npt[0:1, ...], t[0:1, ...])
223    # ellipsis at begin
224    self.assertAllEqual(npt[..., 0:1], t[..., 0:1])
225    # ellipsis at middle
226    self.assertAllEqual(npt[0:1, ..., 0:1], t[0:1, ..., 0:1])
227
228  def testShrink(self):
229    npt = np.array([[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
230                     [[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]])
231    t = constant_op.constant(npt)
232    self.assertAllEqual(npt[:, :, :, :, 3], t[:, :, :, :, 3])
233    self.assertAllEqual(npt[..., 3], t[..., 3])
234    self.assertAllEqual(npt[:, 0], t[:, 0])
235    self.assertAllEqual(npt[:, :, 0], t[:, :, 0])
236
237  def testOpWithInputsOnDifferentDevices(self):
238    if not context.context().num_gpus():
239      self.skipTest('No GPUs found')
240
241    # The GPU kernel for the Reshape op requires that the
242    # shape input be on CPU.
243    value = constant_op.constant([1., 2.]).gpu()
244    shape = constant_op.constant([2, 1])
245    reshaped = array_ops.reshape(value, shape)
246    self.assertAllEqual([[1], [2]], reshaped.cpu())
247
248  def testInt64(self):
249    # Fill requires the first input to be an int32 tensor.
250    self.assertAllEqual(
251        [1.0, 1.0],
252        array_ops.fill(constant_op.constant([2], dtype=dtypes.int64),
253                       constant_op.constant(1)))
254
255  def testOutputOnHostMemory(self):
256    if not context.context().num_gpus():
257      self.skipTest('No GPUs found')
258    # The Shape op kernel on GPU places the output in host memory.
259    value = constant_op.constant([1.]).gpu()
260    shape = array_ops.shape(value)
261    self.assertEqual([1], shape.numpy())
262
263  def testSilentCopy(self):
264    if not context.context().num_gpus():
265      self.skipTest('No GPUs found')
266    # Temporarily replace the context
267    # pylint: disable=protected-access
268    del context._context
269    try:
270      context._context = context.Context(
271          device_policy=context.DEVICE_PLACEMENT_SILENT)
272      cpu_tensor = constant_op.constant(1.0)
273      gpu_tensor = cpu_tensor.gpu()
274      self.assertAllEqual(cpu_tensor + gpu_tensor, 2.0)
275    finally:
276      del context._context
277      context._context = context.Context()
278    # pylint: enable=protected-access
279
280  def testRandomUniform(self):
281    scalar_shape = constant_op.constant([], dtype=dtypes.int32)
282
283    x = random_ops.random_uniform(scalar_shape)
284    self.assertEquals(0, x.shape.ndims)
285    self.assertEquals(dtypes.float32, x.dtype)
286
287    x = random_ops.random_uniform(
288        scalar_shape, minval=constant_op.constant(5.),
289        maxval=constant_op.constant(6.))
290    self.assertLess(x, 6)
291    self.assertGreaterEqual(x, 5)
292
293  def testArgsToMatchingEagerDefault(self):
294    # Uses default
295    ctx = context.context()
296    t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int32)
297    self.assertEquals(t, dtypes.int32)
298    self.assertEquals(r[0].dtype, dtypes.int32)
299    t, r = execute.args_to_matching_eager([[3, 4]], ctx, dtypes.int64)
300    self.assertEquals(t, dtypes.int64)
301    self.assertEquals(r[0].dtype, dtypes.int64)
302    # Doesn't use default
303    t, r = execute.args_to_matching_eager(
304        [['string', 'arg']], ctx, dtypes.int32)
305    self.assertEquals(t, dtypes.string)
306    self.assertEquals(r[0].dtype, dtypes.string)
307
308  def testFlattenLayer(self):
309    flatten_layer = core.Flatten()
310    x = constant_op.constant([[[-10, -20], [-30, -40]], [[10, 20], [30, 40]]])
311    y = flatten_layer(x)
312    self.assertAllEqual([[-10, -20, -30, -40], [10, 20, 30, 40]], y)
313
314  def testIdentity(self):
315    self.assertAllEqual(2, array_ops.identity(2))
316
317  def testIdentityOnVariable(self):
318    if not context.context().num_gpus():
319      self.skipTest('No GPUs found')
320    with context.device('/gpu:0'):
321      v = resource_variable_ops.ResourceVariable(True)
322    self.assertAllEqual(True, array_ops.identity(v))
323
324  def testIncompatibleSetShape(self):
325    x = constant_op.constant(1)
326    with self.assertRaises(ValueError):
327      x.set_shape((1, 2))
328
329  def testCompatibleSetShape(self):
330    x = constant_op.constant([[1, 2]])
331    x.set_shape(tensor_shape.TensorShape([None, 2]))
332    self.assertEqual(x.get_shape(), (1, 2))
333
334  def testCastScalarToPrimitiveTypes(self):
335    x = constant_op.constant(1.3)
336    self.assertIsInstance(int(x), int)
337    self.assertEqual(int(x), 1)
338    self.assertIsInstance(float(x), float)
339    self.assertAllClose(float(x), 1.3)
340
341  def testCastNonScalarToPrimitiveTypesFails(self):
342    x = constant_op.constant([1.3, 2])
343    with self.assertRaises(TypeError):
344      int(x)
345    with self.assertRaises(TypeError):
346      float(x)
347
348  def testFormatString(self):
349    x = constant_op.constant(3.1415)
350    self.assertEqual('3.14', '{:.2f}'.format(x))
351
352  def testNoOpIsNone(self):
353    self.assertTrue(control_flow_ops.no_op() is None)
354
355
356if __name__ == '__main__':
357  test.main()
358