10cf9ed3a719c0782695154d5a0bca260001cec15A. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Licensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# you may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# You may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#     http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Unless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# distributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# See the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# limitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# ==============================================================================
15f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur"""Tests for tensorflow.ops.tf.Cholesky."""
165866e065bc95c1d7de8a27413b368016941889a6Justine Tunney
17f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import absolute_import
18f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import division
19f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import print_function
20f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevan
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurimport numpy as np
22f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom six.moves import xrange  # pylint: disable=redefined-builtin
235866e065bc95c1d7de8a27413b368016941889a6Justine Tunney
24473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlowerfrom tensorflow.python.client import session
255866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.framework import constant_op
265866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.framework import dtypes as dtypes_lib
27ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlowerfrom tensorflow.python.framework import errors_impl
28473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlowerfrom tensorflow.python.framework import ops
295866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import array_ops
30a23255bc079bb94006aa0bfdc5000eed0d97098aA. Unique TensorFlowerfrom tensorflow.python.ops import control_flow_ops
31473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlowerfrom tensorflow.python.ops import gen_linalg_ops
325866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import gradient_checker
33473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlowerfrom tensorflow.python.ops import gradients_impl
345866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import linalg_ops
355866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.ops import math_ops
36ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlowerfrom tensorflow.python.ops import random_ops
37d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlowerfrom tensorflow.python.ops import variables
38bc52fbda2bbe458c9ff5f20ebc48188959ebe026A. Unique TensorFlowerfrom tensorflow.python.ops.linalg import linalg
395866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.platform import test
405866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyfrom tensorflow.python.platform import tf_logging
41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
43473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower# Different gradient implementations for benchmark purposes
44473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlowerdef SpecializedGrad(l, grad):
45473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  return gen_linalg_ops.cholesky_grad(l, grad)
46473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
47473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
48473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlowerdef _GradWithInverseL(l, l_inverse, grad):
49473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  middle = math_ops.matmul(l, grad, adjoint_a=True)
50473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  middle = array_ops.matrix_set_diag(middle,
51473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                                     0.5 * array_ops.matrix_diag_part(middle))
52473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  middle = array_ops.matrix_band_part(middle, -1, 0)
53473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  grad_a = math_ops.matmul(
54473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
55473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
56473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  return grad_a * 0.5
57473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
58473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
59473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlowerdef TriAngSolveCompositeGrad(l, grad):
60473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
61473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
62473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  # Compute ((l^{H} @ grad) * (tril(ones)-1/2*eye)) = middle
63473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  middle = math_ops.matmul(l, grad, adjoint_a=True)
64473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  middle = array_ops.matrix_set_diag(middle,
65473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                                     0.5 * array_ops.matrix_diag_part(middle))
66473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  middle = array_ops.matrix_band_part(middle, -1, 0)
67473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
68473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  # Compute l^{-H} @ middle = z
69473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  l_inverse_middle = linalg_ops.matrix_triangular_solve(l, middle, adjoint=True)
70473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
71473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  # We need to compute z @ l^{-1}. With matrix_triangular_solve we
72473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  # actually compute l^{-H} @ z^{H} = grad. Since we later add grad^{H}
73473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  # we can ommit the conjugate transpose here.
74473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  z_h = math_ops.conj(array_ops.matrix_transpose(l_inverse_middle))
75473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  grad_a = linalg_ops.matrix_triangular_solve(l, z_h, adjoint=True)
76bc52fbda2bbe458c9ff5f20ebc48188959ebe026A. Unique TensorFlower  grad_a += linalg.adjoint(grad_a)
77473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  return grad_a * 0.5
78473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
79473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
80473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlowerdef MatrixInverseCompositeGrad(l, grad):
81473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  l_inverse = linalg_ops.matrix_inverse(l)
82473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  return _GradWithInverseL(l, l_inverse, grad)
83473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
84473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
85473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlowerdef TriAngInvCompositeGrad(l, grad):
86473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  num_rows = array_ops.shape(l)[-1]
87473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  batch_shape = array_ops.shape(l)[:-2]
88d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower  l_inverse = linalg_ops.matrix_triangular_solve(l,
89d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                                                 linalg_ops.eye(
90d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                                                     num_rows,
91d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                                                     batch_shape=batch_shape,
92d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                                                     dtype=l.dtype))
93473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  return _GradWithInverseL(l, l_inverse, grad)
94473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
95473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
965866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyclass CholeskyOpTest(test.TestCase):
97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
985082486121e539b6c002512f2a756e6b066696b4Eugene Brevdo  def _verifyCholeskyBase(self, sess, x, chol, verification):
995082486121e539b6c002512f2a756e6b066696b4Eugene Brevdo    chol_np, verification_np = sess.run([chol, verification])
100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self.assertAllClose(x, verification_np)
101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self.assertShapeEqual(x, chol)
102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    # Check that the cholesky is lower triangular, and has positive diagonal
103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    # elements.
104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    if chol_np.shape[-1] > 0:
105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2],
106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                           chol_np.shape[-1]))
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      for chol_matrix in chol_reshaped:
108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        self.assertAllClose(chol_matrix, np.tril(chol_matrix))
109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        self.assertTrue((np.diag(chol_matrix) > 0.0).all())
110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
1115082486121e539b6c002512f2a756e6b066696b4Eugene Brevdo  def _verifyCholesky(self, x):
1125082486121e539b6c002512f2a756e6b066696b4Eugene Brevdo    # Verify that LL^T == x.
11372c023d3967a3218cd3d830ce6e57f7c4d87a18cA. Unique TensorFlower    with self.test_session(use_gpu=True) as sess:
1145866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      chol = linalg_ops.cholesky(x)
1155866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      verification = math_ops.matmul(chol, chol, adjoint_b=True)
1165082486121e539b6c002512f2a756e6b066696b4Eugene Brevdo      self._verifyCholeskyBase(sess, x, chol, verification)
1175082486121e539b6c002512f2a756e6b066696b4Eugene Brevdo
118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testBasic(self):
119473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]])
12072c023d3967a3218cd3d830ce6e57f7c4d87a18cA. Unique TensorFlower    for dtype in (np.float32, np.float64):
121473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      self._verifyCholesky(data.astype(dtype))
122473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    for dtype in (np.complex64, np.complex128):
123473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      complex_data = np.tril(1j * data, -1).astype(dtype)
124473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      complex_data += np.triu(-1j * data, 1).astype(dtype)
125473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      complex_data += data
126473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      self._verifyCholesky(complex_data)
127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testBatch(self):
129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    simple_array = np.array([[[1., 0.], [0., 5.]]])  # shape (1, 2, 2)
130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(simple_array)
131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(np.vstack((simple_array, simple_array)))
132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array)))
134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
135473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    # Generate random positive-definite matrices.
136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    matrices = np.random.rand(10, 5, 5)
137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    for i in xrange(10):
138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      matrices[i] = np.dot(matrices[i].T, matrices[i])
139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(matrices)
140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
141473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    # Generate random complex valued positive-definite matrices.
142473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    matrices = np.random.rand(10, 5, 5) + 1j * np.random.rand(10, 5, 5)
143473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    for i in xrange(10):
144473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      matrices[i] = np.dot(matrices[i].T.conj(), matrices[i])
145473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    self._verifyCholesky(matrices)
146473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testNonSquareMatrix(self):
148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    with self.assertRaises(ValueError):
1495866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
150cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower    with self.assertRaises(ValueError):
1515866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      linalg_ops.cholesky(
152fa9c7b76f2ff4d570c1b35b418309853e15b8728A. Unique TensorFlower          np.array([[[1., 2., 3.], [3., 4., 5.]], [[1., 2., 3.], [3., 4., 5.]]
153fa9c7b76f2ff4d570c1b35b418309853e15b8728A. Unique TensorFlower                   ]))
154f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testWrongDimensions(self):
1565866e065bc95c1d7de8a27413b368016941889a6Justine Tunney    tensor3 = constant_op.constant([1., 2.])
157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    with self.assertRaises(ValueError):
1585866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      linalg_ops.cholesky(tensor3)
159cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower    with self.assertRaises(ValueError):
1605866e065bc95c1d7de8a27413b368016941889a6Justine Tunney      linalg_ops.cholesky(tensor3)
161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
16272c023d3967a3218cd3d830ce6e57f7c4d87a18cA. Unique TensorFlower  def testNotInvertibleCPU(self):
163cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower    # The input should be invertible.
164ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower    with self.test_session(use_gpu=True):
165ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower      with self.assertRaisesRegexp(
166ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower          errors_impl.InvalidArgumentError,
16772c023d3967a3218cd3d830ce6e57f7c4d87a18cA. Unique TensorFlower          "Cholesky decomposition was not successful. The"
16872c023d3967a3218cd3d830ce6e57f7c4d87a18cA. Unique TensorFlower          " input might not be valid."):
169cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower        # All rows of the matrix below add to zero
170fa9c7b76f2ff4d570c1b35b418309853e15b8728A. Unique TensorFlower        self._verifyCholesky(
171fa9c7b76f2ff4d570c1b35b418309853e15b8728A. Unique TensorFlower            np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1., 1.]]))
172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testEmpty(self):
174f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(np.empty([0, 2, 2]))
175f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(np.empty([2, 0, 0]))
176f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
177ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower  def testConcurrentExecutesWithoutError(self):
178ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower    with self.test_session(use_gpu=True) as sess:
179ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower      matrix1 = random_ops.random_normal([5, 5], seed=42)
180ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower      matrix2 = random_ops.random_normal([5, 5], seed=42)
181ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower      matrix1 = math_ops.matmul(matrix1, matrix1, adjoint_a=True)
182ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower      matrix2 = math_ops.matmul(matrix2, matrix2, adjoint_a=True)
183ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower      c1 = linalg_ops.cholesky(matrix1)
184ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower      c2 = linalg_ops.cholesky(matrix2)
185ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower      c1_val, c2_val = sess.run([c1, c2])
186ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower      self.assertAllEqual(c1_val, c2_val)
187ac742fab0bf4c8b7bde5febc33e09fedfcb57aa1A. Unique TensorFlower
188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
1895866e065bc95c1d7de8a27413b368016941889a6Justine Tunneyclass CholeskyGradTest(test.TestCase):
1905c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  _backprop_block_size = 32
1915c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
1925c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def getShapes(self, shapeList):
1935c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    return ((elem, int(np.floor(1.2 * elem))) for elem in shapeList)
1945c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
1955c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def testSmallMatrices(self):
1965c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    np.random.seed(0)
1975c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    shapes = self.getShapes([1, 2, 10])
198473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    self.runFiniteDifferences(
199473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower        shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64))
200473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
20142fcbb196052823c4393a4d5d8682ca425253f6aA. Unique TensorFlower  def testSmallMatricesComplex(self):
20242fcbb196052823c4393a4d5d8682ca425253f6aA. Unique TensorFlower    np.random.seed(0)
20342fcbb196052823c4393a4d5d8682ca425253f6aA. Unique TensorFlower    shapes = self.getShapes([1, 2, 10])
20442fcbb196052823c4393a4d5d8682ca425253f6aA. Unique TensorFlower    self.runFiniteDifferences(
20542fcbb196052823c4393a4d5d8682ca425253f6aA. Unique TensorFlower        shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128))
2065c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
2075c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def testOneBlockMatrices(self):
2085c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    np.random.seed(0)
2095c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    shapes = self.getShapes([self._backprop_block_size + 1])
210fa9c7b76f2ff4d570c1b35b418309853e15b8728A. Unique TensorFlower    self.runFiniteDifferences(
2115866e065bc95c1d7de8a27413b368016941889a6Justine Tunney        shapes,
2125866e065bc95c1d7de8a27413b368016941889a6Justine Tunney        dtypes=(dtypes_lib.float32, dtypes_lib.float64),
2135866e065bc95c1d7de8a27413b368016941889a6Justine Tunney        scalarTest=True)
2145c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
2155c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def testTwoBlockMatrixFloat(self):
2165c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    np.random.seed(0)
2175c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    shapes = self.getShapes([2 * self._backprop_block_size + 1])
2185866e065bc95c1d7de8a27413b368016941889a6Justine Tunney    self.runFiniteDifferences(
2195866e065bc95c1d7de8a27413b368016941889a6Justine Tunney        shapes, dtypes=(dtypes_lib.float32,), scalarTest=True)
2205c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
2215c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def testTwoBlockMatrixDouble(self):
2225c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    np.random.seed(0)
2235c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    shapes = self.getShapes([2 * self._backprop_block_size + 1])
2245866e065bc95c1d7de8a27413b368016941889a6Justine Tunney    self.runFiniteDifferences(
2255866e065bc95c1d7de8a27413b368016941889a6Justine Tunney        shapes, dtypes=(dtypes_lib.float64,), scalarTest=True)
2265c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
227473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  def testTwoBlockMatrixComplexFloat(self):
228473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    np.random.seed(0)
229473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    shapes = self.getShapes([2 * self._backprop_block_size + 1])
230473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    self.runFiniteDifferences(
231473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower        shapes, dtypes=(dtypes_lib.complex64,), scalarTest=True)
232473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
233473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  def testTwoBlockMatrixComplexDouble(self):
234473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    np.random.seed(0)
235473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    shapes = self.getShapes([2 * self._backprop_block_size + 1])
236473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    self.runFiniteDifferences(
237473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower        shapes, dtypes=(dtypes_lib.complex128,), scalarTest=True)
238473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
239473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  def testAgainstSpecialized(self):
240473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    np.random.seed(0)
241473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    data = np.random.randn(33, 33).astype(np.float32)
242473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    data = np.matmul(data, data.T)
243473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    grad_data = np.random.randn(*data.shape).astype(np.float32)
244473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
245473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    with ops.Graph().as_default(), self.test_session(use_gpu=False) as s:
246473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      x = constant_op.constant(data, dtypes_lib.float32)
247473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      chol = linalg_ops.cholesky(x)
248473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      composite_grad = gradients_impl.gradients(chol, x, grad_data)[0]
249473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      specialized_grad = SpecializedGrad(chol, grad_data)
250473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      reference, actual = s.run([specialized_grad, composite_grad])
251473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    self.assertAllClose(reference, actual)
252473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
253fa9c7b76f2ff4d570c1b35b418309853e15b8728A. Unique TensorFlower  def runFiniteDifferences(self,
254fa9c7b76f2ff4d570c1b35b418309853e15b8728A. Unique TensorFlower                           shapes,
255473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                           dtypes=(dtypes_lib.float32, dtypes_lib.float64,
256473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                                   dtypes_lib.complex64, dtypes_lib.complex128),
2575c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin                           scalarTest=False):
258473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    with self.test_session(use_gpu=True):
2595c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin      for shape in shapes:
260cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower        for batch in False, True:
261cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower          for dtype in dtypes:
262cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower            if not scalarTest:
263473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              data = np.random.randn(shape[0], shape[1])
264473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              if dtype.is_complex:
265473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                data = data.astype(np.complex64)
266473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                data += 1j * np.random.randn(shape[0], shape[1])
267473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              x = constant_op.constant(data, dtype)
268473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              tensor = math_ops.matmul(
269473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                  x, math_ops.conj(array_ops.transpose(x))) / shape[0]
270cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower            else:
271cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower              # This is designed to be a faster test for larger matrices.
272473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              data = np.random.randn()
273473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              if dtype.is_complex:
274473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                data = np.complex64(data)
275473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                data += 1j * np.random.randn()
276473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              x = constant_op.constant(data, dtype)
2775866e065bc95c1d7de8a27413b368016941889a6Justine Tunney              R = constant_op.constant(
2785866e065bc95c1d7de8a27413b368016941889a6Justine Tunney                  np.random.randn(shape[0], shape[1]), dtype)
279cb4acf5e47574deccf0c578d6d1d18d74f6117afAndrew Selle              e = math_ops.multiply(R, x)
280473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              tensor = math_ops.matmul(
281473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                  e, math_ops.conj(array_ops.transpose(e))) / shape[0]
282cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower
283cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower            # Inner-most matrices in tensor are positive definite.
284cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower            if batch:
2855866e065bc95c1d7de8a27413b368016941889a6Justine Tunney              tensor = array_ops.tile(
2865866e065bc95c1d7de8a27413b368016941889a6Justine Tunney                  array_ops.expand_dims(tensor, 0), [4, 1, 1])
2875866e065bc95c1d7de8a27413b368016941889a6Justine Tunney            y = linalg_ops.cholesky(tensor)
288fa9c7b76f2ff4d570c1b35b418309853e15b8728A. Unique TensorFlower            if scalarTest:
2895866e065bc95c1d7de8a27413b368016941889a6Justine Tunney              y = math_ops.reduce_mean(y)
290473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower            error = gradient_checker.compute_gradient_error(
291473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower                x, x._shape_as_list(), y, y._shape_as_list())
2925866e065bc95c1d7de8a27413b368016941889a6Justine Tunney            tf_logging.info("error = %f", error)
2935866e065bc95c1d7de8a27413b368016941889a6Justine Tunney            if dtype == dtypes_lib.float64:
294cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower              self.assertLess(error, 1e-5)
295473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower            elif dtype == dtypes_lib.complex128:
296473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              self.assertLess(error, 5e-5)
297cafe948be40c7883ce116e2516c4c67a2045558bA. Unique TensorFlower            else:
298473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              self.assertLess(error, 5e-3)
299473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
300473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
301473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlowerclass CholeskyBenchmark(test.Benchmark):
302473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
303d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower  shapes = [
304d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (4, 4),
305d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (10, 10),
306d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (16, 16),
307d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (101, 101),
308d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (256, 256),
309d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (1000, 1000),
310d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (1024, 1024),
311d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (2048, 2048),
312d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (513, 2, 2),
313d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (513, 8, 8),
314d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (513, 256, 256),
315d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      (4, 513, 2, 2),
316473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  ]
317473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
318d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower  def _GenerateMatrix(self, shape):
319d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower    batch_shape = shape[:-2]
320d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower    shape = shape[-2:]
321d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower    assert shape[0] == shape[1]
322d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower    n = shape[0]
323d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower    matrix = np.ones(shape).astype(np.float32) / (
324d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower        2.0 * n) + np.diag(np.ones(n).astype(np.float32))
325d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower    return np.tile(matrix, batch_shape + (1, 1))
326473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
327473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  def benchmarkCholeskyOp(self):
328d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower    for shape in self.shapes:
329473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      with ops.Graph().as_default(), \
330473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower          session.Session() as sess, \
331473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower          ops.device("/cpu:0"):
332d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower        matrix = variables.Variable(self._GenerateMatrix(shape))
333d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower        l = linalg_ops.cholesky(matrix)
334d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower        variables.global_variables_initializer().run()
335473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower        self.run_op_benchmark(
336d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower            sess,
337d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower            control_flow_ops.group(
338d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                l,),
339473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower            min_iters=25,
340d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower            name="cholesky_cpu_{shape}".format(shape=shape))
341473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
342473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower      if test.is_gpu_available(True):
343473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower        with ops.Graph().as_default(), \
344473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower            session.Session() as sess, \
34528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower            ops.device("/device:GPU:0"):
346d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower          matrix = variables.Variable(self._GenerateMatrix(shape))
347d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower          l = linalg_ops.cholesky(matrix)
348d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower          variables.global_variables_initializer().run()
349473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower          self.run_op_benchmark(
35019dba27bd7882e12990c0eb611c80064c1e426f1A. Unique TensorFlower              sess,
35119dba27bd7882e12990c0eb611c80064c1e426f1A. Unique TensorFlower              control_flow_ops.group(
35219dba27bd7882e12990c0eb611c80064c1e426f1A. Unique TensorFlower                  l,),
353473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              min_iters=25,
354d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower              name="cholesky_gpu_{shape}".format(shape=shape))
355473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
356473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower  def benchmarkGradVariants(self):
357d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower
358473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    def _BenchmarkGrad(grad_fn, name, device):
359d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      for shape in self.shapes:
360d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower        matrix = self._GenerateMatrix(shape)
361473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower        with ops.Graph().as_default(), \
362473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower            session.Session() as sess, \
363473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower            ops.device(device):
364d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower          l = variables.Variable(np.linalg.cholesky(matrix))
365d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower          grad_matrix = variables.Variable(
366d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower              np.random.randn(*matrix.shape).astype(np.float32))
367d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower          grad = grad_fn(l, grad_matrix)
368d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower          variables.global_variables_initializer().run()
369473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower          self.run_op_benchmark(
370d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower              sess,
371d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower              control_flow_ops.group(
372d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                  grad,),
373473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower              min_iters=25,
374d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower              name="{name}_{dev}_{shape}".format(
375d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                  name=name, dev=grad.device, shape=shape))
376473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower
377473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    if test.is_gpu_available(True):
378d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      _BenchmarkGrad(MatrixInverseCompositeGrad, "composite_matrix_inverse",
379d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                     "/device:GPU:0")
380d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      _BenchmarkGrad(TriAngInvCompositeGrad, "composite_tri_ang_inverse",
381d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                     "/device:GPU:0")
382d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower      _BenchmarkGrad(TriAngSolveCompositeGrad, "composite_triangular_solve",
383d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                     "/device:GPU:0")
384d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower
385d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower    _BenchmarkGrad(MatrixInverseCompositeGrad, "composite_matrix_inverse",
386d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                   "/cpu:0")
387d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower    _BenchmarkGrad(TriAngInvCompositeGrad, "composite_tri_ang_inverse",
388d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                   "/cpu:0")
389d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower    _BenchmarkGrad(TriAngSolveCompositeGrad, "composite_triangular_solve",
390d7e425f0bd61676aa347a93a81d8e89bb5c1a1a1A. Unique TensorFlower                   "/cpu:0")
391473a590c9cd26cdde1e77117778e3fd50a36d7dfA. Unique TensorFlower    _BenchmarkGrad(SpecializedGrad, "specialized", "/cpu:0")
3925c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
393fa9c7b76f2ff4d570c1b35b418309853e15b8728A. Unique TensorFlower
394f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurif __name__ == "__main__":
3955866e065bc95c1d7de8a27413b368016941889a6Justine Tunney  test.main()
396