1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Add one or more `LinearOperators` efficiently.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import six 24 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops.linalg import linear_operator 29from tensorflow.python.ops.linalg import linear_operator_diag 30from tensorflow.python.ops.linalg import linear_operator_full_matrix 31from tensorflow.python.ops.linalg import linear_operator_identity 32from tensorflow.python.ops.linalg import linear_operator_lower_triangular 33 34__all__ = [] 35 36 37def add_operators(operators, 38 operator_name=None, 39 addition_tiers=None, 40 name=None): 41 """Efficiently add one or more linear operators. 42 43 Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of 44 operators `[B1, B2,...]` such that 45 46 ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).``` 47 48 The operators `Bk` result by adding some of the `Ak`, as allowed by 49 `addition_tiers`. 50 51 Example of efficient adding of diagonal operators. 52 53 ```python 54 A1 = LinearOperatorDiag(diag=[1., 1.], name="A1") 55 A2 = LinearOperatorDiag(diag=[2., 2.], name="A2") 56 57 # Use two tiers, the first contains an Adder that returns Diag. Since both 58 # A1 and A2 are Diag, they can use this Adder. The second tier will not be 59 # used. 60 addition_tiers = [ 61 [_AddAndReturnDiag()], 62 [_AddAndReturnMatrix()]] 63 B_list = add_operators([A1, A2], addition_tiers=addition_tiers) 64 65 len(B_list) 66 ==> 1 67 68 B_list[0].__class__.__name__ 69 ==> 'LinearOperatorDiag' 70 71 B_list[0].to_dense() 72 ==> [[3., 0.], 73 [0., 3.]] 74 75 B_list[0].name 76 ==> 'Add/A1__A2/' 77 ``` 78 79 Args: 80 operators: Iterable of `LinearOperator` objects with same `dtype`, domain 81 and range dimensions, and broadcastable batch shapes. 82 operator_name: String name for returned `LinearOperator`. Defaults to 83 concatenation of "Add/A__B/" that indicates the order of addition steps. 84 addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i` 85 is a list of `Adder` objects. This function attempts to do all additions 86 in tier `i` before trying tier `i + 1`. 87 name: A name for this `Op`. Defaults to `add_operators`. 88 89 Returns: 90 Subclass of `LinearOperator`. Class and order of addition may change as new 91 (and better) addition strategies emerge. 92 93 Raises: 94 ValueError: If `operators` argument is empty. 95 ValueError: If shapes are incompatible. 96 """ 97 # Default setting 98 if addition_tiers is None: 99 addition_tiers = _DEFAULT_ADDITION_TIERS 100 101 # Argument checking. 102 check_ops.assert_proper_iterable(operators) 103 operators = list(reversed(operators)) 104 if len(operators) < 1: 105 raise ValueError( 106 "Argument 'operators' must contain at least one operator. " 107 "Found: %s" % operators) 108 if not all( 109 isinstance(op, linear_operator.LinearOperator) for op in operators): 110 raise TypeError( 111 "Argument 'operators' must contain only LinearOperator instances. " 112 "Found: %s" % operators) 113 _static_check_for_same_dimensions(operators) 114 _static_check_for_broadcastable_batch_shape(operators) 115 116 graph_parents = [] 117 for operator in operators: 118 graph_parents.extend(operator.graph_parents) 119 120 with ops.name_scope(name or "add_operators", values=graph_parents): 121 122 # Additions done in one of the tiers. Try tier 0, 1,... 123 ops_to_try_at_next_tier = list(operators) 124 for tier in addition_tiers: 125 ops_to_try_at_this_tier = ops_to_try_at_next_tier 126 ops_to_try_at_next_tier = [] 127 while ops_to_try_at_this_tier: 128 op1 = ops_to_try_at_this_tier.pop() 129 op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier) 130 if op2 is not None: 131 # Will try to add the result of this again at this same tier. 132 new_operator = adder.add(op1, op2, operator_name) 133 ops_to_try_at_this_tier.append(new_operator) 134 else: 135 ops_to_try_at_next_tier.append(op1) 136 137 return ops_to_try_at_next_tier 138 139 140def _pop_a_match_at_tier(op1, operator_list, tier): 141 # Search from the back of list to the front in order to create nice default 142 # order of operations. 143 for i in range(1, len(operator_list) + 1): 144 op2 = operator_list[-i] 145 for adder in tier: 146 if adder.can_add(op1, op2): 147 return operator_list.pop(-i), adder 148 return None, None 149 150 151def _infer_hints_allowing_override(op1, op2, hints): 152 """Infer hints from op1 and op2. hints argument is an override. 153 154 Args: 155 op1: LinearOperator 156 op2: LinearOperator 157 hints: _Hints object holding "is_X" boolean hints to use for returned 158 operator. 159 If some hint is None, try to set using op1 and op2. If the 160 hint is provided, ignore op1 and op2 hints. This allows an override 161 of previous hints, but does not allow forbidden hints (e.g. you still 162 cannot say a real diagonal operator is not self-adjoint. 163 164 Returns: 165 _Hints object. 166 """ 167 hints = hints or _Hints() 168 # If A, B are self-adjoint, then so is A + B. 169 if hints.is_self_adjoint is None: 170 is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint 171 else: 172 is_self_adjoint = hints.is_self_adjoint 173 174 # If A, B are positive definite, then so is A + B. 175 if hints.is_positive_definite is None: 176 is_positive_definite = op1.is_positive_definite and op2.is_positive_definite 177 else: 178 is_positive_definite = hints.is_positive_definite 179 180 # A positive definite operator is always non-singular. 181 if is_positive_definite and hints.is_positive_definite is None: 182 is_non_singular = True 183 else: 184 is_non_singular = hints.is_non_singular 185 186 return _Hints( 187 is_non_singular=is_non_singular, 188 is_self_adjoint=is_self_adjoint, 189 is_positive_definite=is_positive_definite) 190 191 192def _static_check_for_same_dimensions(operators): 193 """ValueError if operators determined to have different dimensions.""" 194 if len(operators) < 2: 195 return 196 197 domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators 198 if op.domain_dimension.value is not None] 199 if len(set(value for name, value in domain_dimensions)) > 1: 200 raise ValueError("Operators must have the same domain dimension. Found: %s" 201 % domain_dimensions) 202 203 range_dimensions = [(op.name, op.range_dimension.value) for op in operators 204 if op.range_dimension.value is not None] 205 if len(set(value for name, value in range_dimensions)) > 1: 206 raise ValueError("Operators must have the same range dimension. Found: %s" % 207 range_dimensions) 208 209 210def _static_check_for_broadcastable_batch_shape(operators): 211 """ValueError if operators determined to have non-broadcastable shapes.""" 212 if len(operators) < 2: 213 return 214 215 # This will fail if they cannot be broadcast together. 216 batch_shape = operators[0].batch_shape 217 for op in operators[1:]: 218 batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape) 219 220 221class _Hints(object): 222 """Holds 'is_X' flags that every LinearOperator is initialized with.""" 223 224 def __init__(self, 225 is_non_singular=None, 226 is_positive_definite=None, 227 is_self_adjoint=None): 228 self.is_non_singular = is_non_singular 229 self.is_positive_definite = is_positive_definite 230 self.is_self_adjoint = is_self_adjoint 231 232 233################################################################################ 234# Classes to add two linear operators. 235################################################################################ 236 237 238@six.add_metaclass(abc.ABCMeta) 239class _Adder(object): 240 """Abstract base class to add two operators. 241 242 Each `Adder` acts independently, adding everything it can, paying no attention 243 as to whether another `Adder` could have done the addition more efficiently. 244 """ 245 246 @property 247 def name(self): 248 return self.__class__.__name__ 249 250 @abc.abstractmethod 251 def can_add(self, op1, op2): 252 """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`.""" 253 pass 254 255 @abc.abstractmethod 256 def _add(self, op1, op2, operator_name, hints): 257 # Derived classes can assume op1 and op2 have been validated, e.g. they have 258 # the same dtype, and their domain/range dimensions match. 259 pass 260 261 def add(self, op1, op2, operator_name, hints=None): 262 """Return new `LinearOperator` acting like `op1 + op2`. 263 264 Args: 265 op1: `LinearOperator` 266 op2: `LinearOperator`, with `shape` and `dtype` such that adding to 267 `op1` is allowed. 268 operator_name: `String` name to give to returned `LinearOperator` 269 hints: `_Hints` object. Returned `LinearOperator` will be created with 270 these hints. 271 272 Returns: 273 `LinearOperator` 274 """ 275 updated_hints = _infer_hints_allowing_override(op1, op2, hints) 276 277 if operator_name is None: 278 operator_name = "Add/" + op1.name + "__" + op2.name + "/" 279 280 values = op1.graph_parents + op2.graph_parents 281 scope_name = self.name 282 if scope_name.startswith("_"): 283 scope_name = scope_name[1:] 284 with ops.name_scope(scope_name, values=values): 285 return self._add(op1, op2, operator_name, updated_hints) 286 287 288class _AddAndReturnScaledIdentity(_Adder): 289 """Handles additions resulting in an Identity family member. 290 291 The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family 292 is closed under addition. This `Adder` respects that, and returns an Identity 293 """ 294 295 def can_add(self, op1, op2): 296 types = {_type(op1), _type(op2)} 297 return not types.difference(_IDENTITY_FAMILY) 298 299 def _add(self, op1, op2, operator_name, hints): 300 # Will build a LinearOperatorScaledIdentity. 301 302 if _type(op1) == _SCALED_IDENTITY: 303 multiplier_1 = op1.multiplier 304 else: 305 multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype) 306 307 if _type(op2) == _SCALED_IDENTITY: 308 multiplier_2 = op2.multiplier 309 else: 310 multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype) 311 312 return linear_operator_identity.LinearOperatorScaledIdentity( 313 num_rows=op1.range_dimension_tensor(), 314 multiplier=multiplier_1 + multiplier_2, 315 is_non_singular=hints.is_non_singular, 316 is_self_adjoint=hints.is_self_adjoint, 317 is_positive_definite=hints.is_positive_definite, 318 name=operator_name) 319 320 321class _AddAndReturnDiag(_Adder): 322 """Handles additions resulting in a Diag operator.""" 323 324 def can_add(self, op1, op2): 325 types = {_type(op1), _type(op2)} 326 return not types.difference(_DIAG_LIKE) 327 328 def _add(self, op1, op2, operator_name, hints): 329 return linear_operator_diag.LinearOperatorDiag( 330 diag=op1.diag_part() + op2.diag_part(), 331 is_non_singular=hints.is_non_singular, 332 is_self_adjoint=hints.is_self_adjoint, 333 is_positive_definite=hints.is_positive_definite, 334 name=operator_name) 335 336 337class _AddAndReturnTriL(_Adder): 338 """Handles additions resulting in a TriL operator.""" 339 340 def can_add(self, op1, op2): 341 types = {_type(op1), _type(op2)} 342 return not types.difference(_DIAG_LIKE.union({_TRIL})) 343 344 def _add(self, op1, op2, operator_name, hints): 345 if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: 346 op_add_to_tensor, op_other = op1, op2 347 else: 348 op_add_to_tensor, op_other = op2, op1 349 350 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 351 tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), 352 is_non_singular=hints.is_non_singular, 353 is_self_adjoint=hints.is_self_adjoint, 354 is_positive_definite=hints.is_positive_definite, 355 name=operator_name) 356 357 358class _AddAndReturnMatrix(_Adder): 359 """"Handles additions resulting in a `LinearOperatorFullMatrix`.""" 360 361 def can_add(self, op1, op2): # pylint: disable=unused-argument 362 return isinstance(op1, linear_operator.LinearOperator) and isinstance( 363 op2, linear_operator.LinearOperator) 364 365 def _add(self, op1, op2, operator_name, hints): 366 if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: 367 op_add_to_tensor, op_other = op1, op2 368 else: 369 op_add_to_tensor, op_other = op2, op1 370 return linear_operator_full_matrix.LinearOperatorFullMatrix( 371 matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()), 372 is_non_singular=hints.is_non_singular, 373 is_self_adjoint=hints.is_self_adjoint, 374 is_positive_definite=hints.is_positive_definite, 375 name=operator_name) 376 377 378################################################################################ 379# Constants designating types of LinearOperators 380################################################################################ 381 382# Type name constants for LinearOperator classes. 383_IDENTITY = "identity" 384_SCALED_IDENTITY = "scaled_identity" 385_DIAG = "diag" 386_TRIL = "tril" 387_MATRIX = "matrix" 388 389# Groups of operators. 390_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY} 391_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY} 392# operators with an efficient .add_to_tensor() method. 393_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE 394 395 396def _type(operator): 397 """Returns the type name constant (e.g. _TRIL) for operator.""" 398 if isinstance(operator, linear_operator_diag.LinearOperatorDiag): 399 return _DIAG 400 if isinstance(operator, 401 linear_operator_lower_triangular.LinearOperatorLowerTriangular): 402 return _TRIL 403 if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): 404 return _MATRIX 405 if isinstance(operator, linear_operator_identity.LinearOperatorIdentity): 406 return _IDENTITY 407 if isinstance(operator, 408 linear_operator_identity.LinearOperatorScaledIdentity): 409 return _SCALED_IDENTITY 410 raise TypeError("Operator type unknown: %s" % operator) 411 412 413################################################################################ 414# Addition tiers: 415# We attempt to use Adders in tier K before K+1. 416# 417# Organize tiers to 418# (i) reduce O(..) complexity of forming final operator, and 419# (ii) produce the "most efficient" final operator. 420# Dev notes: 421# * Results of addition at tier K will be added at tier K or higher. 422# * Tiers may change, and we warn the user that it may change. 423################################################################################ 424 425# Note that the final tier, _AddAndReturnMatrix, will convert everything to a 426# dense matrix. So it is sometimes very inefficient. 427_DEFAULT_ADDITION_TIERS = [ 428 [_AddAndReturnScaledIdentity()], 429 [_AddAndReturnDiag()], 430 [_AddAndReturnTriL()], 431 [_AddAndReturnMatrix()], 432] 433