estimator.py revision de1b4a8a75ae3a50f4fa7480efb1177d79abf553
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Defines the high-level Fisher estimator class.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22 23import numpy as np 24 25from tensorflow.contrib.kfac.python.ops import utils 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import gradients_impl 28from tensorflow.python.util import nest 29 30 31class FisherEstimator(object): 32 """Fisher estimator class supporting various approximations of the Fisher.""" 33 34 def __init__(self, 35 variables, 36 cov_ema_decay, 37 damping, 38 layer_collection, 39 estimation_mode="gradients"): 40 """Create a FisherEstimator object. 41 42 Args: 43 variables: A list of the variables for which to estimate the Fisher. This 44 must match the variables registered in layer_collection (if it is not 45 None). 46 cov_ema_decay: The decay factor used when calculating the covariance 47 estimate moving averages. 48 damping: The damping factor used to stabilize training due to errors in 49 the local approximation with the Fisher information matrix, and to 50 regularize the update direction by making it closer to the gradient. 51 (Higher damping means the update looks more like a standard gradient 52 update - see Tikhonov regularization.) 53 layer_collection: The layer collection object, which holds the fisher 54 blocks, kronecker factors, and losses associated with the 55 graph. 56 estimation_mode: The type of estimator to use for the Fishers. Can be 57 'gradients', 'empirical', 'curvature_propagation', or 'exact'. 58 (Default: 'gradients'). 'gradients' is the basic estimation approach 59 from the original K-FAC paper. 'empirical' computes the 'empirical' 60 Fisher information matrix (which uses the data's distribution for the 61 targets, as opposed to the true Fisher which uses the model's 62 distribution) and requires that each registered loss have specified 63 targets. 'curvature_propagation' is a method which estimates the 64 Fisher using self-products of random 1/-1 vectors times "half-factors" 65 of the Fisher, as described here: https://arxiv.org/abs/1206.6464 . 66 Finally, 'exact' is the obvious generalization of Curvature 67 Propagation to compute the exact Fisher (modulo any additional 68 diagonal or Kronecker approximations) by looping over one-hot vectors 69 for each coordinate of the output instead of using 1/-1 vectors. It 70 is more expensive to compute than the other three options by a factor 71 equal to the output dimension, roughly speaking. 72 73 Raises: 74 ValueError: If no losses have been registered with layer_collection. 75 """ 76 77 self._variables = variables 78 self._damping = damping 79 self._estimation_mode = estimation_mode 80 self._layers = layer_collection 81 self._layers.create_subgraph() 82 self._check_registration(variables) 83 self._gradient_fns = { 84 "gradients": self._get_grads_lists_gradients, 85 "empirical": self._get_grads_lists_empirical, 86 "curvature_prop": self._get_grads_lists_curvature_prop, 87 "exact": self._get_grads_lists_exact 88 } 89 setup = self._setup(cov_ema_decay) 90 self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup 91 92 @property 93 def variables(self): 94 return self._variables 95 96 @property 97 def damping(self): 98 return self._damping 99 100 def _apply_transformation(self, vecs_and_vars, transform): 101 """Applies an block-wise transformation to the corresponding vectors. 102 103 Args: 104 vecs_and_vars: List of (vector, variable) pairs. 105 transform: A function of the form f(fb, vec), where vec is the vector 106 to transform and fb is its corresponding block in the matrix, that 107 returns the transformed vector. 108 109 Returns: 110 A list of (transformed vector, var) pairs in the same order as 111 vecs_and_vars. 112 """ 113 114 vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars) 115 116 trans_vecs = utils.SequenceDict() 117 118 for params, fb in self._layers.fisher_blocks.items(): 119 trans_vecs[params] = transform(fb, vecs[params]) 120 121 return [(trans_vecs[var], var) for _, var in vecs_and_vars] 122 123 def multiply_inverse(self, vecs_and_vars): 124 """Multiplies the vecs by the corresponding (damped) inverses of the blocks. 125 126 Args: 127 vecs_and_vars: List of (vector, variable) pairs. 128 129 Returns: 130 A list of (transformed vector, var) pairs in the same order as 131 vecs_and_vars. 132 """ 133 134 return self._apply_transformation(vecs_and_vars, 135 lambda fb, vec: fb.multiply_inverse(vec)) 136 137 def multiply(self, vecs_and_vars): 138 """Multiplies the vectors by the corresponding (damped) blocks. 139 140 Args: 141 vecs_and_vars: List of (vector, variable) pairs. 142 143 Returns: 144 A list of (transformed vector, var) pairs in the same order as 145 vecs_and_vars. 146 """ 147 148 return self._apply_transformation(vecs_and_vars, 149 lambda fb, vec: fb.multiply(vec)) 150 151 def _check_registration(self, variables): 152 """Checks that all variable uses have been registered properly. 153 154 Args: 155 variables: List of variables. 156 157 Raises: 158 ValueError: If any registered variables are not included in the list. 159 ValueError: If any variable in the list is not registered. 160 ValueError: If any variable in the list is registered with the wrong 161 number of "uses" in the subgraph recorded (vs the number of times that 162 variable is actually used in the subgraph). 163 """ 164 # Note that overlapping parameters (i.e. those that share variables) will 165 # be caught by layer_collection.LayerParametersDict during registration. 166 167 reg_use_map = self._layers.get_use_count_map() 168 169 error_messages = [] 170 171 for var in variables: 172 total_uses = self._layers.subgraph.variable_uses(var) 173 reg_uses = reg_use_map[var] 174 175 if reg_uses == 0: 176 error_messages.append("Variable {} not registered.".format(var)) 177 elif (not math.isinf(reg_uses)) and reg_uses != total_uses: 178 error_messages.append( 179 "Variable {} registered with wrong number of uses ({} " 180 "vs {} actual).".format(var, reg_uses, total_uses)) 181 182 num_get_vars = len(reg_use_map) 183 184 if num_get_vars > len(variables): 185 error_messages.append("{} registered variables were not included in list." 186 .format(num_get_vars - len(variables))) 187 188 if error_messages: 189 error_messages = [ 190 "Found the following errors with variable registration:" 191 ] + error_messages 192 raise ValueError("\n\t".join(error_messages)) 193 194 def _setup(self, cov_ema_decay): 195 """Sets up the various operations. 196 197 Args: 198 cov_ema_decay: The decay factor used when calculating the covariance 199 estimate moving averages. 200 201 Returns: 202 A triple (covs_update_op, invs_update_op, inv_updates_dict), where 203 covs_update_op is the grouped Op to update all the covariance estimates, 204 invs_update_op is the grouped Op to update all the inverses, and 205 inv_updates_dict is a dict mapping Op names to individual inverse updates. 206 207 Raises: 208 ValueError: If estimation_mode was improperly specified at construction. 209 """ 210 fisher_blocks_list = self._layers.get_blocks() 211 tensors_to_compute_grads = [ 212 fb.tensors_to_compute_grads() for fb in fisher_blocks_list 213 ] 214 215 try: 216 grads_lists = self._gradient_fns[self._estimation_mode]( 217 tensors_to_compute_grads) 218 except KeyError: 219 raise ValueError("Unrecognized value {} for estimation_mode.".format( 220 self._estimation_mode)) 221 222 for grads_list, fb in zip(grads_lists, fisher_blocks_list): 223 fb.instantiate_factors(grads_list, self.damping) 224 225 cov_updates = [ 226 factor.make_covariance_update_op(cov_ema_decay) 227 for factor in self._layers.get_factors() 228 ] 229 inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()} 230 231 return control_flow_ops.group(*cov_updates), control_flow_ops.group( 232 *inv_updates.values()), inv_updates 233 234 def _get_all_inverse_update_ops(self): 235 for factor in self._layers.get_factors(): 236 for op in factor.make_inverse_update_ops(): 237 yield op 238 239 def _get_grads_lists_gradients(self, tensors): 240 grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(), 241 nest.flatten(tensors)) 242 grads_all = nest.pack_sequence_as(tensors, grads_flat) 243 return tuple((grad,) for grad in grads_all) 244 245 def _get_grads_lists_empirical(self, tensors): 246 grads_flat = gradients_impl.gradients(self._layers.total_loss(), 247 nest.flatten(tensors)) 248 grads_all = nest.pack_sequence_as(tensors, grads_flat) 249 return tuple((grad,) for grad in grads_all) 250 251 def _get_transformed_random_signs(self): 252 transformed_random_signs = [] 253 for loss in self._layers.losses: 254 transformed_random_signs.append( 255 loss.multiply_fisher_factor( 256 utils.generate_random_signs(loss.fisher_factor_inner_shape))) 257 return transformed_random_signs 258 259 def _get_grads_lists_curvature_prop(self, tensors): 260 loss_inputs = list(loss.inputs for loss in self._layers.losses) 261 transformed_random_signs = self._get_transformed_random_signs() 262 grads_flat = gradients_impl.gradients( 263 nest.flatten(loss_inputs), 264 nest.flatten(tensors), 265 grad_ys=nest.flatten(transformed_random_signs)) 266 grads_all = nest.pack_sequence_as(tensors, grads_flat) 267 return tuple((grad,) for grad in grads_all) 268 269 def _get_grads_lists_exact(self, tensors): 270 # Loop over all coordinates of all losses. 271 grads_all = [] 272 for loss in self._layers.losses: 273 for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): 274 transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( 275 index) 276 grads_flat = gradients_impl.gradients( 277 loss.inputs, nest.flatten(tensors), grad_ys=transformed_one_hot) 278 grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) 279 return zip(*grads_all) 280