cholesky_op_test.py revision 5c9bc51857bc0c330d3ab976871ee3509647d1e7
19c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Copyright 2015 Google Inc. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Licensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# you may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# You may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# Unless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# distributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# See the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# limitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur# ============================================================================== 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur"""Tests for tensorflow.ops.tf.Cholesky.""" 17f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import absolute_import 18f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import division 19f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom __future__ import print_function 20f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevan 21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurimport numpy as np 22f2102f4e2c1c87f1d1bf9ab856a2849c54478760Vijay Vasudevanfrom six.moves import xrange # pylint: disable=redefined-builtin 23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurimport tensorflow as tf 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass CholeskyOpTest(tf.test.TestCase): 27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def _verifyCholesky(self, x): 29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur with self.test_session() as sess: 30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur # Verify that LL^T == x. 31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if x.ndim == 2: 32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur chol = tf.cholesky(x) 33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur verification = tf.matmul(chol, 34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur chol, 35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur transpose_a=False, 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur transpose_b=True) 37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur else: 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur chol = tf.batch_cholesky(x) 39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True) 40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur chol_np, verification_np = sess.run([chol, verification]) 41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self.assertAllClose(x, verification_np) 42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self.assertShapeEqual(x, chol) 43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur # Check that the cholesky is lower triangular, and has positive diagonal 44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur # elements. 45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur if chol_np.shape[-1] > 0: 46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2], 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur chol_np.shape[-1])) 48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur for chol_matrix in chol_reshaped: 49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self.assertAllClose(chol_matrix, np.tril(chol_matrix)) 50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self.assertTrue((np.diag(chol_matrix) > 0.0).all()) 51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def testBasic(self): 53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]])) 54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def testBatch(self): 56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2) 57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._verifyCholesky(simple_array) 58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._verifyCholesky(np.vstack((simple_array, simple_array))) 59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]]) 60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array))) 61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 625c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin # Generate random positive-definite matrices. 63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur matrices = np.random.rand(10, 5, 5) 64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur for i in xrange(10): 65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur matrices[i] = np.dot(matrices[i].T, matrices[i]) 66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._verifyCholesky(matrices) 67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def testNonSquareMatrix(self): 69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur with self.assertRaises(ValueError): 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]])) 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def testWrongDimensions(self): 73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tensor3 = tf.constant([1., 2.]) 74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur with self.assertRaises(ValueError): 75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tf.cholesky(tensor3) 76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def testNotInvertible(self): 785c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin # The input should be invertible. 79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur with self.test_session(): 805c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin with self.assertRaisesOpError("LLT decomposition was not successful. The" 815c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin " input might not be valid."): 825c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin # All rows of the matrix below add to zero 83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1., 84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 1.]])) 85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur def testEmpty(self): 87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._verifyCholesky(np.empty([0, 2, 2])) 88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur self._verifyCholesky(np.empty([2, 0, 0])) 89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 915c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhinclass CholeskyGradTest(tf.test.TestCase): 925c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin _backprop_block_size = 32 935c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin 945c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin def getShapes(self, shapeList): 955c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin return ((elem, int(np.floor(1.2 * elem))) for elem in shapeList) 965c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin 975c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin def testSmallMatrices(self): 985c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin np.random.seed(0) 995c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin shapes = self.getShapes([1, 2, 10]) 1005c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin self.runFiniteDifferences(shapes) 1015c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin 1025c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin def testOneBlockMatrices(self): 1035c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin np.random.seed(0) 1045c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin shapes = self.getShapes([self._backprop_block_size + 1]) 1055c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin self.runFiniteDifferences(shapes, dtypes=(tf.float32, tf.float64), 1065c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin scalarTest=True) 1075c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin 1085c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin def testTwoBlockMatrixFloat(self): 1095c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin np.random.seed(0) 1105c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin shapes = self.getShapes([2 * self._backprop_block_size + 1]) 1115c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin self.runFiniteDifferences(shapes, dtypes=(tf.float32,), scalarTest=True) 1125c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin 1135c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin def testTwoBlockMatrixDouble(self): 1145c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin np.random.seed(0) 1155c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin shapes = self.getShapes([2 * self._backprop_block_size + 1]) 1165c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin self.runFiniteDifferences(shapes, dtypes=(tf.float64,), scalarTest=True) 1175c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin 1185c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin def runFiniteDifferences(self, shapes, dtypes=(tf.float32, tf.float64), 1195c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin scalarTest=False): 1205c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin with self.test_session(use_gpu=False): 1215c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin for shape in shapes: 1225c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin for dtype in dtypes: 1235c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin if not(scalarTest): 1245c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin x = tf.constant(np.random.randn(shape[0], shape[1]), dtype) 1255c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin K = tf.matmul(x, tf.transpose(x)) / shape[0] # K is posdef 1265c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin y = tf.cholesky(K) 1275c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin else: # This is designed to be a faster test for larger matrices. 1285c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin x = tf.constant(np.random.randn(), dtype) 1295c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin R = tf.constant(np.random.randn(shape[0], shape[1]), dtype) 1305c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin e = tf.mul(R, x) 1315c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin K = tf.matmul(e, tf.transpose(e)) / shape[0] # K is posdef 1325c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin y = tf.reduce_mean(tf.cholesky(K)) 1335c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin error = tf.test.compute_gradient_error(x, x._shape_as_list(), 1345c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin y, y._shape_as_list()) 1355c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin tf.logging.info("error = %f", error) 1365c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin if dtype == tf.float64: 1375c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin self.assertLess(error, 1e-5) 1385c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin else: 1395c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin self.assertLess(error, 2e-3) 1405c9bc51857bc0c330d3ab976871ee3509647d1e7Illia Polosukhin 141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurif __name__ == "__main__": 142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur tf.test.main() 143