1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import random_seed
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import linalg_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import random_ops
25from tensorflow.python.ops.linalg import linalg as linalg_lib
26from tensorflow.python.ops.linalg import linear_operator_test_util
27from tensorflow.python.platform import test
28
29linalg = linalg_lib
30random_seed.set_random_seed(23)
31
32
33class LinearOperatorDiagTest(
34    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
35  """Most tests done in the base class LinearOperatorDerivedClassTest."""
36
37  def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder):
38    diag = linear_operator_test_util.random_sign_uniform(
39        shape[:-1], minval=1., maxval=2., dtype=dtype)
40    if use_placeholder:
41      diag_ph = array_ops.placeholder(dtype=dtype)
42      # Evaluate the diag here because (i) you cannot feed a tensor, and (ii)
43      # diag is random and we want the same value used for both mat and
44      # feed_dict.
45      diag = diag.eval()
46      operator = linalg.LinearOperatorDiag(diag_ph)
47      feed_dict = {diag_ph: diag}
48    else:
49      operator = linalg.LinearOperatorDiag(diag)
50      feed_dict = None
51
52    mat = array_ops.matrix_diag(diag)
53
54    return operator, mat, feed_dict
55
56  def test_assert_positive_definite_raises_for_zero_eigenvalue(self):
57    # Matrix with one positive eigenvalue and one zero eigenvalue.
58    with self.test_session():
59      diag = [1.0, 0.0]
60      operator = linalg.LinearOperatorDiag(diag)
61
62      # is_self_adjoint should be auto-set for real diag.
63      self.assertTrue(operator.is_self_adjoint)
64      with self.assertRaisesOpError("non-positive.*not positive definite"):
65        operator.assert_positive_definite().run()
66
67  def test_assert_positive_definite_raises_for_negative_real_eigvalues(self):
68    with self.test_session():
69      diag_x = [1.0, -2.0]
70      diag_y = [0., 0.]  # Imaginary eigenvalues should not matter.
71      diag = math_ops.complex(diag_x, diag_y)
72      operator = linalg.LinearOperatorDiag(diag)
73
74      # is_self_adjoint should not be auto-set for complex diag.
75      self.assertTrue(operator.is_self_adjoint is None)
76      with self.assertRaisesOpError("non-positive real.*not positive definite"):
77        operator.assert_positive_definite().run()
78
79  def test_assert_positive_definite_does_not_raise_if_pd_and_complex(self):
80    with self.test_session():
81      x = [1., 2.]
82      y = [1., 0.]
83      diag = math_ops.complex(x, y)  # Re[diag] > 0.
84      # Should not fail
85      linalg.LinearOperatorDiag(diag).assert_positive_definite().run()
86
87  def test_assert_non_singular_raises_if_zero_eigenvalue(self):
88    # Singlular matrix with one positive eigenvalue and one zero eigenvalue.
89    with self.test_session():
90      diag = [1.0, 0.0]
91      operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True)
92      with self.assertRaisesOpError("Singular operator"):
93        operator.assert_non_singular().run()
94
95  def test_assert_non_singular_does_not_raise_for_complex_nonsingular(self):
96    with self.test_session():
97      x = [1., 0.]
98      y = [0., 1.]
99      diag = math_ops.complex(x, y)
100      # Should not raise.
101      linalg.LinearOperatorDiag(diag).assert_non_singular().run()
102
103  def test_assert_self_adjoint_raises_if_diag_has_complex_part(self):
104    with self.test_session():
105      x = [1., 0.]
106      y = [0., 1.]
107      diag = math_ops.complex(x, y)
108      operator = linalg.LinearOperatorDiag(diag)
109      with self.assertRaisesOpError("imaginary.*not self-adjoint"):
110        operator.assert_self_adjoint().run()
111
112  def test_assert_self_adjoint_does_not_raise_for_diag_with_zero_imag(self):
113    with self.test_session():
114      x = [1., 0.]
115      y = [0., 0.]
116      diag = math_ops.complex(x, y)
117      operator = linalg.LinearOperatorDiag(diag)
118      # Should not raise
119      operator.assert_self_adjoint().run()
120
121  def test_scalar_diag_raises(self):
122    with self.assertRaisesRegexp(ValueError, "must have at least 1 dimension"):
123      linalg.LinearOperatorDiag(1.)
124
125  def test_broadcast_matmul_and_solve(self):
126    # These cannot be done in the automated (base test class) tests since they
127    # test shapes that tf.matmul cannot handle.
128    # In particular, tf.matmul does not broadcast.
129    with self.test_session() as sess:
130      x = random_ops.random_normal(shape=(2, 2, 3, 4))
131
132      # This LinearOperatorDiag will be brodacast to (2, 2, 3, 3) during solve
133      # and matmul with 'x' as the argument.
134      diag = random_ops.random_uniform(shape=(2, 1, 3))
135      operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True)
136      self.assertAllEqual((2, 1, 3, 3), operator.shape)
137
138      # Create a batch matrix with the broadcast shape of operator.
139      diag_broadcast = array_ops.concat((diag, diag), 1)
140      mat = array_ops.matrix_diag(diag_broadcast)
141      self.assertAllEqual((2, 2, 3, 3), mat.get_shape())  # being pedantic.
142
143      operator_matmul = operator.matmul(x)
144      mat_matmul = math_ops.matmul(mat, x)
145      self.assertAllEqual(operator_matmul.get_shape(), mat_matmul.get_shape())
146      self.assertAllClose(*sess.run([operator_matmul, mat_matmul]))
147
148      operator_solve = operator.solve(x)
149      mat_solve = linalg_ops.matrix_solve(mat, x)
150      self.assertAllEqual(operator_solve.get_shape(), mat_solve.get_shape())
151      self.assertAllClose(*sess.run([operator_solve, mat_solve]))
152
153
154if __name__ == "__main__":
155  test.main()
156