cholesky_op_test.py revision 5c9bc51857bc0c330d3ab976871ee3509647d1e7
19c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Copyright 2015 Google Inc. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Licensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# you may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# You may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#     http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur#
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Unless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# distributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# See the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# limitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# ==============================================================================
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur"""Tests for tensorflow.ops.tf.Cholesky."""
17f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import absolute_import
18f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import division
19f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import print_function
20f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevan
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurimport numpy as np
22f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom six.moves import xrange  # pylint: disable=redefined-builtin
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurimport tensorflow as tf
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass CholeskyOpTest(tf.test.TestCase):
27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def _verifyCholesky(self, x):
29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    with self.test_session() as sess:
30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      # Verify that LL^T == x.
31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      if x.ndim == 2:
32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        chol = tf.cholesky(x)
33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        verification = tf.matmul(chol,
34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                 chol,
35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                 transpose_a=False,
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                 transpose_b=True)
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      else:
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        chol = tf.batch_cholesky(x)
39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      chol_np, verification_np = sess.run([chol, verification])
41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self.assertAllClose(x, verification_np)
42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self.assertShapeEqual(x, chol)
43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    # Check that the cholesky is lower triangular, and has positive diagonal
44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    # elements.
45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    if chol_np.shape[-1] > 0:
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2],
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                           chol_np.shape[-1]))
48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      for chol_matrix in chol_reshaped:
49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        self.assertAllClose(chol_matrix, np.tril(chol_matrix))
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        self.assertTrue((np.diag(chol_matrix) > 0.0).all())
51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testBasic(self):
53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]))
54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testBatch(self):
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    simple_array = np.array([[[1., 0.], [0., 5.]]])  # shape (1, 2, 2)
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(simple_array)
58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(np.vstack((simple_array, simple_array)))
59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array)))
61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
625c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    #  Generate random positive-definite matrices.
63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    matrices = np.random.rand(10, 5, 5)
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    for i in xrange(10):
65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      matrices[i] = np.dot(matrices[i].T, matrices[i])
66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(matrices)
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testNonSquareMatrix(self):
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    with self.assertRaises(ValueError):
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testWrongDimensions(self):
73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    tensor3 = tf.constant([1., 2.])
74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    with self.assertRaises(ValueError):
75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      tf.cholesky(tensor3)
76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testNotInvertible(self):
785c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin     # The input should be invertible.
79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    with self.test_session():
805c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin      with self.assertRaisesOpError("LLT decomposition was not successful. The"
815c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin                                    " input might not be valid."):
825c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin         # All rows of the matrix below add to zero
83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1.,
84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                                                       1.]]))
85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  def testEmpty(self):
87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(np.empty([0, 2, 2]))
88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    self._verifyCholesky(np.empty([2, 0, 0]))
89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
915c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhinclass CholeskyGradTest(tf.test.TestCase):
925c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  _backprop_block_size = 32
935c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
945c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def getShapes(self, shapeList):
955c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    return ((elem, int(np.floor(1.2 * elem))) for elem in shapeList)
965c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
975c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def testSmallMatrices(self):
985c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    np.random.seed(0)
995c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    shapes = self.getShapes([1, 2, 10])
1005c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    self.runFiniteDifferences(shapes)
1015c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
1025c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def testOneBlockMatrices(self):
1035c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    np.random.seed(0)
1045c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    shapes = self.getShapes([self._backprop_block_size + 1])
1055c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    self.runFiniteDifferences(shapes, dtypes=(tf.float32, tf.float64),
1065c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin                              scalarTest=True)
1075c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
1085c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def testTwoBlockMatrixFloat(self):
1095c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    np.random.seed(0)
1105c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    shapes = self.getShapes([2 * self._backprop_block_size + 1])
1115c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    self.runFiniteDifferences(shapes, dtypes=(tf.float32,), scalarTest=True)
1125c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
1135c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def testTwoBlockMatrixDouble(self):
1145c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    np.random.seed(0)
1155c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    shapes = self.getShapes([2 * self._backprop_block_size + 1])
1165c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    self.runFiniteDifferences(shapes, dtypes=(tf.float64,), scalarTest=True)
1175c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
1185c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin  def runFiniteDifferences(self, shapes, dtypes=(tf.float32, tf.float64),
1195c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin                           scalarTest=False):
1205c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin    with self.test_session(use_gpu=False):
1215c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin      for shape in shapes:
1225c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin        for dtype in dtypes:
1235c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin          if not(scalarTest):
1245c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin            x = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
1255c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin            K = tf.matmul(x, tf.transpose(x)) / shape[0]  # K is posdef
1265c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin            y = tf.cholesky(K)
1275c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin          else:  # This is designed to be a faster test for larger matrices.
1285c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin            x = tf.constant(np.random.randn(), dtype)
1295c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin            R = tf.constant(np.random.randn(shape[0], shape[1]), dtype)
1305c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin            e = tf.mul(R, x)
1315c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin            K = tf.matmul(e, tf.transpose(e)) / shape[0]  # K is posdef
1325c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin            y = tf.reduce_mean(tf.cholesky(K))
1335c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin          error = tf.test.compute_gradient_error(x, x._shape_as_list(),
1345c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin                                                 y, y._shape_as_list())
1355c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin          tf.logging.info("error = %f", error)
1365c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin          if dtype == tf.float64:
1375c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin            self.assertLess(error, 1e-5)
1385c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin          else:
1395c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin            self.assertLess(error, 2e-3)
1405c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin
141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurif __name__ == "__main__":
142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  tf.test.main()
143