1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib.tpu.python.ops import tpu_ops
24from tensorflow.contrib.tpu.python.tpu import tpu_function
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import gradients_impl
30from tensorflow.python.ops import linalg_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import random_ops
33from tensorflow.python.ops import resource_variable_ops
34from tensorflow.python.ops import variables
35
36# Method used for inverting matrices.
37POSDEF_INV_METHOD = "cholesky"
38POSDEF_EIG_METHOD = "self_adjoint"
39
40
41def set_global_constants(posdef_inv_method=None):
42  """Sets various global constants used by the classes in this module."""
43  global POSDEF_INV_METHOD
44
45  if posdef_inv_method is not None:
46    POSDEF_INV_METHOD = posdef_inv_method
47
48
49class SequenceDict(object):
50  """A dict convenience wrapper that allows getting/setting with sequences."""
51
52  def __init__(self, iterable=None):
53    self._dict = dict(iterable or [])
54
55  def __getitem__(self, key_or_keys):
56    if isinstance(key_or_keys, (tuple, list)):
57      return list(map(self.__getitem__, key_or_keys))
58    else:
59      return self._dict[key_or_keys]
60
61  def __setitem__(self, key_or_keys, val_or_vals):
62    if isinstance(key_or_keys, (tuple, list)):
63      for key, value in zip(key_or_keys, val_or_vals):
64        self[key] = value
65    else:
66      self._dict[key_or_keys] = val_or_vals
67
68  def items(self):
69    return list(self._dict.items())
70
71
72def tensors_to_column(tensors):
73  """Converts a tensor or list of tensors to a column vector.
74
75  Args:
76    tensors: A tensor or list of tensors.
77
78  Returns:
79    The tensors reshaped into vectors and stacked on top of each other.
80  """
81  if isinstance(tensors, (tuple, list)):
82    return array_ops.concat(
83        tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0)
84  else:
85    return array_ops.reshape(tensors, [-1, 1])
86
87
88def column_to_tensors(tensors_template, colvec):
89  """Converts a column vector back to the shape of the given template.
90
91  Args:
92    tensors_template: A tensor or list of tensors.
93    colvec: A 2d column vector with the same shape as the value of
94        tensors_to_column(tensors_template).
95
96  Returns:
97    X, where X is tensor or list of tensors with the properties:
98     1) tensors_to_column(X) = colvec
99     2) X (or its elements) have the same shape as tensors_template (or its
100        elements)
101  """
102  if isinstance(tensors_template, (tuple, list)):
103    offset = 0
104    tensors = []
105    for tensor_template in tensors_template:
106      sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32)
107      tensor = array_ops.reshape(colvec[offset:(offset + sz)],
108                                 tensor_template.shape)
109      tensors.append(tensor)
110      offset += sz
111
112    tensors = tuple(tensors)
113  else:
114    tensors = array_ops.reshape(colvec, tensors_template.shape)
115
116  return tensors
117
118
119def kronecker_product(mat1, mat2):
120  """Computes the Kronecker product two matrices."""
121  m1, n1 = mat1.get_shape().as_list()
122  mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
123  m2, n2 = mat2.get_shape().as_list()
124  mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
125  return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
126
127
128def layer_params_to_mat2d(vector):
129  """Converts a vector shaped like layer parameters to a 2D matrix.
130
131  In particular, we reshape the weights/filter component of the vector to be
132  2D, flattening all leading (input) dimensions. If there is a bias component,
133  we concatenate it to the reshaped weights/filter component.
134
135  Args:
136    vector: A Tensor or pair of Tensors shaped like layer parameters.
137
138  Returns:
139    A 2D Tensor with the same coefficients and the same output dimension.
140  """
141  if isinstance(vector, (tuple, list)):
142    w_part, b_part = vector
143    w_part_reshaped = array_ops.reshape(w_part,
144                                        [-1, w_part.shape.as_list()[-1]])
145    return array_ops.concat(
146        (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0)
147  elif isinstance(vector, ops.IndexedSlices):
148    return vector
149  else:  # Tensor or Tensor-like.
150    return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]])
151
152
153def mat2d_to_layer_params(vector_template, mat2d):
154  """Converts a canonical 2D matrix representation back to a vector.
155
156  Args:
157    vector_template: A Tensor or pair of Tensors shaped like layer parameters.
158    mat2d: A 2D Tensor with the same shape as the value of
159        layer_params_to_mat2d(vector_template).
160
161  Returns:
162    A Tensor or pair of Tensors with the same coefficients as mat2d and the same
163        shape as vector_template.
164  """
165  if isinstance(vector_template, (tuple, list)):
166    w_part, b_part = mat2d[:-1], mat2d[-1]
167    return array_ops.reshape(w_part, vector_template[0].shape), b_part
168  elif isinstance(vector_template, ops.IndexedSlices):
169    if not isinstance(mat2d, ops.IndexedSlices):
170      raise TypeError(
171          "If vector_template is an IndexedSlices, so should mat2d.")
172    return mat2d
173  else:
174    return array_ops.reshape(mat2d, vector_template.shape)
175
176
177def posdef_inv(tensor, damping):
178  """Computes the inverse of tensor + damping * identity."""
179  identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
180  damping = math_ops.cast(damping, dtype=tensor.dtype)
181  return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping)
182
183
184def posdef_inv_matrix_inverse(tensor, identity, damping):
185  """Computes inverse(tensor + damping * identity) directly."""
186  return linalg_ops.matrix_inverse(tensor + damping * identity)
187
188
189def posdef_inv_cholesky(tensor, identity, damping):
190  """Computes inverse(tensor + damping * identity) with Cholesky."""
191  chol = linalg_ops.cholesky(tensor + damping * identity)
192  return linalg_ops.cholesky_solve(chol, identity)
193
194
195def posdef_inv_eig(tensor, identity, damping):
196  """Computes inverse(tensor + damping * identity) with eigendecomposition."""
197  eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
198      tensor + damping * identity)
199  return math_ops.matmul(
200      eigenvectors / eigenvalues, eigenvectors, transpose_b=True)
201
202
203posdef_inv_functions = {
204    "matrix_inverse": posdef_inv_matrix_inverse,
205    "cholesky": posdef_inv_cholesky,
206    "eig": posdef_inv_eig,
207}
208
209
210def posdef_eig(mat):
211  """Computes the eigendecomposition of a positive semidefinite matrix."""
212  return posdef_eig_functions[POSDEF_EIG_METHOD](mat)
213
214
215def posdef_eig_svd(mat):
216  """Computes the singular values and left singular vectors of a matrix."""
217  evals, evecs, _ = linalg_ops.svd(mat)
218
219  return evals, evecs
220
221
222def posdef_eig_self_adjoint(mat):
223  """Computes eigendecomposition using self_adjoint_eig."""
224  evals, evecs = linalg_ops.self_adjoint_eig(mat)
225  evals = math_ops.abs(evals)  # Should be equivalent to svd approach.
226
227  return evals, evecs
228
229
230posdef_eig_functions = {
231    "self_adjoint": posdef_eig_self_adjoint,
232    "svd": posdef_eig_svd,
233}
234
235
236class SubGraph(object):
237  """Defines a subgraph given by all the dependencies of a given set of outputs.
238  """
239
240  def __init__(self, outputs):
241    # Set of all ancestor Tensors, Ops to 'outputs'.
242    self._members = set()
243
244    self._recurse_add(outputs)
245
246  def _recurse_add(self, nodes):
247    """Recursively adds all of nodes' ancestors."""
248    for node in nodes:
249      if node in self._members:
250        continue
251      self._members.add(node)
252
253      if isinstance(node, ops.Tensor):
254        self._recurse_add((node.op,))
255      elif isinstance(node, ops.Operation):
256        self._recurse_add(node.inputs)
257
258  def is_member(self, node):
259    """Check if 'node' is in this subgraph."""
260    return node in self._members
261
262  def variable_uses(self, var):
263    """Computes number of times a variable is used.
264
265    Args:
266      var: Variable or ResourceVariable instance.
267
268    Returns:
269      Number of times a variable is used within this subgraph.
270
271    Raises:
272      ValueError: If 'var' is not a variable type.
273    """
274    if isinstance(var, resource_variable_ops.ResourceVariable):
275      var = var.handle
276    elif isinstance(var, variables.Variable):
277      var = var.value()
278    else:
279      raise ValueError("%s does not appear to be a variable." % str(var))
280
281    return len(self._members.intersection(set(var.consumers())))
282
283  def filter_list(self, node_list):
284    """Filters 'node_list' to nodes in this subgraph."""
285    filtered_list = []
286    for node in node_list:
287      if self.is_member(node):
288        filtered_list.append(node)
289    return filtered_list
290
291
292def generate_random_signs(shape, dtype=dtypes.float32):
293  """Generate a random tensor with {-1, +1} entries."""
294  ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32)
295  return 2 * math_ops.cast(ints, dtype=dtype) - 1
296
297
298def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
299  """Compute forward-mode gradients."""
300  # See b/37888268.
301
302  # This version of forward-mode autodiff is based on code by Tim Cooijmans
303  # and handles list arguments and certain special cases such as when the
304  # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are
305  # generated by the first gradients_impl.gradients call.
306
307  us = [array_ops.zeros_like(y) + float("nan") for y in ys]
308  dydxs = gradients_impl.gradients(
309      ys, xs, grad_ys=us, stop_gradients=stop_gradients)
310
311  # Deal with strange types that gradients_impl.gradients returns but can't
312  # deal with.
313  dydxs = [
314      ops.convert_to_tensor(dydx)
315      if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs
316  ]
317  dydxs = [
318      array_ops.zeros_like(x) if dydx is None else dydx
319      for x, dydx in zip(xs, dydxs)
320  ]
321
322  dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs)
323
324  return dysdx
325
326
327def on_tpu():
328  """Returns True when building a TPU computation."""
329  return tpu_function.get_tpu_context().number_of_shards is not None
330
331
332def cross_replica_mean(tensor, name=None):
333  """Takes mean value of a Tensor across all TPU cores.
334
335  Args:
336    tensor: Tensor to be synchronized.
337    name: None or string. Name of Op.
338
339  Returns:
340    Average of Tensor across all TPU cores.
341
342  Raises:
343    ValueError: If called outside of TPU context.
344  """
345  with ops.name_scope(name, "cross_replica_mean", [tensor]):
346    num_shards = tpu_function.get_tpu_context().number_of_shards
347    if num_shards is None:
348      raise ValueError(
349          "Cannot take cross_replica_mean() outside of TPU Context.")
350    if num_shards == 1:
351      return tensor
352    return tpu_ops.cross_replica_sum(tensor / num_shards)
353
354
355def ensure_sequence(obj):
356  """If `obj` isn't a tuple or list, return a tuple containing `obj`."""
357  if isinstance(obj, (tuple, list)):
358    return obj
359  else:
360    return (obj,)
361
362
363def batch_execute(global_step, thunks, batch_size, name=None):
364  """Executes a subset of ops per global step.
365
366  Given a list of thunks, each of which produces a single stateful op,
367  ensures that exactly 'batch_size' ops are run per global step. Ops are
368  scheduled in a round-robin fashion. For example, with 3 ops
369
370    global_step | op0 | op1 | op2
371    ------------+-----+-----+-----
372        0       |  x  |  x  |
373    ------------+-----+-----+-----
374        1       |  x  |     |  x
375    ------------+-----+-----+-----
376        2       |     |  x  |  x
377    ------------+-----+-----+-----
378        3       |  x  |  x  |
379    ------------+-----+-----+-----
380        4       |  x  |     |  x
381
382  Does not guarantee order of op execution within a single global step.
383
384  Args:
385    global_step: Tensor indicating time. Determines which ops run.
386    thunks: List of thunks. Each thunk encapsulates one op. Return values are
387      ignored.
388    batch_size: int. Number of ops to execute per global_step.
389    name: string or None. Name scope for newly added ops.
390
391  Returns:
392    List of ops. Exactly 'batch_size' ops are guaranteed to have an effect
393    every global step.
394  """
395
396  def true_fn(thunk):
397    """Ensures thunk is executed and returns an Op (not a Tensor)."""
398
399    def result():
400      with ops.control_dependencies([thunk()]):
401        return control_flow_ops.no_op()
402
403    return result
404
405  def false_fn(_):
406    """Executes a no-op."""
407
408    def result():
409      return control_flow_ops.no_op()
410
411    return result
412
413  with ops.name_scope(name, "batch_execute"):
414    true_fns = [true_fn(thunk) for thunk in thunks]
415    false_fns = [false_fn(thunk) for thunk in thunks]
416    num_thunks = len(thunks)
417    conditions = [
418        math_ops.less(
419            math_ops.mod(batch_size - 1 + global_step * batch_size - j,
420                         num_thunks), batch_size) for j in range(num_thunks)
421    ]
422    result = [
423        control_flow_ops.cond(condition, true_fn, false_fn)
424        for (condition, true_fn,
425             false_fn) in zip(conditions, true_fns, false_fns)
426    ]
427    return result
428
429
430def matmul_sparse_dense(A, B, name=None):  # pylint: disable=invalid-name
431  """Computes matmul(A, B) where A is sparse, B is dense.
432
433  Args:
434    A: tf.IndexedSlices with dense shape [m, n].
435    B: tf.Tensor with shape [n, k].
436    name: str. Name of op.
437
438  Returns:
439    tf.IndexedSlices resulting from matmul(A, B).
440
441  Raises:
442    ValueError: If A doesn't represent a matrix.
443    ValueError: If B is not rank-2.
444  """
445  with ops.name_scope(name, "matmul_sparse_dense", [A, B]):
446    if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2:
447      raise ValueError("A must represent a matrix. Found: %s." % A)
448    if B.shape.ndims != 2:
449      raise ValueError("B must be a matrix.")
450    new_values = math_ops.matmul(A.values, B)
451    return ops.IndexedSlices(
452        new_values,
453        A.indices,
454        dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]]))
455
456
457def matmul_diag_sparse(A_diag, B, name=None):  # pylint: disable=invalid-name
458  """Computes matmul(A, B) where A is a diagonal matrix, B is sparse.
459
460  Args:
461    A_diag: diagonal entries of matrix A of shape [m, m].
462    B: tf.IndexedSlices. Represents matrix of shape [m, n].
463    name: str. Name of op.
464
465  Returns:
466    tf.IndexedSlices resulting from matmul(A, B).
467
468  Raises:
469    ValueError: If A_diag is not rank-1.
470    ValueError: If B doesn't represent a matrix.
471  """
472  with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]):
473    A_diag = ops.convert_to_tensor(A_diag)
474    if A_diag.shape.ndims != 1:
475      raise ValueError("A_diag must be a rank-1 Tensor.")
476    if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2:
477      raise ValueError("B must represent a matrix. Found: %s." % B)
478    a = array_ops.gather(A_diag, B.indices)
479    a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
480    return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
481
482# TODO(b/69623235): Add a function for finding tensors that share gradients
483# to eliminate redundant fisher factor computations.
484