1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes as dtypes_lib
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import gradient_checker
27from tensorflow.python.ops import gradients_impl
28from tensorflow.python.platform import test
29from tensorflow.python.platform import tf_logging
30
31
32class MatrixDiagTest(test.TestCase):
33
34  def testVector(self):
35    with self.test_session(use_gpu=True):
36      v = np.array([1.0, 2.0, 3.0])
37      mat = np.diag(v)
38      v_diag = array_ops.matrix_diag(v)
39      self.assertEqual((3, 3), v_diag.get_shape())
40      self.assertAllEqual(v_diag.eval(), mat)
41
42  def _testBatchVector(self, dtype):
43    with self.test_session(use_gpu=True):
44      v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
45      mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
46                            [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
47                             [0.0, 0.0, 6.0]]]).astype(dtype)
48      v_batch_diag = array_ops.matrix_diag(v_batch)
49      self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
50      self.assertAllEqual(v_batch_diag.eval(), mat_batch)
51
52  def testBatchVector(self):
53    self._testBatchVector(np.float32)
54    self._testBatchVector(np.float64)
55    self._testBatchVector(np.int32)
56    self._testBatchVector(np.int64)
57    self._testBatchVector(np.bool)
58
59  def testInvalidShape(self):
60    with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
61      array_ops.matrix_diag(0)
62
63  def testInvalidShapeAtEval(self):
64    with self.test_session(use_gpu=True):
65      v = array_ops.placeholder(dtype=dtypes_lib.float32)
66      with self.assertRaisesOpError("input must be at least 1-dim"):
67        array_ops.matrix_diag(v).eval(feed_dict={v: 0.0})
68
69  def testGrad(self):
70    shapes = ((3,), (7, 4))
71    with self.test_session(use_gpu=True):
72      for shape in shapes:
73        x = constant_op.constant(np.random.rand(*shape), np.float32)
74        y = array_ops.matrix_diag(x)
75        error = gradient_checker.compute_gradient_error(x,
76                                                        x.get_shape().as_list(),
77                                                        y,
78                                                        y.get_shape().as_list())
79        self.assertLess(error, 1e-4)
80
81
82class MatrixSetDiagTest(test.TestCase):
83
84  def testSquare(self):
85    with self.test_session(use_gpu=True):
86      v = np.array([1.0, 2.0, 3.0])
87      mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
88      mat_set_diag = np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0],
89                               [1.0, 1.0, 3.0]])
90      output = array_ops.matrix_set_diag(mat, v)
91      self.assertEqual((3, 3), output.get_shape())
92      self.assertAllEqual(mat_set_diag, output.eval())
93
94  def testRectangular(self):
95    with self.test_session(use_gpu=True):
96      v = np.array([3.0, 4.0])
97      mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
98      expected = np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]])
99      output = array_ops.matrix_set_diag(mat, v)
100      self.assertEqual((2, 3), output.get_shape())
101      self.assertAllEqual(expected, output.eval())
102
103      v = np.array([3.0, 4.0])
104      mat = np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
105      expected = np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]])
106      output = array_ops.matrix_set_diag(mat, v)
107      self.assertEqual((3, 2), output.get_shape())
108      self.assertAllEqual(expected, output.eval())
109
110  def _testSquareBatch(self, dtype):
111    with self.test_session(use_gpu=True):
112      v_batch = np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]]).astype(dtype)
113      mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
114                            [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0],
115                             [2.0, 0.0, 6.0]]]).astype(dtype)
116
117      mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0],
118                                      [1.0, 0.0, -3.0]],
119                                     [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0],
120                                      [2.0, 0.0, -6.0]]]).astype(dtype)
121
122      output = array_ops.matrix_set_diag(mat_batch, v_batch)
123      self.assertEqual((2, 3, 3), output.get_shape())
124      self.assertAllEqual(mat_set_diag_batch, output.eval())
125
126  def testSquareBatch(self):
127    self._testSquareBatch(np.float32)
128    self._testSquareBatch(np.float64)
129    self._testSquareBatch(np.int32)
130    self._testSquareBatch(np.int64)
131    self._testSquareBatch(np.bool)
132
133  def testRectangularBatch(self):
134    with self.test_session(use_gpu=True):
135      v_batch = np.array([[-1.0, -2.0], [-4.0, -5.0]])
136      mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
137                            [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]])
138
139      mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
140                                     [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]])
141      output = array_ops.matrix_set_diag(mat_batch, v_batch)
142      self.assertEqual((2, 2, 3), output.get_shape())
143      self.assertAllEqual(mat_set_diag_batch, output.eval())
144
145  def testInvalidShape(self):
146    with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
147      array_ops.matrix_set_diag(0, [0])
148    with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
149      array_ops.matrix_set_diag([[0]], 0)
150
151  def testInvalidShapeAtEval(self):
152    with self.test_session(use_gpu=True):
153      v = array_ops.placeholder(dtype=dtypes_lib.float32)
154      with self.assertRaisesOpError("input must be at least 2-dim"):
155        array_ops.matrix_set_diag(v, [v]).eval(feed_dict={v: 0.0})
156      with self.assertRaisesOpError(
157          r"but received input shape: \[1,1\] and diagonal shape: \[\]"):
158        array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
159
160  def testGrad(self):
161    shapes = ((3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8))
162    with self.test_session(use_gpu=True):
163      for shape in shapes:
164        x = constant_op.constant(
165            np.random.rand(*shape), dtype=dtypes_lib.float32)
166        diag_shape = shape[:-2] + (min(shape[-2:]),)
167        x_diag = constant_op.constant(
168            np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
169        y = array_ops.matrix_set_diag(x, x_diag)
170        error_x = gradient_checker.compute_gradient_error(
171            x,
172            x.get_shape().as_list(), y,
173            y.get_shape().as_list())
174        self.assertLess(error_x, 1e-4)
175        error_x_diag = gradient_checker.compute_gradient_error(
176            x_diag,
177            x_diag.get_shape().as_list(), y,
178            y.get_shape().as_list())
179        self.assertLess(error_x_diag, 1e-4)
180
181  def testGradWithNoShapeInformation(self):
182    with self.test_session(use_gpu=True) as sess:
183      v = array_ops.placeholder(dtype=dtypes_lib.float32)
184      mat = array_ops.placeholder(dtype=dtypes_lib.float32)
185      grad_input = array_ops.placeholder(dtype=dtypes_lib.float32)
186      output = array_ops.matrix_set_diag(mat, v)
187      grads = gradients_impl.gradients(output, [mat, v], grad_ys=grad_input)
188      grad_input_val = np.random.rand(3, 3).astype(np.float32)
189      grad_vals = sess.run(
190          grads,
191          feed_dict={
192              v: 2 * np.ones(3),
193              mat: np.ones((3, 3)),
194              grad_input: grad_input_val
195          })
196      self.assertAllEqual(np.diag(grad_input_val), grad_vals[1])
197      self.assertAllEqual(grad_input_val - np.diag(np.diag(grad_input_val)),
198                          grad_vals[0])
199
200
201class MatrixDiagPartTest(test.TestCase):
202
203  def testSquare(self):
204    with self.test_session(use_gpu=True):
205      v = np.array([1.0, 2.0, 3.0])
206      mat = np.diag(v)
207      mat_diag = array_ops.matrix_diag_part(mat)
208      self.assertEqual((3,), mat_diag.get_shape())
209      self.assertAllEqual(mat_diag.eval(), v)
210
211  def testRectangular(self):
212    with self.test_session(use_gpu=True):
213      mat = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
214      mat_diag = array_ops.matrix_diag_part(mat)
215      self.assertAllEqual(mat_diag.eval(), np.array([1.0, 5.0]))
216      mat = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
217      mat_diag = array_ops.matrix_diag_part(mat)
218      self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
219
220  def _testSquareBatch(self, dtype):
221    with self.test_session(use_gpu=True):
222      v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
223      mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
224                            [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
225                             [0.0, 0.0, 6.0]]]).astype(dtype)
226      self.assertEqual(mat_batch.shape, (2, 3, 3))
227      mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
228      self.assertEqual((2, 3), mat_batch_diag.get_shape())
229      self.assertAllEqual(mat_batch_diag.eval(), v_batch)
230
231  def testSquareBatch(self):
232    self._testSquareBatch(np.float32)
233    self._testSquareBatch(np.float64)
234    self._testSquareBatch(np.int32)
235    self._testSquareBatch(np.int64)
236    self._testSquareBatch(np.bool)
237
238  def testRectangularBatch(self):
239    with self.test_session(use_gpu=True):
240      v_batch = np.array([[1.0, 2.0], [4.0, 5.0]])
241      mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]],
242                            [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0]]])
243      self.assertEqual(mat_batch.shape, (2, 2, 3))
244      mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
245      self.assertEqual((2, 2), mat_batch_diag.get_shape())
246      self.assertAllEqual(mat_batch_diag.eval(), v_batch)
247
248  def testInvalidShape(self):
249    with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
250      array_ops.matrix_diag_part(0)
251
252  def testInvalidShapeAtEval(self):
253    with self.test_session(use_gpu=True):
254      v = array_ops.placeholder(dtype=dtypes_lib.float32)
255      with self.assertRaisesOpError("input must be at least 2-dim"):
256        array_ops.matrix_diag_part(v).eval(feed_dict={v: 0.0})
257
258  def testGrad(self):
259    shapes = ((3, 3), (2, 3), (3, 2), (5, 3, 3))
260    with self.test_session(use_gpu=True):
261      for shape in shapes:
262        x = constant_op.constant(np.random.rand(*shape), dtype=np.float32)
263        y = array_ops.matrix_diag_part(x)
264        error = gradient_checker.compute_gradient_error(x,
265                                                        x.get_shape().as_list(),
266                                                        y,
267                                                        y.get_shape().as_list())
268        self.assertLess(error, 1e-4)
269
270
271class DiagTest(test.TestCase):
272
273  def _diagOp(self, diag, dtype, expected_ans, use_gpu):
274    with self.test_session(use_gpu=use_gpu):
275      tf_ans = array_ops.diag(ops.convert_to_tensor(diag.astype(dtype)))
276      out = tf_ans.eval()
277      tf_ans_inv = array_ops.diag_part(expected_ans)
278      inv_out = tf_ans_inv.eval()
279    self.assertAllClose(out, expected_ans)
280    self.assertAllClose(inv_out, diag)
281    self.assertShapeEqual(expected_ans, tf_ans)
282    self.assertShapeEqual(diag, tf_ans_inv)
283
284  def diagOp(self, diag, dtype, expected_ans):
285    self._diagOp(diag, dtype, expected_ans, False)
286    self._diagOp(diag, dtype, expected_ans, True)
287
288  def testEmptyTensor(self):
289    x = np.array([])
290    expected_ans = np.empty([0, 0])
291    self.diagOp(x, np.int32, expected_ans)
292
293  def testRankOneIntTensor(self):
294    x = np.array([1, 2, 3])
295    expected_ans = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
296    self.diagOp(x, np.int32, expected_ans)
297    self.diagOp(x, np.int64, expected_ans)
298
299  def testRankOneFloatTensor(self):
300    x = np.array([1.1, 2.2, 3.3])
301    expected_ans = np.array([[1.1, 0, 0], [0, 2.2, 0], [0, 0, 3.3]])
302    self.diagOp(x, np.float32, expected_ans)
303    self.diagOp(x, np.float64, expected_ans)
304
305  def testRankOneComplexTensor(self):
306    for dtype in [np.complex64, np.complex128]:
307      x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype=dtype)
308      expected_ans = np.array(
309          [[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 2.2 + 2.2j, 0 + 0j],
310           [0 + 0j, 0 + 0j, 3.3 + 3.3j]],
311          dtype=dtype)
312      self.diagOp(x, dtype, expected_ans)
313
314  def testRankTwoIntTensor(self):
315    x = np.array([[1, 2, 3], [4, 5, 6]])
316    expected_ans = np.array([[[[1, 0, 0], [0, 0, 0]], [[0, 2, 0], [0, 0, 0]],
317                              [[0, 0, 3], [0, 0, 0]]],
318                             [[[0, 0, 0], [4, 0, 0]], [[0, 0, 0], [0, 5, 0]],
319                              [[0, 0, 0], [0, 0, 6]]]])
320    self.diagOp(x, np.int32, expected_ans)
321    self.diagOp(x, np.int64, expected_ans)
322
323  def testRankTwoFloatTensor(self):
324    x = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
325    expected_ans = np.array(
326        [[[[1.1, 0, 0], [0, 0, 0]], [[0, 2.2, 0], [0, 0, 0]],
327          [[0, 0, 3.3], [0, 0, 0]]], [[[0, 0, 0], [4.4, 0, 0]],
328                                      [[0, 0, 0], [0, 5.5, 0]], [[0, 0, 0],
329                                                                 [0, 0, 6.6]]]])
330    self.diagOp(x, np.float32, expected_ans)
331    self.diagOp(x, np.float64, expected_ans)
332
333  def testRankTwoComplexTensor(self):
334    for dtype in [np.complex64, np.complex128]:
335      x = np.array(
336          [[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
337           [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]],
338          dtype=dtype)
339      expected_ans = np.array(
340          [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], [
341              [0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]
342          ], [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]], [[
343              [0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]
344          ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]
345             ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
346          dtype=dtype)
347      self.diagOp(x, dtype, expected_ans)
348
349  def testRankThreeFloatTensor(self):
350    x = np.array([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]])
351    expected_ans = np.array([[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]],
352                               [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]],
353                              [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]],
354                               [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]],
355                             [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]],
356                               [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]],
357                              [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]],
358                               [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]])
359    self.diagOp(x, np.float32, expected_ans)
360    self.diagOp(x, np.float64, expected_ans)
361
362  def testRankThreeComplexTensor(self):
363    for dtype in [np.complex64, np.complex128]:
364      x = np.array(
365          [[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
366           [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
367          dtype=dtype)
368      expected_ans = np.array(
369          [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
370              0 + 0j, 0 + 0j
371          ]]], [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
372              0 + 0j, 0 + 0j
373          ]]]], [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
374              0 + 0j, 0 + 0j
375          ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]], [[0 + 0j, 0 + 0j], [
376              0 + 0j, 0 + 0j
377          ]]]]], [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [
378              [5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]
379          ]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 6.6 + 6.6j], [
380              0 + 0j, 0 + 0j
381          ]]]], [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
382              7.7 + 7.7j, 0 + 0j
383          ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
384                [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
385          dtype=dtype)
386      self.diagOp(x, dtype, expected_ans)
387
388  def testRankFourNumberTensor(self):
389    for dtype in [np.float32, np.float64, np.int64, np.int32]:
390      # Input with shape [2, 1, 2, 3]
391      x = np.array(
392          [[[[1, 2, 3], [4, 5, 6]]], [[[7, 8, 9], [10, 11, 12]]]], dtype=dtype)
393      # Output with shape [2, 1, 2, 3, 2, 1, 2, 3]
394      expected_ans = np.array(
395          [[[[[[[[1, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
396              [[[0, 2, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
397          ], [[[[0, 0, 3], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]]], [[
398              [[[0, 0, 0], [4, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
399          ], [[[[0, 0, 0], [0, 5, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
400              [[[0, 0, 0], [0, 0, 6]]], [[[0, 0, 0], [0, 0, 0]]]
401          ]]]], [[[[[[[0, 0, 0], [0, 0, 0]]], [[[7, 0, 0], [0, 0, 0]]]], [
402              [[[0, 0, 0], [0, 0, 0]]], [[[0, 8, 0], [0, 0, 0]]]
403          ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 9], [0, 0, 0]]]]], [[
404              [[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [10, 0, 0]]]
405          ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 11, 0]]]
406             ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 12]]]]]]]],
407          dtype=dtype)
408      self.diagOp(x, dtype, expected_ans)
409
410  def testInvalidRank(self):
411    with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
412      array_ops.diag(0.0)
413
414
415class DiagPartOpTest(test.TestCase):
416
417  def setUp(self):
418    np.random.seed(0)
419
420  def _diagPartOp(self, tensor, dtype, expected_ans, use_gpu):
421    with self.test_session(use_gpu=use_gpu):
422      tensor = ops.convert_to_tensor(tensor.astype(dtype))
423      tf_ans_inv = array_ops.diag_part(tensor)
424      inv_out = tf_ans_inv.eval()
425    self.assertAllClose(inv_out, expected_ans)
426    self.assertShapeEqual(expected_ans, tf_ans_inv)
427
428  def diagPartOp(self, tensor, dtype, expected_ans):
429    self._diagPartOp(tensor, dtype, expected_ans, False)
430    self._diagPartOp(tensor, dtype, expected_ans, True)
431
432  def testRankTwoFloatTensor(self):
433    x = np.random.rand(3, 3)
434    i = np.arange(3)
435    expected_ans = x[i, i]
436    self.diagPartOp(x, np.float32, expected_ans)
437    self.diagPartOp(x, np.float64, expected_ans)
438
439  def testRankFourFloatTensorUnknownShape(self):
440    x = np.random.rand(3, 3)
441    i = np.arange(3)
442    expected_ans = x[i, i]
443    for shape in None, (None, 3), (3, None):
444      with self.test_session(use_gpu=False):
445        t = ops.convert_to_tensor(x.astype(np.float32))
446        t.set_shape(shape)
447        tf_ans = array_ops.diag_part(t)
448        out = tf_ans.eval()
449      self.assertAllClose(out, expected_ans)
450      self.assertShapeEqual(expected_ans, tf_ans)
451
452  def testRankFourFloatTensor(self):
453    x = np.random.rand(2, 3, 2, 3)
454    i = np.arange(2)[:, None]
455    j = np.arange(3)
456    expected_ans = x[i, j, i, j]
457    self.diagPartOp(x, np.float32, expected_ans)
458    self.diagPartOp(x, np.float64, expected_ans)
459
460  def testRankSixFloatTensor(self):
461    x = np.random.rand(2, 2, 2, 2, 2, 2)
462    i = np.arange(2)[:, None, None]
463    j = np.arange(2)[:, None]
464    k = np.arange(2)
465    expected_ans = x[i, j, k, i, j, k]
466    self.diagPartOp(x, np.float32, expected_ans)
467    self.diagPartOp(x, np.float64, expected_ans)
468
469  def testRankEightComplexTensor(self):
470    x = np.random.rand(2, 2, 2, 3, 2, 2, 2, 3)
471    i = np.arange(2)[:, None, None, None]
472    j = np.arange(2)[:, None, None]
473    k = np.arange(2)[:, None]
474    l = np.arange(3)
475    expected_ans = x[i, j, k, l, i, j, k, l]
476    self.diagPartOp(x, np.complex64, expected_ans)
477    self.diagPartOp(x, np.complex128, expected_ans)
478
479  def testOddRank(self):
480    w = np.random.rand(2)
481    x = np.random.rand(2, 2, 2)
482    self.assertRaises(ValueError, self.diagPartOp, w, np.float32, 0)
483    self.assertRaises(ValueError, self.diagPartOp, x, np.float32, 0)
484    with self.assertRaises(ValueError):
485      array_ops.diag_part(0.0)
486
487  def testUnevenDimensions(self):
488    w = np.random.rand(2, 5)
489    x = np.random.rand(2, 1, 2, 3)
490    self.assertRaises(ValueError, self.diagPartOp, w, np.float32, 0)
491    self.assertRaises(ValueError, self.diagPartOp, x, np.float32, 0)
492
493
494class DiagGradOpTest(test.TestCase):
495
496  def testDiagGrad(self):
497    np.random.seed(0)
498    shapes = ((3,), (3, 3), (3, 3, 3))
499    dtypes = (dtypes_lib.float32, dtypes_lib.float64)
500    with self.test_session(use_gpu=False):
501      errors = []
502      for shape in shapes:
503        for dtype in dtypes:
504          x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
505          y = array_ops.diag(x1)
506          error = gradient_checker.compute_gradient_error(
507              x1,
508              x1.get_shape().as_list(), y,
509              y.get_shape().as_list())
510          tf_logging.info("error = %f", error)
511          self.assertLess(error, 1e-4)
512
513
514class DiagGradPartOpTest(test.TestCase):
515
516  def testDiagPartGrad(self):
517    np.random.seed(0)
518    shapes = ((3, 3), (3, 3, 3, 3))
519    dtypes = (dtypes_lib.float32, dtypes_lib.float64)
520    with self.test_session(use_gpu=False):
521      errors = []
522      for shape in shapes:
523        for dtype in dtypes:
524          x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
525          y = array_ops.diag_part(x1)
526          error = gradient_checker.compute_gradient_error(
527              x1,
528              x1.get_shape().as_list(), y,
529              y.get_shape().as_list())
530          tf_logging.info("error = %f", error)
531          self.assertLess(error, 1e-4)
532
533
534if __name__ == "__main__":
535  test.main()
536