1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Solvers for linear least-squares.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.contrib.solvers.python.ops import util 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30 31 32def cgls(operator, rhs, tol=1e-6, max_iter=20, name="cgls"): 33 r"""Conjugate gradient least squares solver. 34 35 Solves a linear least squares problem \\(||A x - rhs||_2\\) for a single 36 righ-hand side, using an iterative, matrix-free algorithm where the action of 37 the matrix A is represented by `operator`. The CGLS algorithm implicitly 38 applies the symmetric conjugate gradient algorithm to the normal equations 39 \\(A^* A x = A^* rhs\\). The iteration terminates when either 40 the number of iterations exceeds `max_iter` or when the norm of the conjugate 41 residual (residual of the normal equations) have been reduced to `tol` times 42 its initial initial value, i.e. 43 \\(||A^* (rhs - A x_k)|| <= tol ||A^* rhs||\\). 44 45 Args: 46 operator: An object representing a linear operator with attributes: 47 - shape: Either a list of integers or a 1-D `Tensor` of type `int32` of 48 length 2. `shape[0]` is the dimension on the domain of the operator, 49 `shape[1]` is the dimension of the co-domain of the operator. On other 50 words, if operator represents an M x N matrix A, `shape` must contain 51 `[M, N]`. 52 - dtype: The datatype of input to and output from `apply` and 53 `apply_adjoint`. 54 - apply: Callable object taking a vector `x` as input and returning a 55 vector with the result of applying the operator to `x`, i.e. if 56 `operator` represents matrix `A`, `apply` should return `A * x`. 57 - apply_adjoint: Callable object taking a vector `x` as input and 58 returning a vector with the result of applying the adjoint operator 59 to `x`, i.e. if `operator` represents matrix `A`, `apply_adjoint` should 60 return `conj(transpose(A)) * x`. 61 62 rhs: A rank-1 `Tensor` of shape `[M]` containing the right-hand size vector. 63 tol: A float scalar convergence tolerance. 64 max_iter: An integer giving the maximum number of iterations. 65 name: A name scope for the operation. 66 67 68 Returns: 69 output: A namedtuple representing the final state with fields: 70 - i: A scalar `int32` `Tensor`. Number of iterations executed. 71 - x: A rank-1 `Tensor` of shape `[N]` containing the computed solution. 72 - r: A rank-1 `Tensor` of shape `[M]` containing the residual vector. 73 - p: A rank-1 `Tensor` of shape `[N]`. The next descent direction. 74 - gamma: \\(||A^* r||_2^2\\) 75 """ 76 # ephemeral class holding CGLS state. 77 cgls_state = collections.namedtuple("CGLSState", 78 ["i", "x", "r", "p", "gamma"]) 79 80 def stopping_criterion(i, state): 81 return math_ops.logical_and(i < max_iter, state.gamma > tol) 82 83 # TODO(rmlarsen): add preconditioning 84 def cgls_step(i, state): 85 q = operator.apply(state.p) 86 alpha = state.gamma / util.l2norm_squared(q) 87 x = state.x + alpha * state.p 88 r = state.r - alpha * q 89 s = operator.apply_adjoint(r) 90 gamma = util.l2norm_squared(s) 91 beta = gamma / state.gamma 92 p = s + beta * state.p 93 return i + 1, cgls_state(i + 1, x, r, p, gamma) 94 95 with ops.name_scope(name): 96 n = operator.shape[1:] 97 rhs = array_ops.expand_dims(rhs, -1) 98 s0 = operator.apply_adjoint(rhs) 99 gamma0 = util.l2norm_squared(s0) 100 tol = tol * tol * gamma0 101 x = array_ops.expand_dims( 102 array_ops.zeros( 103 n, dtype=rhs.dtype.base_dtype), -1) 104 i = constant_op.constant(0, dtype=dtypes.int32) 105 state = cgls_state(i=i, x=x, r=rhs, p=s0, gamma=gamma0) 106 _, state = control_flow_ops.while_loop(stopping_criterion, cgls_step, 107 [i, state]) 108 return cgls_state( 109 state.i, 110 x=array_ops.squeeze(state.x), 111 r=array_ops.squeeze(state.r), 112 p=array_ops.squeeze(state.p), 113 gamma=state.gamma) 114