1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.contrib.tpu.python.ops import tpu_ops 24from tensorflow.contrib.tpu.python.tpu import tpu_function 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import gradients_impl 30from tensorflow.python.ops import linalg_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import random_ops 33from tensorflow.python.ops import resource_variable_ops 34from tensorflow.python.ops import variables 35 36# Method used for inverting matrices. 37POSDEF_INV_METHOD = "cholesky" 38POSDEF_EIG_METHOD = "self_adjoint" 39 40 41def set_global_constants(posdef_inv_method=None): 42 """Sets various global constants used by the classes in this module.""" 43 global POSDEF_INV_METHOD 44 45 if posdef_inv_method is not None: 46 POSDEF_INV_METHOD = posdef_inv_method 47 48 49class SequenceDict(object): 50 """A dict convenience wrapper that allows getting/setting with sequences.""" 51 52 def __init__(self, iterable=None): 53 self._dict = dict(iterable or []) 54 55 def __getitem__(self, key_or_keys): 56 if isinstance(key_or_keys, (tuple, list)): 57 return list(map(self.__getitem__, key_or_keys)) 58 else: 59 return self._dict[key_or_keys] 60 61 def __setitem__(self, key_or_keys, val_or_vals): 62 if isinstance(key_or_keys, (tuple, list)): 63 for key, value in zip(key_or_keys, val_or_vals): 64 self[key] = value 65 else: 66 self._dict[key_or_keys] = val_or_vals 67 68 def items(self): 69 return list(self._dict.items()) 70 71 72def tensors_to_column(tensors): 73 """Converts a tensor or list of tensors to a column vector. 74 75 Args: 76 tensors: A tensor or list of tensors. 77 78 Returns: 79 The tensors reshaped into vectors and stacked on top of each other. 80 """ 81 if isinstance(tensors, (tuple, list)): 82 return array_ops.concat( 83 tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0) 84 else: 85 return array_ops.reshape(tensors, [-1, 1]) 86 87 88def column_to_tensors(tensors_template, colvec): 89 """Converts a column vector back to the shape of the given template. 90 91 Args: 92 tensors_template: A tensor or list of tensors. 93 colvec: A 2d column vector with the same shape as the value of 94 tensors_to_column(tensors_template). 95 96 Returns: 97 X, where X is tensor or list of tensors with the properties: 98 1) tensors_to_column(X) = colvec 99 2) X (or its elements) have the same shape as tensors_template (or its 100 elements) 101 """ 102 if isinstance(tensors_template, (tuple, list)): 103 offset = 0 104 tensors = [] 105 for tensor_template in tensors_template: 106 sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32) 107 tensor = array_ops.reshape(colvec[offset:(offset + sz)], 108 tensor_template.shape) 109 tensors.append(tensor) 110 offset += sz 111 112 tensors = tuple(tensors) 113 else: 114 tensors = array_ops.reshape(colvec, tensors_template.shape) 115 116 return tensors 117 118 119def kronecker_product(mat1, mat2): 120 """Computes the Kronecker product two matrices.""" 121 m1, n1 = mat1.get_shape().as_list() 122 mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) 123 m2, n2 = mat2.get_shape().as_list() 124 mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) 125 return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) 126 127 128def layer_params_to_mat2d(vector): 129 """Converts a vector shaped like layer parameters to a 2D matrix. 130 131 In particular, we reshape the weights/filter component of the vector to be 132 2D, flattening all leading (input) dimensions. If there is a bias component, 133 we concatenate it to the reshaped weights/filter component. 134 135 Args: 136 vector: A Tensor or pair of Tensors shaped like layer parameters. 137 138 Returns: 139 A 2D Tensor with the same coefficients and the same output dimension. 140 """ 141 if isinstance(vector, (tuple, list)): 142 w_part, b_part = vector 143 w_part_reshaped = array_ops.reshape(w_part, 144 [-1, w_part.shape.as_list()[-1]]) 145 return array_ops.concat( 146 (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) 147 elif isinstance(vector, ops.IndexedSlices): 148 return vector 149 else: # Tensor or Tensor-like. 150 return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) 151 152 153def mat2d_to_layer_params(vector_template, mat2d): 154 """Converts a canonical 2D matrix representation back to a vector. 155 156 Args: 157 vector_template: A Tensor or pair of Tensors shaped like layer parameters. 158 mat2d: A 2D Tensor with the same shape as the value of 159 layer_params_to_mat2d(vector_template). 160 161 Returns: 162 A Tensor or pair of Tensors with the same coefficients as mat2d and the same 163 shape as vector_template. 164 """ 165 if isinstance(vector_template, (tuple, list)): 166 w_part, b_part = mat2d[:-1], mat2d[-1] 167 return array_ops.reshape(w_part, vector_template[0].shape), b_part 168 elif isinstance(vector_template, ops.IndexedSlices): 169 if not isinstance(mat2d, ops.IndexedSlices): 170 raise TypeError( 171 "If vector_template is an IndexedSlices, so should mat2d.") 172 return mat2d 173 else: 174 return array_ops.reshape(mat2d, vector_template.shape) 175 176 177def posdef_inv(tensor, damping): 178 """Computes the inverse of tensor + damping * identity.""" 179 identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) 180 damping = math_ops.cast(damping, dtype=tensor.dtype) 181 return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) 182 183 184def posdef_inv_matrix_inverse(tensor, identity, damping): 185 """Computes inverse(tensor + damping * identity) directly.""" 186 return linalg_ops.matrix_inverse(tensor + damping * identity) 187 188 189def posdef_inv_cholesky(tensor, identity, damping): 190 """Computes inverse(tensor + damping * identity) with Cholesky.""" 191 chol = linalg_ops.cholesky(tensor + damping * identity) 192 return linalg_ops.cholesky_solve(chol, identity) 193 194 195def posdef_inv_eig(tensor, identity, damping): 196 """Computes inverse(tensor + damping * identity) with eigendecomposition.""" 197 eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig( 198 tensor + damping * identity) 199 return math_ops.matmul( 200 eigenvectors / eigenvalues, eigenvectors, transpose_b=True) 201 202 203posdef_inv_functions = { 204 "matrix_inverse": posdef_inv_matrix_inverse, 205 "cholesky": posdef_inv_cholesky, 206 "eig": posdef_inv_eig, 207} 208 209 210def posdef_eig(mat): 211 """Computes the eigendecomposition of a positive semidefinite matrix.""" 212 return posdef_eig_functions[POSDEF_EIG_METHOD](mat) 213 214 215def posdef_eig_svd(mat): 216 """Computes the singular values and left singular vectors of a matrix.""" 217 evals, evecs, _ = linalg_ops.svd(mat) 218 219 return evals, evecs 220 221 222def posdef_eig_self_adjoint(mat): 223 """Computes eigendecomposition using self_adjoint_eig.""" 224 evals, evecs = linalg_ops.self_adjoint_eig(mat) 225 evals = math_ops.abs(evals) # Should be equivalent to svd approach. 226 227 return evals, evecs 228 229 230posdef_eig_functions = { 231 "self_adjoint": posdef_eig_self_adjoint, 232 "svd": posdef_eig_svd, 233} 234 235 236class SubGraph(object): 237 """Defines a subgraph given by all the dependencies of a given set of outputs. 238 """ 239 240 def __init__(self, outputs): 241 # Set of all ancestor Tensors, Ops to 'outputs'. 242 self._members = set() 243 244 self._recurse_add(outputs) 245 246 def _recurse_add(self, nodes): 247 """Recursively adds all of nodes' ancestors.""" 248 for node in nodes: 249 if node in self._members: 250 continue 251 self._members.add(node) 252 253 if isinstance(node, ops.Tensor): 254 self._recurse_add((node.op,)) 255 elif isinstance(node, ops.Operation): 256 self._recurse_add(node.inputs) 257 258 def is_member(self, node): 259 """Check if 'node' is in this subgraph.""" 260 return node in self._members 261 262 def variable_uses(self, var): 263 """Computes number of times a variable is used. 264 265 Args: 266 var: Variable or ResourceVariable instance. 267 268 Returns: 269 Number of times a variable is used within this subgraph. 270 271 Raises: 272 ValueError: If 'var' is not a variable type. 273 """ 274 if isinstance(var, resource_variable_ops.ResourceVariable): 275 var = var.handle 276 elif isinstance(var, variables.Variable): 277 var = var.value() 278 else: 279 raise ValueError("%s does not appear to be a variable." % str(var)) 280 281 return len(self._members.intersection(set(var.consumers()))) 282 283 def filter_list(self, node_list): 284 """Filters 'node_list' to nodes in this subgraph.""" 285 filtered_list = [] 286 for node in node_list: 287 if self.is_member(node): 288 filtered_list.append(node) 289 return filtered_list 290 291 292def generate_random_signs(shape, dtype=dtypes.float32): 293 """Generate a random tensor with {-1, +1} entries.""" 294 ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32) 295 return 2 * math_ops.cast(ints, dtype=dtype) - 1 296 297 298def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): 299 """Compute forward-mode gradients.""" 300 # See b/37888268. 301 302 # This version of forward-mode autodiff is based on code by Tim Cooijmans 303 # and handles list arguments and certain special cases such as when the 304 # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are 305 # generated by the first gradients_impl.gradients call. 306 307 us = [array_ops.zeros_like(y) + float("nan") for y in ys] 308 dydxs = gradients_impl.gradients( 309 ys, xs, grad_ys=us, stop_gradients=stop_gradients) 310 311 # Deal with strange types that gradients_impl.gradients returns but can't 312 # deal with. 313 dydxs = [ 314 ops.convert_to_tensor(dydx) 315 if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs 316 ] 317 dydxs = [ 318 array_ops.zeros_like(x) if dydx is None else dydx 319 for x, dydx in zip(xs, dydxs) 320 ] 321 322 dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) 323 324 return dysdx 325 326 327def on_tpu(): 328 """Returns True when building a TPU computation.""" 329 return tpu_function.get_tpu_context().number_of_shards is not None 330 331 332def cross_replica_mean(tensor, name=None): 333 """Takes mean value of a Tensor across all TPU cores. 334 335 Args: 336 tensor: Tensor to be synchronized. 337 name: None or string. Name of Op. 338 339 Returns: 340 Average of Tensor across all TPU cores. 341 342 Raises: 343 ValueError: If called outside of TPU context. 344 """ 345 with ops.name_scope(name, "cross_replica_mean", [tensor]): 346 num_shards = tpu_function.get_tpu_context().number_of_shards 347 if num_shards is None: 348 raise ValueError( 349 "Cannot take cross_replica_mean() outside of TPU Context.") 350 if num_shards == 1: 351 return tensor 352 return tpu_ops.cross_replica_sum(tensor / num_shards) 353 354 355def ensure_sequence(obj): 356 """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" 357 if isinstance(obj, (tuple, list)): 358 return obj 359 else: 360 return (obj,) 361 362 363def batch_execute(global_step, thunks, batch_size, name=None): 364 """Executes a subset of ops per global step. 365 366 Given a list of thunks, each of which produces a single stateful op, 367 ensures that exactly 'batch_size' ops are run per global step. Ops are 368 scheduled in a round-robin fashion. For example, with 3 ops 369 370 global_step | op0 | op1 | op2 371 ------------+-----+-----+----- 372 0 | x | x | 373 ------------+-----+-----+----- 374 1 | x | | x 375 ------------+-----+-----+----- 376 2 | | x | x 377 ------------+-----+-----+----- 378 3 | x | x | 379 ------------+-----+-----+----- 380 4 | x | | x 381 382 Does not guarantee order of op execution within a single global step. 383 384 Args: 385 global_step: Tensor indicating time. Determines which ops run. 386 thunks: List of thunks. Each thunk encapsulates one op. Return values are 387 ignored. 388 batch_size: int. Number of ops to execute per global_step. 389 name: string or None. Name scope for newly added ops. 390 391 Returns: 392 List of ops. Exactly 'batch_size' ops are guaranteed to have an effect 393 every global step. 394 """ 395 396 def true_fn(thunk): 397 """Ensures thunk is executed and returns an Op (not a Tensor).""" 398 399 def result(): 400 with ops.control_dependencies([thunk()]): 401 return control_flow_ops.no_op() 402 403 return result 404 405 def false_fn(_): 406 """Executes a no-op.""" 407 408 def result(): 409 return control_flow_ops.no_op() 410 411 return result 412 413 with ops.name_scope(name, "batch_execute"): 414 true_fns = [true_fn(thunk) for thunk in thunks] 415 false_fns = [false_fn(thunk) for thunk in thunks] 416 num_thunks = len(thunks) 417 conditions = [ 418 math_ops.less( 419 math_ops.mod(batch_size - 1 + global_step * batch_size - j, 420 num_thunks), batch_size) for j in range(num_thunks) 421 ] 422 result = [ 423 control_flow_ops.cond(condition, true_fn, false_fn) 424 for (condition, true_fn, 425 false_fn) in zip(conditions, true_fns, false_fns) 426 ] 427 return result 428 429 430def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name 431 """Computes matmul(A, B) where A is sparse, B is dense. 432 433 Args: 434 A: tf.IndexedSlices with dense shape [m, n]. 435 B: tf.Tensor with shape [n, k]. 436 name: str. Name of op. 437 438 Returns: 439 tf.IndexedSlices resulting from matmul(A, B). 440 441 Raises: 442 ValueError: If A doesn't represent a matrix. 443 ValueError: If B is not rank-2. 444 """ 445 with ops.name_scope(name, "matmul_sparse_dense", [A, B]): 446 if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2: 447 raise ValueError("A must represent a matrix. Found: %s." % A) 448 if B.shape.ndims != 2: 449 raise ValueError("B must be a matrix.") 450 new_values = math_ops.matmul(A.values, B) 451 return ops.IndexedSlices( 452 new_values, 453 A.indices, 454 dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]])) 455 456 457def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name 458 """Computes matmul(A, B) where A is a diagonal matrix, B is sparse. 459 460 Args: 461 A_diag: diagonal entries of matrix A of shape [m, m]. 462 B: tf.IndexedSlices. Represents matrix of shape [m, n]. 463 name: str. Name of op. 464 465 Returns: 466 tf.IndexedSlices resulting from matmul(A, B). 467 468 Raises: 469 ValueError: If A_diag is not rank-1. 470 ValueError: If B doesn't represent a matrix. 471 """ 472 with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]): 473 A_diag = ops.convert_to_tensor(A_diag) 474 if A_diag.shape.ndims != 1: 475 raise ValueError("A_diag must be a rank-1 Tensor.") 476 if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2: 477 raise ValueError("B must represent a matrix. Found: %s." % B) 478 a = array_ops.gather(A_diag, B.indices) 479 a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) 480 return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) 481 482# TODO(b/69623235): Add a function for finding tensors that share gradients 483# to eliminate redundant fisher factor computations. 484