estimator.py revision de1b4a8a75ae3a50f4fa7480efb1177d79abf553
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Defines the high-level Fisher estimator class."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23import numpy as np
24
25from tensorflow.contrib.kfac.python.ops import utils
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import gradients_impl
28from tensorflow.python.util import nest
29
30
31class FisherEstimator(object):
32  """Fisher estimator class supporting various approximations of the Fisher."""
33
34  def __init__(self,
35               variables,
36               cov_ema_decay,
37               damping,
38               layer_collection,
39               estimation_mode="gradients"):
40    """Create a FisherEstimator object.
41
42    Args:
43      variables: A list of the variables for which to estimate the Fisher. This
44          must match the variables registered in layer_collection (if it is not
45          None).
46      cov_ema_decay: The decay factor used when calculating the covariance
47          estimate moving averages.
48      damping: The damping factor used to stabilize training due to errors in
49          the local approximation with the Fisher information matrix, and to
50          regularize the update direction by making it closer to the gradient.
51          (Higher damping means the update looks more like a standard gradient
52          update - see Tikhonov regularization.)
53      layer_collection: The layer collection object, which holds the fisher
54          blocks, kronecker factors, and losses associated with the
55          graph.
56      estimation_mode: The type of estimator to use for the Fishers.  Can be
57          'gradients', 'empirical', 'curvature_propagation', or 'exact'.
58          (Default: 'gradients').  'gradients' is the basic estimation approach
59          from the original K-FAC paper.  'empirical' computes the 'empirical'
60          Fisher information matrix (which uses the data's distribution for the
61          targets, as opposed to the true Fisher which uses the model's
62          distribution) and requires that each registered loss have specified
63          targets. 'curvature_propagation' is a method which estimates the
64          Fisher using self-products of random 1/-1 vectors times "half-factors"
65          of the Fisher, as described here: https://arxiv.org/abs/1206.6464 .
66          Finally, 'exact' is the obvious generalization of Curvature
67          Propagation to compute the exact Fisher (modulo any additional
68          diagonal or Kronecker approximations) by looping over one-hot vectors
69          for each coordinate of the output instead of using 1/-1 vectors.  It
70          is more expensive to compute than the other three options by a factor
71          equal to the output dimension, roughly speaking.
72
73    Raises:
74      ValueError: If no losses have been registered with layer_collection.
75    """
76
77    self._variables = variables
78    self._damping = damping
79    self._estimation_mode = estimation_mode
80    self._layers = layer_collection
81    self._layers.create_subgraph()
82    self._check_registration(variables)
83    self._gradient_fns = {
84        "gradients": self._get_grads_lists_gradients,
85        "empirical": self._get_grads_lists_empirical,
86        "curvature_prop": self._get_grads_lists_curvature_prop,
87        "exact": self._get_grads_lists_exact
88    }
89    setup = self._setup(cov_ema_decay)
90    self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup
91
92  @property
93  def variables(self):
94    return self._variables
95
96  @property
97  def damping(self):
98    return self._damping
99
100  def _apply_transformation(self, vecs_and_vars, transform):
101    """Applies an block-wise transformation to the corresponding vectors.
102
103    Args:
104      vecs_and_vars: List of (vector, variable) pairs.
105      transform: A function of the form f(fb, vec), where vec is the vector
106          to transform and fb is its corresponding block in the matrix, that
107          returns the transformed vector.
108
109    Returns:
110      A list of (transformed vector, var) pairs in the same order as
111      vecs_and_vars.
112    """
113
114    vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)
115
116    trans_vecs = utils.SequenceDict()
117
118    for params, fb in self._layers.fisher_blocks.items():
119      trans_vecs[params] = transform(fb, vecs[params])
120
121    return [(trans_vecs[var], var) for _, var in vecs_and_vars]
122
123  def multiply_inverse(self, vecs_and_vars):
124    """Multiplies the vecs by the corresponding (damped) inverses of the blocks.
125
126    Args:
127      vecs_and_vars: List of (vector, variable) pairs.
128
129    Returns:
130      A list of (transformed vector, var) pairs in the same order as
131      vecs_and_vars.
132    """
133
134    return self._apply_transformation(vecs_and_vars,
135                                      lambda fb, vec: fb.multiply_inverse(vec))
136
137  def multiply(self, vecs_and_vars):
138    """Multiplies the vectors by the corresponding (damped) blocks.
139
140    Args:
141      vecs_and_vars: List of (vector, variable) pairs.
142
143    Returns:
144      A list of (transformed vector, var) pairs in the same order as
145      vecs_and_vars.
146    """
147
148    return self._apply_transformation(vecs_and_vars,
149                                      lambda fb, vec: fb.multiply(vec))
150
151  def _check_registration(self, variables):
152    """Checks that all variable uses have been registered properly.
153
154    Args:
155      variables: List of variables.
156
157    Raises:
158      ValueError: If any registered variables are not included in the list.
159      ValueError: If any variable in the list is not registered.
160      ValueError: If any variable in the list is registered with the wrong
161          number of "uses" in the subgraph recorded (vs the number of times that
162          variable is actually used in the subgraph).
163    """
164    # Note that overlapping parameters (i.e. those that share variables) will
165    # be caught by layer_collection.LayerParametersDict during registration.
166
167    reg_use_map = self._layers.get_use_count_map()
168
169    error_messages = []
170
171    for var in variables:
172      total_uses = self._layers.subgraph.variable_uses(var)
173      reg_uses = reg_use_map[var]
174
175      if reg_uses == 0:
176        error_messages.append("Variable {} not registered.".format(var))
177      elif (not math.isinf(reg_uses)) and reg_uses != total_uses:
178        error_messages.append(
179            "Variable {} registered with wrong number of uses ({} "
180            "vs {} actual).".format(var, reg_uses, total_uses))
181
182    num_get_vars = len(reg_use_map)
183
184    if num_get_vars > len(variables):
185      error_messages.append("{} registered variables were not included in list."
186                            .format(num_get_vars - len(variables)))
187
188    if error_messages:
189      error_messages = [
190          "Found the following errors with variable registration:"
191      ] + error_messages
192      raise ValueError("\n\t".join(error_messages))
193
194  def _setup(self, cov_ema_decay):
195    """Sets up the various operations.
196
197    Args:
198      cov_ema_decay: The decay factor used when calculating the covariance
199          estimate moving averages.
200
201    Returns:
202      A triple (covs_update_op, invs_update_op, inv_updates_dict), where
203      covs_update_op is the grouped Op to update all the covariance estimates,
204      invs_update_op is the grouped Op to update all the inverses, and
205      inv_updates_dict is a dict mapping Op names to individual inverse updates.
206
207    Raises:
208      ValueError: If estimation_mode was improperly specified at construction.
209    """
210    fisher_blocks_list = self._layers.get_blocks()
211    tensors_to_compute_grads = [
212        fb.tensors_to_compute_grads() for fb in fisher_blocks_list
213    ]
214
215    try:
216      grads_lists = self._gradient_fns[self._estimation_mode](
217          tensors_to_compute_grads)
218    except KeyError:
219      raise ValueError("Unrecognized value {} for estimation_mode.".format(
220          self._estimation_mode))
221
222    for grads_list, fb in zip(grads_lists, fisher_blocks_list):
223      fb.instantiate_factors(grads_list, self.damping)
224
225    cov_updates = [
226        factor.make_covariance_update_op(cov_ema_decay)
227        for factor in self._layers.get_factors()
228    ]
229    inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()}
230
231    return control_flow_ops.group(*cov_updates), control_flow_ops.group(
232        *inv_updates.values()), inv_updates
233
234  def _get_all_inverse_update_ops(self):
235    for factor in self._layers.get_factors():
236      for op in factor.make_inverse_update_ops():
237        yield op
238
239  def _get_grads_lists_gradients(self, tensors):
240    grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(),
241                                          nest.flatten(tensors))
242    grads_all = nest.pack_sequence_as(tensors, grads_flat)
243    return tuple((grad,) for grad in grads_all)
244
245  def _get_grads_lists_empirical(self, tensors):
246    grads_flat = gradients_impl.gradients(self._layers.total_loss(),
247                                          nest.flatten(tensors))
248    grads_all = nest.pack_sequence_as(tensors, grads_flat)
249    return tuple((grad,) for grad in grads_all)
250
251  def _get_transformed_random_signs(self):
252    transformed_random_signs = []
253    for loss in self._layers.losses:
254      transformed_random_signs.append(
255          loss.multiply_fisher_factor(
256              utils.generate_random_signs(loss.fisher_factor_inner_shape)))
257    return transformed_random_signs
258
259  def _get_grads_lists_curvature_prop(self, tensors):
260    loss_inputs = list(loss.inputs for loss in self._layers.losses)
261    transformed_random_signs = self._get_transformed_random_signs()
262    grads_flat = gradients_impl.gradients(
263        nest.flatten(loss_inputs),
264        nest.flatten(tensors),
265        grad_ys=nest.flatten(transformed_random_signs))
266    grads_all = nest.pack_sequence_as(tensors, grads_flat)
267    return tuple((grad,) for grad in grads_all)
268
269  def _get_grads_lists_exact(self, tensors):
270    # Loop over all coordinates of all losses.
271    grads_all = []
272    for loss in self._layers.losses:
273      for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
274        transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
275            index)
276        grads_flat = gradients_impl.gradients(
277            loss.inputs, nest.flatten(tensors), grad_ys=transformed_one_hot)
278        grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
279    return zip(*grads_all)
280