1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Solvers for linear least-squares."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.contrib.solvers.python.ops import util
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30
31
32def cgls(operator, rhs, tol=1e-6, max_iter=20, name="cgls"):
33  r"""Conjugate gradient least squares solver.
34
35  Solves a linear least squares problem \\(||A x - rhs||_2\\) for a single
36  righ-hand side, using an iterative, matrix-free algorithm where the action of
37  the matrix A is represented by `operator`. The CGLS algorithm implicitly
38  applies the symmetric conjugate gradient algorithm to the normal equations
39  \\(A^* A x = A^* rhs\\). The iteration terminates when either
40  the number of iterations exceeds `max_iter` or when the norm of the conjugate
41  residual (residual of the normal equations) have been reduced to `tol` times
42  its initial initial value, i.e.
43  \\(||A^* (rhs - A x_k)|| <= tol ||A^* rhs||\\).
44
45  Args:
46    operator: An object representing a linear operator with attributes:
47      - shape: Either a list of integers or a 1-D `Tensor` of type `int32` of
48        length 2. `shape[0]` is the dimension on the domain of the operator,
49        `shape[1]` is the dimension of the co-domain of the operator. On other
50        words, if operator represents an M x N matrix A, `shape` must contain
51        `[M, N]`.
52      - dtype: The datatype of input to and output from `apply` and
53        `apply_adjoint`.
54      - apply: Callable object taking a vector `x` as input and returning a
55        vector with the result of applying the operator to `x`, i.e. if
56       `operator` represents matrix `A`, `apply` should return `A * x`.
57      - apply_adjoint: Callable object taking a vector `x` as input and
58        returning a vector with the result of applying the adjoint operator
59        to `x`, i.e. if `operator` represents matrix `A`, `apply_adjoint` should
60        return `conj(transpose(A)) * x`.
61
62    rhs: A rank-1 `Tensor` of shape `[M]` containing the right-hand size vector.
63    tol: A float scalar convergence tolerance.
64    max_iter: An integer giving the maximum number of iterations.
65    name: A name scope for the operation.
66
67
68  Returns:
69    output: A namedtuple representing the final state with fields:
70      - i: A scalar `int32` `Tensor`. Number of iterations executed.
71      - x: A rank-1 `Tensor` of shape `[N]` containing the computed solution.
72      - r: A rank-1 `Tensor` of shape `[M]` containing the residual vector.
73      - p: A rank-1 `Tensor` of shape `[N]`. The next descent direction.
74      - gamma: \\(||A^* r||_2^2\\)
75  """
76  # ephemeral class holding CGLS state.
77  cgls_state = collections.namedtuple("CGLSState",
78                                      ["i", "x", "r", "p", "gamma"])
79
80  def stopping_criterion(i, state):
81    return math_ops.logical_and(i < max_iter, state.gamma > tol)
82
83  # TODO(rmlarsen): add preconditioning
84  def cgls_step(i, state):
85    q = operator.apply(state.p)
86    alpha = state.gamma / util.l2norm_squared(q)
87    x = state.x + alpha * state.p
88    r = state.r - alpha * q
89    s = operator.apply_adjoint(r)
90    gamma = util.l2norm_squared(s)
91    beta = gamma / state.gamma
92    p = s + beta * state.p
93    return i + 1, cgls_state(i + 1, x, r, p, gamma)
94
95  with ops.name_scope(name):
96    n = operator.shape[1:]
97    rhs = array_ops.expand_dims(rhs, -1)
98    s0 = operator.apply_adjoint(rhs)
99    gamma0 = util.l2norm_squared(s0)
100    tol = tol * tol * gamma0
101    x = array_ops.expand_dims(
102        array_ops.zeros(
103            n, dtype=rhs.dtype.base_dtype), -1)
104    i = constant_op.constant(0, dtype=dtypes.int32)
105    state = cgls_state(i=i, x=x, r=rhs, p=s0, gamma=gamma0)
106    _, state = control_flow_ops.while_loop(stopping_criterion, cgls_step,
107                                           [i, state])
108    return cgls_state(
109        state.i,
110        x=array_ops.squeeze(state.x),
111        r=array_ops.squeeze(state.r),
112        p=array_ops.squeeze(state.p),
113        gamma=state.gamma)
114