1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import random_seed 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import linalg_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import random_ops 25from tensorflow.python.ops.linalg import linalg as linalg_lib 26from tensorflow.python.ops.linalg import linear_operator_test_util 27from tensorflow.python.platform import test 28 29linalg = linalg_lib 30random_seed.set_random_seed(23) 31 32 33class LinearOperatorDiagTest( 34 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 35 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 36 37 def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): 38 diag = linear_operator_test_util.random_sign_uniform( 39 shape[:-1], minval=1., maxval=2., dtype=dtype) 40 if use_placeholder: 41 diag_ph = array_ops.placeholder(dtype=dtype) 42 # Evaluate the diag here because (i) you cannot feed a tensor, and (ii) 43 # diag is random and we want the same value used for both mat and 44 # feed_dict. 45 diag = diag.eval() 46 operator = linalg.LinearOperatorDiag(diag_ph) 47 feed_dict = {diag_ph: diag} 48 else: 49 operator = linalg.LinearOperatorDiag(diag) 50 feed_dict = None 51 52 mat = array_ops.matrix_diag(diag) 53 54 return operator, mat, feed_dict 55 56 def test_assert_positive_definite_raises_for_zero_eigenvalue(self): 57 # Matrix with one positive eigenvalue and one zero eigenvalue. 58 with self.test_session(): 59 diag = [1.0, 0.0] 60 operator = linalg.LinearOperatorDiag(diag) 61 62 # is_self_adjoint should be auto-set for real diag. 63 self.assertTrue(operator.is_self_adjoint) 64 with self.assertRaisesOpError("non-positive.*not positive definite"): 65 operator.assert_positive_definite().run() 66 67 def test_assert_positive_definite_raises_for_negative_real_eigvalues(self): 68 with self.test_session(): 69 diag_x = [1.0, -2.0] 70 diag_y = [0., 0.] # Imaginary eigenvalues should not matter. 71 diag = math_ops.complex(diag_x, diag_y) 72 operator = linalg.LinearOperatorDiag(diag) 73 74 # is_self_adjoint should not be auto-set for complex diag. 75 self.assertTrue(operator.is_self_adjoint is None) 76 with self.assertRaisesOpError("non-positive real.*not positive definite"): 77 operator.assert_positive_definite().run() 78 79 def test_assert_positive_definite_does_not_raise_if_pd_and_complex(self): 80 with self.test_session(): 81 x = [1., 2.] 82 y = [1., 0.] 83 diag = math_ops.complex(x, y) # Re[diag] > 0. 84 # Should not fail 85 linalg.LinearOperatorDiag(diag).assert_positive_definite().run() 86 87 def test_assert_non_singular_raises_if_zero_eigenvalue(self): 88 # Singlular matrix with one positive eigenvalue and one zero eigenvalue. 89 with self.test_session(): 90 diag = [1.0, 0.0] 91 operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True) 92 with self.assertRaisesOpError("Singular operator"): 93 operator.assert_non_singular().run() 94 95 def test_assert_non_singular_does_not_raise_for_complex_nonsingular(self): 96 with self.test_session(): 97 x = [1., 0.] 98 y = [0., 1.] 99 diag = math_ops.complex(x, y) 100 # Should not raise. 101 linalg.LinearOperatorDiag(diag).assert_non_singular().run() 102 103 def test_assert_self_adjoint_raises_if_diag_has_complex_part(self): 104 with self.test_session(): 105 x = [1., 0.] 106 y = [0., 1.] 107 diag = math_ops.complex(x, y) 108 operator = linalg.LinearOperatorDiag(diag) 109 with self.assertRaisesOpError("imaginary.*not self-adjoint"): 110 operator.assert_self_adjoint().run() 111 112 def test_assert_self_adjoint_does_not_raise_for_diag_with_zero_imag(self): 113 with self.test_session(): 114 x = [1., 0.] 115 y = [0., 0.] 116 diag = math_ops.complex(x, y) 117 operator = linalg.LinearOperatorDiag(diag) 118 # Should not raise 119 operator.assert_self_adjoint().run() 120 121 def test_scalar_diag_raises(self): 122 with self.assertRaisesRegexp(ValueError, "must have at least 1 dimension"): 123 linalg.LinearOperatorDiag(1.) 124 125 def test_broadcast_matmul_and_solve(self): 126 # These cannot be done in the automated (base test class) tests since they 127 # test shapes that tf.matmul cannot handle. 128 # In particular, tf.matmul does not broadcast. 129 with self.test_session() as sess: 130 x = random_ops.random_normal(shape=(2, 2, 3, 4)) 131 132 # This LinearOperatorDiag will be brodacast to (2, 2, 3, 3) during solve 133 # and matmul with 'x' as the argument. 134 diag = random_ops.random_uniform(shape=(2, 1, 3)) 135 operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True) 136 self.assertAllEqual((2, 1, 3, 3), operator.shape) 137 138 # Create a batch matrix with the broadcast shape of operator. 139 diag_broadcast = array_ops.concat((diag, diag), 1) 140 mat = array_ops.matrix_diag(diag_broadcast) 141 self.assertAllEqual((2, 2, 3, 3), mat.get_shape()) # being pedantic. 142 143 operator_matmul = operator.matmul(x) 144 mat_matmul = math_ops.matmul(mat, x) 145 self.assertAllEqual(operator_matmul.get_shape(), mat_matmul.get_shape()) 146 self.assertAllClose(*sess.run([operator_matmul, mat_matmul])) 147 148 operator_solve = operator.solve(x) 149 mat_solve = linalg_ops.matrix_solve(mat, x) 150 self.assertAllEqual(operator_solve.get_shape(), mat_solve.get_shape()) 151 self.assertAllClose(*sess.run([operator_solve, mat_solve])) 152 153 154if __name__ == "__main__": 155 test.main() 156