1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for linear algebra.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import gen_linalg_ops 24from tensorflow.python.ops import linalg_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import special_math_ops 27from tensorflow.python.util.tf_export import tf_export 28 29# Linear algebra ops. 30band_part = array_ops.matrix_band_part 31cholesky = linalg_ops.cholesky 32cholesky_solve = linalg_ops.cholesky_solve 33det = linalg_ops.matrix_determinant 34# pylint: disable=protected-access 35slogdet = gen_linalg_ops._log_matrix_determinant 36# pylint: disable=protected-access 37diag = array_ops.matrix_diag 38diag_part = array_ops.matrix_diag_part 39eigh = linalg_ops.self_adjoint_eig 40eigvalsh = linalg_ops.self_adjoint_eigvals 41einsum = special_math_ops.einsum 42expm = gen_linalg_ops._matrix_exponential 43eye = linalg_ops.eye 44inv = linalg_ops.matrix_inverse 45logm = gen_linalg_ops._matrix_logarithm 46lstsq = linalg_ops.matrix_solve_ls 47norm = linalg_ops.norm 48qr = linalg_ops.qr 49set_diag = array_ops.matrix_set_diag 50solve = linalg_ops.matrix_solve 51svd = linalg_ops.svd 52tensordot = math_ops.tensordot 53trace = math_ops.trace 54transpose = array_ops.matrix_transpose 55triangular_solve = linalg_ops.matrix_triangular_solve 56 57 58@tf_export('linalg.logdet') 59def logdet(matrix, name=None): 60 """Computes log of the determinant of a hermitian positive definite matrix. 61 62 ```python 63 # Compute the determinant of a matrix while reducing the chance of over- or 64 underflow: 65 A = ... # shape 10 x 10 66 det = tf.exp(tf.logdet(A)) # scalar 67 ``` 68 69 Args: 70 matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, 71 or `complex128` with shape `[..., M, M]`. 72 name: A name to give this `Op`. Defaults to `logdet`. 73 74 Returns: 75 The natural log of the determinant of `matrix`. 76 77 @compatibility(numpy) 78 Equivalent to numpy.linalg.slogdet, although no sign is returned since only 79 hermitian positive definite matrices are supported. 80 @end_compatibility 81 """ 82 # This uses the property that the log det(A) = 2*sum(log(real(diag(C)))) 83 # where C is the cholesky decomposition of A. 84 with ops.name_scope(name, 'logdet', [matrix]): 85 chol = gen_linalg_ops.cholesky(matrix) 86 return 2.0 * math_ops.reduce_sum( 87 math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))), 88 reduction_indices=[-1]) 89 90 91@tf_export('linalg.adjoint') 92def adjoint(matrix, name=None): 93 """Transposes the last two dimensions of and conjugates tensor `matrix`. 94 95 For example: 96 97 ```python 98 x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j], 99 [4 + 4j, 5 + 5j, 6 + 6j]]) 100 tf.linalg.adjoint(x) # [[1 - 1j, 4 - 4j], 101 # [2 - 2j, 5 - 5j], 102 # [3 - 3j, 6 - 6j]] 103 104 Args: 105 matrix: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, 106 or `complex128` with shape `[..., M, M]`. 107 name: A name to give this `Op` (optional). 108 109 Returns: 110 The adjoint (a.k.a. Hermitian transpose a.k.a. conjugate transpose) of 111 matrix. 112 """ 113 with ops.name_scope(name, 'adjoint', [matrix]): 114 matrix = ops.convert_to_tensor(matrix, name='matrix') 115 return array_ops.matrix_transpose(matrix, conjugate=True) 116