1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for probabilistic layers.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib
25from tensorflow.contrib.distributions.python.ops import independent as independent_lib
26from tensorflow.python.framework import dtypes
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import init_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import nn_ops
31from tensorflow.python.ops import random_ops
32from tensorflow.python.ops.distributions import normal as normal_lib
33
34
35def default_loc_scale_fn(
36    is_singular=False,
37    loc_initializer=init_ops.random_normal_initializer(stddev=0.1),
38    untransformed_scale_initializer=init_ops.random_normal_initializer(
39        mean=-3., stddev=0.1),
40    loc_regularizer=None,
41    untransformed_scale_regularizer=None,
42    loc_constraint=None,
43    untransformed_scale_constraint=None):
44  """Makes closure which creates `loc`, `scale` params from `tf.get_variable`.
45
46  This function produces a closure which produces `loc`, `scale` using
47  `tf.get_variable`. The closure accepts the following arguments:
48
49    dtype: Type of parameter's event.
50    shape: Python `list`-like representing the parameter's event shape.
51    name: Python `str` name prepended to any created (or existing)
52      `tf.Variable`s.
53    trainable: Python `bool` indicating all created `tf.Variable`s should be
54      added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
55    add_variable_fn: `tf.get_variable`-like `callable` used to create (or
56      access existing) `tf.Variable`s.
57
58  Args:
59    is_singular: Python `bool` indicating if `scale is None`. Default: `False`.
60    loc_initializer: Initializer function for the `loc` parameters.
61      The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`.
62    untransformed_scale_initializer: Initializer function for the `scale`
63      parameters. Default value: `tf.random_normal_initializer(mean=-3.,
64      stddev=0.1)`. This implies the softplus transformed result has mean
65      approximately `0.05` and std. deviation approximately `0.005`.
66    loc_regularizer: Regularizer function for the `loc` parameters.
67      The default (`None`) is to use the `tf.get_variable` default.
68    untransformed_scale_regularizer: Regularizer function for the `scale`
69      parameters. The default (`None`) is to use the `tf.get_variable` default.
70    loc_constraint: An optional projection function to be applied to the
71      loc after being updated by an `Optimizer`. The function must take as input
72      the unprojected variable and must return the projected variable (which
73      must have the same shape). Constraints are not safe to use when doing
74      asynchronous distributed training.
75      The default (`None`) is to use the `tf.get_variable` default.
76    untransformed_scale_constraint: An optional projection function to be
77      applied to the `scale` parameters after being updated by an `Optimizer`
78      (e.g. used to implement norm constraints or value constraints). The
79      function must take as input the unprojected variable and must return the
80      projected variable (which must have the same shape). Constraints are not
81      safe to use when doing asynchronous distributed training. The default
82      (`None`) is to use the `tf.get_variable` default.
83
84  Returns:
85    default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale`
86    parameters from args: `dtype, shape, name, trainable, add_variable_fn`.
87  """
88  def _fn(dtype, shape, name, trainable, add_variable_fn):
89    """Creates `loc`, `scale` parameters."""
90    loc = add_variable_fn(
91        name=name + "_loc",
92        shape=shape,
93        initializer=loc_initializer,
94        regularizer=loc_regularizer,
95        constraint=loc_constraint,
96        dtype=dtype,
97        trainable=trainable)
98    if is_singular:
99      return loc, None
100    untransformed_scale = add_variable_fn(
101        name=name + "_untransformed_scale",
102        shape=shape,
103        initializer=untransformed_scale_initializer,
104        regularizer=untransformed_scale_regularizer,
105        constraint=untransformed_scale_constraint,
106        dtype=dtype,
107        trainable=trainable)
108    scale = (np.finfo(dtype.as_numpy_dtype).eps +
109             nn_ops.softplus(untransformed_scale))
110    return loc, scale
111  return _fn
112
113
114def default_mean_field_normal_fn(
115    is_singular=False,
116    loc_initializer=None,
117    untransformed_scale_initializer=None,
118    loc_regularizer=None,
119    untransformed_scale_regularizer=None,
120    loc_constraint=None,
121    untransformed_scale_constraint=None):
122  """Creates a function to build Normal distributions with trainable params.
123
124  This function produces a closure which produces `tf.distributions.Normal`
125  parameterized by a loc` and `scale` each created using `tf.get_variable`. The
126  produced closure accepts the following arguments:
127
128    name: Python `str` name prepended to any created (or existing)
129      `tf.Variable`s.
130    shape: Python `list`-like representing the parameter's event shape.
131    dtype: Type of parameter's event.
132    trainable: Python `bool` indicating all created `tf.Variable`s should be
133      added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`.
134    add_variable_fn: `tf.get_variable`-like `callable` used to create (or
135      access existing) `tf.Variable`s.
136
137  Args:
138    is_singular: Python `bool` if `True`, forces the special case limit of
139      `scale->0`, i.e., a `Deterministic` distribution.
140    loc_initializer: Initializer function for the `loc` parameters.
141      If `None` (default), values are initialized using the default
142      initializer used by `tf.get_variable`.
143    untransformed_scale_initializer: Initializer function for the `scale`
144      parameters. If `None` (default), values are initialized using the default
145      initializer used by `tf.get_variable`.
146    loc_regularizer: Regularizer function for the `loc` parameters.
147    untransformed_scale_regularizer: Regularizer function for the `scale`
148      parameters.
149    loc_constraint: An optional projection function to be applied to the
150      loc after being updated by an `Optimizer`. The function must take as input
151      the unprojected variable and must return the projected variable (which
152      must have the same shape). Constraints are not safe to use when doing
153      asynchronous distributed training.
154    untransformed_scale_constraint: An optional projection function to be
155      applied to the `scale` parameters after being updated by an `Optimizer`
156      (e.g. used to implement norm constraints or value constraints). The
157      function must take as input the unprojected variable and must return the
158      projected variable (which must have the same shape). Constraints are not
159      safe to use when doing asynchronous distributed training.
160
161  Returns:
162    make_normal_fn: Python `callable` which creates a `tf.distributions.Normal`
163      using from args: `dtype, shape, name, trainable, add_variable_fn`.
164  """
165  loc_scale_fn_ = default_loc_scale_fn(
166      is_singular,
167      loc_initializer,
168      untransformed_scale_initializer,
169      loc_regularizer,
170      untransformed_scale_regularizer,
171      loc_constraint,
172      untransformed_scale_constraint)
173  def _fn(dtype, shape, name, trainable, add_variable_fn):
174    """Creates multivariate `Deterministic` or `Normal` distribution."""
175    loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn)
176    if scale is None:
177      dist = deterministic_lib.Deterministic(loc=loc)
178    else:
179      dist = normal_lib.Normal(loc=loc, scale=scale)
180    reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0]
181    return independent_lib.Independent(
182        dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims)
183  return _fn
184
185
186def random_sign(shape, dtype=dtypes.float32, seed=None):
187  """Draw values from {-1, 1} uniformly, i.e., Rademacher distribution."""
188  random_bernoulli = random_ops.random_uniform(shape, minval=0, maxval=2,
189                                               dtype=dtypes.int32,
190                                               seed=seed)
191  return math_ops.cast(2 * random_bernoulli - 1, dtype)
192