1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Add one or more `LinearOperators` efficiently."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import six
24
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops.linalg import linear_operator
29from tensorflow.python.ops.linalg import linear_operator_diag
30from tensorflow.python.ops.linalg import linear_operator_full_matrix
31from tensorflow.python.ops.linalg import linear_operator_identity
32from tensorflow.python.ops.linalg import linear_operator_lower_triangular
33
34__all__ = []
35
36
37def add_operators(operators,
38                  operator_name=None,
39                  addition_tiers=None,
40                  name=None):
41  """Efficiently add one or more linear operators.
42
43  Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
44  operators `[B1, B2,...]` such that
45
46  ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
47
48  The operators `Bk` result by adding some of the `Ak`, as allowed by
49  `addition_tiers`.
50
51  Example of efficient adding of diagonal operators.
52
53  ```python
54  A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
55  A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
56
57  # Use two tiers, the first contains an Adder that returns Diag.  Since both
58  # A1 and A2 are Diag, they can use this Adder.  The second tier will not be
59  # used.
60  addition_tiers = [
61      [_AddAndReturnDiag()],
62      [_AddAndReturnMatrix()]]
63  B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
64
65  len(B_list)
66  ==> 1
67
68  B_list[0].__class__.__name__
69  ==> 'LinearOperatorDiag'
70
71  B_list[0].to_dense()
72  ==> [[3., 0.],
73       [0., 3.]]
74
75  B_list[0].name
76  ==> 'Add/A1__A2/'
77  ```
78
79  Args:
80    operators:  Iterable of `LinearOperator` objects with same `dtype`, domain
81      and range dimensions, and broadcastable batch shapes.
82    operator_name:  String name for returned `LinearOperator`.  Defaults to
83      concatenation of "Add/A__B/" that indicates the order of addition steps.
84    addition_tiers:  List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
85      is a list of `Adder` objects.  This function attempts to do all additions
86      in tier `i` before trying tier `i + 1`.
87    name:  A name for this `Op`.  Defaults to `add_operators`.
88
89  Returns:
90    Subclass of `LinearOperator`.  Class and order of addition may change as new
91      (and better) addition strategies emerge.
92
93  Raises:
94    ValueError:  If `operators` argument is empty.
95    ValueError:  If shapes are incompatible.
96  """
97  # Default setting
98  if addition_tiers is None:
99    addition_tiers = _DEFAULT_ADDITION_TIERS
100
101  # Argument checking.
102  check_ops.assert_proper_iterable(operators)
103  operators = list(reversed(operators))
104  if len(operators) < 1:
105    raise ValueError(
106        "Argument 'operators' must contain at least one operator.  "
107        "Found: %s" % operators)
108  if not all(
109      isinstance(op, linear_operator.LinearOperator) for op in operators):
110    raise TypeError(
111        "Argument 'operators' must contain only LinearOperator instances.  "
112        "Found: %s" % operators)
113  _static_check_for_same_dimensions(operators)
114  _static_check_for_broadcastable_batch_shape(operators)
115
116  graph_parents = []
117  for operator in operators:
118    graph_parents.extend(operator.graph_parents)
119
120  with ops.name_scope(name or "add_operators", values=graph_parents):
121
122    # Additions done in one of the tiers.  Try tier 0, 1,...
123    ops_to_try_at_next_tier = list(operators)
124    for tier in addition_tiers:
125      ops_to_try_at_this_tier = ops_to_try_at_next_tier
126      ops_to_try_at_next_tier = []
127      while ops_to_try_at_this_tier:
128        op1 = ops_to_try_at_this_tier.pop()
129        op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
130        if op2 is not None:
131          # Will try to add the result of this again at this same tier.
132          new_operator = adder.add(op1, op2, operator_name)
133          ops_to_try_at_this_tier.append(new_operator)
134        else:
135          ops_to_try_at_next_tier.append(op1)
136
137    return ops_to_try_at_next_tier
138
139
140def _pop_a_match_at_tier(op1, operator_list, tier):
141  # Search from the back of list to the front in order to create nice default
142  # order of operations.
143  for i in range(1, len(operator_list) + 1):
144    op2 = operator_list[-i]
145    for adder in tier:
146      if adder.can_add(op1, op2):
147        return operator_list.pop(-i), adder
148  return None, None
149
150
151def _infer_hints_allowing_override(op1, op2, hints):
152  """Infer hints from op1 and op2.  hints argument is an override.
153
154  Args:
155    op1:  LinearOperator
156    op2:  LinearOperator
157    hints:  _Hints object holding "is_X" boolean hints to use for returned
158      operator.
159      If some hint is None, try to set using op1 and op2.  If the
160      hint is provided, ignore op1 and op2 hints.  This allows an override
161      of previous hints, but does not allow forbidden hints (e.g. you still
162      cannot say a real diagonal operator is not self-adjoint.
163
164  Returns:
165    _Hints object.
166  """
167  hints = hints or _Hints()
168  # If A, B are self-adjoint, then so is A + B.
169  if hints.is_self_adjoint is None:
170    is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
171  else:
172    is_self_adjoint = hints.is_self_adjoint
173
174  # If A, B are positive definite, then so is A + B.
175  if hints.is_positive_definite is None:
176    is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
177  else:
178    is_positive_definite = hints.is_positive_definite
179
180  # A positive definite operator is always non-singular.
181  if is_positive_definite and hints.is_positive_definite is None:
182    is_non_singular = True
183  else:
184    is_non_singular = hints.is_non_singular
185
186  return _Hints(
187      is_non_singular=is_non_singular,
188      is_self_adjoint=is_self_adjoint,
189      is_positive_definite=is_positive_definite)
190
191
192def _static_check_for_same_dimensions(operators):
193  """ValueError if operators determined to have different dimensions."""
194  if len(operators) < 2:
195    return
196
197  domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
198                       if op.domain_dimension.value is not None]
199  if len(set(value for name, value in domain_dimensions)) > 1:
200    raise ValueError("Operators must have the same domain dimension. Found: %s"
201                     % domain_dimensions)
202
203  range_dimensions = [(op.name, op.range_dimension.value) for op in operators
204                      if op.range_dimension.value is not None]
205  if len(set(value for name, value in range_dimensions)) > 1:
206    raise ValueError("Operators must have the same range dimension. Found: %s" %
207                     range_dimensions)
208
209
210def _static_check_for_broadcastable_batch_shape(operators):
211  """ValueError if operators determined to have non-broadcastable shapes."""
212  if len(operators) < 2:
213    return
214
215  # This will fail if they cannot be broadcast together.
216  batch_shape = operators[0].batch_shape
217  for op in operators[1:]:
218    batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
219
220
221class _Hints(object):
222  """Holds 'is_X' flags that every LinearOperator is initialized with."""
223
224  def __init__(self,
225               is_non_singular=None,
226               is_positive_definite=None,
227               is_self_adjoint=None):
228    self.is_non_singular = is_non_singular
229    self.is_positive_definite = is_positive_definite
230    self.is_self_adjoint = is_self_adjoint
231
232
233################################################################################
234# Classes to add two linear operators.
235################################################################################
236
237
238@six.add_metaclass(abc.ABCMeta)
239class _Adder(object):
240  """Abstract base class to add two operators.
241
242  Each `Adder` acts independently, adding everything it can, paying no attention
243  as to whether another `Adder` could have done the addition more efficiently.
244  """
245
246  @property
247  def name(self):
248    return self.__class__.__name__
249
250  @abc.abstractmethod
251  def can_add(self, op1, op2):
252    """Returns `True` if this `Adder` can add `op1` and `op2`.  Else `False`."""
253    pass
254
255  @abc.abstractmethod
256  def _add(self, op1, op2, operator_name, hints):
257    # Derived classes can assume op1 and op2 have been validated, e.g. they have
258    # the same dtype, and their domain/range dimensions match.
259    pass
260
261  def add(self, op1, op2, operator_name, hints=None):
262    """Return new `LinearOperator` acting like `op1 + op2`.
263
264    Args:
265      op1:  `LinearOperator`
266      op2:  `LinearOperator`, with `shape` and `dtype` such that adding to
267        `op1` is allowed.
268      operator_name:  `String` name to give to returned `LinearOperator`
269      hints:  `_Hints` object.  Returned `LinearOperator` will be created with
270        these hints.
271
272    Returns:
273      `LinearOperator`
274    """
275    updated_hints = _infer_hints_allowing_override(op1, op2, hints)
276
277    if operator_name is None:
278      operator_name = "Add/" + op1.name + "__" + op2.name + "/"
279
280    values = op1.graph_parents + op2.graph_parents
281    scope_name = self.name
282    if scope_name.startswith("_"):
283      scope_name = scope_name[1:]
284    with ops.name_scope(scope_name, values=values):
285      return self._add(op1, op2, operator_name, updated_hints)
286
287
288class _AddAndReturnScaledIdentity(_Adder):
289  """Handles additions resulting in an Identity family member.
290
291  The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
292  is closed under addition.  This `Adder` respects that, and returns an Identity
293  """
294
295  def can_add(self, op1, op2):
296    types = {_type(op1), _type(op2)}
297    return not types.difference(_IDENTITY_FAMILY)
298
299  def _add(self, op1, op2, operator_name, hints):
300    # Will build a LinearOperatorScaledIdentity.
301
302    if _type(op1) == _SCALED_IDENTITY:
303      multiplier_1 = op1.multiplier
304    else:
305      multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
306
307    if _type(op2) == _SCALED_IDENTITY:
308      multiplier_2 = op2.multiplier
309    else:
310      multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
311
312    return linear_operator_identity.LinearOperatorScaledIdentity(
313        num_rows=op1.range_dimension_tensor(),
314        multiplier=multiplier_1 + multiplier_2,
315        is_non_singular=hints.is_non_singular,
316        is_self_adjoint=hints.is_self_adjoint,
317        is_positive_definite=hints.is_positive_definite,
318        name=operator_name)
319
320
321class _AddAndReturnDiag(_Adder):
322  """Handles additions resulting in a Diag operator."""
323
324  def can_add(self, op1, op2):
325    types = {_type(op1), _type(op2)}
326    return not types.difference(_DIAG_LIKE)
327
328  def _add(self, op1, op2, operator_name, hints):
329    return linear_operator_diag.LinearOperatorDiag(
330        diag=op1.diag_part() + op2.diag_part(),
331        is_non_singular=hints.is_non_singular,
332        is_self_adjoint=hints.is_self_adjoint,
333        is_positive_definite=hints.is_positive_definite,
334        name=operator_name)
335
336
337class _AddAndReturnTriL(_Adder):
338  """Handles additions resulting in a TriL operator."""
339
340  def can_add(self, op1, op2):
341    types = {_type(op1), _type(op2)}
342    return not types.difference(_DIAG_LIKE.union({_TRIL}))
343
344  def _add(self, op1, op2, operator_name, hints):
345    if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
346      op_add_to_tensor, op_other = op1, op2
347    else:
348      op_add_to_tensor, op_other = op2, op1
349
350    return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
351        tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
352        is_non_singular=hints.is_non_singular,
353        is_self_adjoint=hints.is_self_adjoint,
354        is_positive_definite=hints.is_positive_definite,
355        name=operator_name)
356
357
358class _AddAndReturnMatrix(_Adder):
359  """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
360
361  def can_add(self, op1, op2):  # pylint: disable=unused-argument
362    return isinstance(op1, linear_operator.LinearOperator) and isinstance(
363        op2, linear_operator.LinearOperator)
364
365  def _add(self, op1, op2, operator_name, hints):
366    if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
367      op_add_to_tensor, op_other = op1, op2
368    else:
369      op_add_to_tensor, op_other = op2, op1
370    return linear_operator_full_matrix.LinearOperatorFullMatrix(
371        matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
372        is_non_singular=hints.is_non_singular,
373        is_self_adjoint=hints.is_self_adjoint,
374        is_positive_definite=hints.is_positive_definite,
375        name=operator_name)
376
377
378################################################################################
379# Constants designating types of LinearOperators
380################################################################################
381
382# Type name constants for LinearOperator classes.
383_IDENTITY = "identity"
384_SCALED_IDENTITY = "scaled_identity"
385_DIAG = "diag"
386_TRIL = "tril"
387_MATRIX = "matrix"
388
389# Groups of operators.
390_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
391_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
392# operators with an efficient .add_to_tensor() method.
393_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
394
395
396def _type(operator):
397  """Returns the type name constant (e.g. _TRIL) for operator."""
398  if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
399    return _DIAG
400  if isinstance(operator,
401                linear_operator_lower_triangular.LinearOperatorLowerTriangular):
402    return _TRIL
403  if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
404    return _MATRIX
405  if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
406    return _IDENTITY
407  if isinstance(operator,
408                linear_operator_identity.LinearOperatorScaledIdentity):
409    return _SCALED_IDENTITY
410  raise TypeError("Operator type unknown: %s" % operator)
411
412
413################################################################################
414# Addition tiers:
415# We attempt to use Adders in tier K before K+1.
416#
417# Organize tiers to
418#   (i) reduce O(..) complexity of forming final operator, and
419#   (ii) produce the "most efficient" final operator.
420# Dev notes:
421#  * Results of addition at tier K will be added at tier K or higher.
422#  * Tiers may change, and we warn the user that it may change.
423################################################################################
424
425# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
426# dense matrix.  So it is sometimes very inefficient.
427_DEFAULT_ADDITION_TIERS = [
428    [_AddAndReturnScaledIdentity()],
429    [_AddAndReturnDiag()],
430    [_AddAndReturnTriL()],
431    [_AddAndReturnMatrix()],
432]
433