1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for probabilistic layers. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib 25from tensorflow.contrib.distributions.python.ops import independent as independent_lib 26from tensorflow.python.framework import dtypes 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import init_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import nn_ops 31from tensorflow.python.ops import random_ops 32from tensorflow.python.ops.distributions import normal as normal_lib 33 34 35def default_loc_scale_fn( 36 is_singular=False, 37 loc_initializer=init_ops.random_normal_initializer(stddev=0.1), 38 untransformed_scale_initializer=init_ops.random_normal_initializer( 39 mean=-3., stddev=0.1), 40 loc_regularizer=None, 41 untransformed_scale_regularizer=None, 42 loc_constraint=None, 43 untransformed_scale_constraint=None): 44 """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. 45 46 This function produces a closure which produces `loc`, `scale` using 47 `tf.get_variable`. The closure accepts the following arguments: 48 49 dtype: Type of parameter's event. 50 shape: Python `list`-like representing the parameter's event shape. 51 name: Python `str` name prepended to any created (or existing) 52 `tf.Variable`s. 53 trainable: Python `bool` indicating all created `tf.Variable`s should be 54 added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. 55 add_variable_fn: `tf.get_variable`-like `callable` used to create (or 56 access existing) `tf.Variable`s. 57 58 Args: 59 is_singular: Python `bool` indicating if `scale is None`. Default: `False`. 60 loc_initializer: Initializer function for the `loc` parameters. 61 The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. 62 untransformed_scale_initializer: Initializer function for the `scale` 63 parameters. Default value: `tf.random_normal_initializer(mean=-3., 64 stddev=0.1)`. This implies the softplus transformed result has mean 65 approximately `0.05` and std. deviation approximately `0.005`. 66 loc_regularizer: Regularizer function for the `loc` parameters. 67 The default (`None`) is to use the `tf.get_variable` default. 68 untransformed_scale_regularizer: Regularizer function for the `scale` 69 parameters. The default (`None`) is to use the `tf.get_variable` default. 70 loc_constraint: An optional projection function to be applied to the 71 loc after being updated by an `Optimizer`. The function must take as input 72 the unprojected variable and must return the projected variable (which 73 must have the same shape). Constraints are not safe to use when doing 74 asynchronous distributed training. 75 The default (`None`) is to use the `tf.get_variable` default. 76 untransformed_scale_constraint: An optional projection function to be 77 applied to the `scale` parameters after being updated by an `Optimizer` 78 (e.g. used to implement norm constraints or value constraints). The 79 function must take as input the unprojected variable and must return the 80 projected variable (which must have the same shape). Constraints are not 81 safe to use when doing asynchronous distributed training. The default 82 (`None`) is to use the `tf.get_variable` default. 83 84 Returns: 85 default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` 86 parameters from args: `dtype, shape, name, trainable, add_variable_fn`. 87 """ 88 def _fn(dtype, shape, name, trainable, add_variable_fn): 89 """Creates `loc`, `scale` parameters.""" 90 loc = add_variable_fn( 91 name=name + "_loc", 92 shape=shape, 93 initializer=loc_initializer, 94 regularizer=loc_regularizer, 95 constraint=loc_constraint, 96 dtype=dtype, 97 trainable=trainable) 98 if is_singular: 99 return loc, None 100 untransformed_scale = add_variable_fn( 101 name=name + "_untransformed_scale", 102 shape=shape, 103 initializer=untransformed_scale_initializer, 104 regularizer=untransformed_scale_regularizer, 105 constraint=untransformed_scale_constraint, 106 dtype=dtype, 107 trainable=trainable) 108 scale = (np.finfo(dtype.as_numpy_dtype).eps + 109 nn_ops.softplus(untransformed_scale)) 110 return loc, scale 111 return _fn 112 113 114def default_mean_field_normal_fn( 115 is_singular=False, 116 loc_initializer=None, 117 untransformed_scale_initializer=None, 118 loc_regularizer=None, 119 untransformed_scale_regularizer=None, 120 loc_constraint=None, 121 untransformed_scale_constraint=None): 122 """Creates a function to build Normal distributions with trainable params. 123 124 This function produces a closure which produces `tf.distributions.Normal` 125 parameterized by a loc` and `scale` each created using `tf.get_variable`. The 126 produced closure accepts the following arguments: 127 128 name: Python `str` name prepended to any created (or existing) 129 `tf.Variable`s. 130 shape: Python `list`-like representing the parameter's event shape. 131 dtype: Type of parameter's event. 132 trainable: Python `bool` indicating all created `tf.Variable`s should be 133 added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. 134 add_variable_fn: `tf.get_variable`-like `callable` used to create (or 135 access existing) `tf.Variable`s. 136 137 Args: 138 is_singular: Python `bool` if `True`, forces the special case limit of 139 `scale->0`, i.e., a `Deterministic` distribution. 140 loc_initializer: Initializer function for the `loc` parameters. 141 If `None` (default), values are initialized using the default 142 initializer used by `tf.get_variable`. 143 untransformed_scale_initializer: Initializer function for the `scale` 144 parameters. If `None` (default), values are initialized using the default 145 initializer used by `tf.get_variable`. 146 loc_regularizer: Regularizer function for the `loc` parameters. 147 untransformed_scale_regularizer: Regularizer function for the `scale` 148 parameters. 149 loc_constraint: An optional projection function to be applied to the 150 loc after being updated by an `Optimizer`. The function must take as input 151 the unprojected variable and must return the projected variable (which 152 must have the same shape). Constraints are not safe to use when doing 153 asynchronous distributed training. 154 untransformed_scale_constraint: An optional projection function to be 155 applied to the `scale` parameters after being updated by an `Optimizer` 156 (e.g. used to implement norm constraints or value constraints). The 157 function must take as input the unprojected variable and must return the 158 projected variable (which must have the same shape). Constraints are not 159 safe to use when doing asynchronous distributed training. 160 161 Returns: 162 make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` 163 using from args: `dtype, shape, name, trainable, add_variable_fn`. 164 """ 165 loc_scale_fn_ = default_loc_scale_fn( 166 is_singular, 167 loc_initializer, 168 untransformed_scale_initializer, 169 loc_regularizer, 170 untransformed_scale_regularizer, 171 loc_constraint, 172 untransformed_scale_constraint) 173 def _fn(dtype, shape, name, trainable, add_variable_fn): 174 """Creates multivariate `Deterministic` or `Normal` distribution.""" 175 loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) 176 if scale is None: 177 dist = deterministic_lib.Deterministic(loc=loc) 178 else: 179 dist = normal_lib.Normal(loc=loc, scale=scale) 180 reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] 181 return independent_lib.Independent( 182 dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims) 183 return _fn 184 185 186def random_sign(shape, dtype=dtypes.float32, seed=None): 187 """Draw values from {-1, 1} uniformly, i.e., Rademacher distribution.""" 188 random_bernoulli = random_ops.random_uniform(shape, minval=0, maxval=2, 189 dtype=dtypes.int32, 190 seed=seed) 191 return math_ops.cast(2 * random_bernoulli - 1, dtype) 192