1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Special Math Ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22import numpy as np
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28
29__all__ = [
30    "erfinv",
31    "ndtr",
32    "ndtri",
33    "log_ndtr",
34    "log_cdf_laplace",
35]
36
37
38# log_ndtr uses different functions over the ranges
39# (-infty, lower](lower, upper](upper, infty)
40# Lower bound values were chosen by examining where the support of ndtr
41# appears to be zero, relative to scipy's (which is always 64bit). They were
42# then made more conservative just to be safe. (Conservative means use the
43# expansion more than we probably need to.) See `NdtrTest` in
44# special_math_test.py.
45LOGNDTR_FLOAT64_LOWER = -20
46LOGNDTR_FLOAT32_LOWER = -10
47
48# Upper bound values were chosen by examining for which values of 'x'
49# Log[cdf(x)] is 0, after which point we need to use the approximation
50# Log[cdf(x)] = Log[1 - cdf(-x)] approx -cdf(-x). We chose a value slightly
51# conservative, meaning we use the approximation earlier than needed.
52LOGNDTR_FLOAT64_UPPER = 8
53LOGNDTR_FLOAT32_UPPER = 5
54
55
56def ndtr(x, name="ndtr"):
57  """Normal distribution function.
58
59  Returns the area under the Gaussian probability density function, integrated
60  from minus infinity to x:
61
62  ```
63                    1       / x
64     ndtr(x)  = ----------  |    exp(-0.5 t**2) dt
65                sqrt(2 pi)  /-inf
66
67              = 0.5 (1 + erf(x / sqrt(2)))
68              = 0.5 erfc(x / sqrt(2))
69  ```
70
71  Args:
72    x: `Tensor` of type `float32`, `float64`.
73    name: Python string. A name for the operation (default="ndtr").
74
75  Returns:
76    ndtr: `Tensor` with `dtype=x.dtype`.
77
78  Raises:
79    TypeError: if `x` is not floating-type.
80  """
81
82  with ops.name_scope(name, values=[x]):
83    x = ops.convert_to_tensor(x, name="x")
84    if x.dtype.as_numpy_dtype not in [np.float32, np.float64]:
85      raise TypeError(
86          "x.dtype=%s is not handled, see docstring for supported types."
87          % x.dtype)
88    return _ndtr(x)
89
90
91def _ndtr(x):
92  """Implements ndtr core logic."""
93  half_sqrt_2 = constant_op.constant(
94      0.5 * math.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
95  w = x * half_sqrt_2
96  z = math_ops.abs(w)
97  y = array_ops.where(math_ops.less(z, half_sqrt_2),
98                      1. + math_ops.erf(w),
99                      array_ops.where(math_ops.greater(w, 0.),
100                                      2. - math_ops.erfc(z),
101                                      math_ops.erfc(z)))
102  return 0.5 * y
103
104
105def ndtri(p, name="ndtri"):
106  """The inverse of the CDF of the Normal distribution function.
107
108  Returns x such that the area under the pdf from minus infinity to x is equal
109  to p.
110
111  A piece-wise rational approximation is done for the function.
112  This is a port of the implementation in netlib.
113
114  Args:
115    p: `Tensor` of type `float32`, `float64`.
116    name: Python string. A name for the operation (default="ndtri").
117
118  Returns:
119    x: `Tensor` with `dtype=p.dtype`.
120
121  Raises:
122    TypeError: if `p` is not floating-type.
123  """
124
125  with ops.name_scope(name, values=[p]):
126    p = ops.convert_to_tensor(p, name="p")
127    if p.dtype.as_numpy_dtype not in [np.float32, np.float64]:
128      raise TypeError(
129          "p.dtype=%s is not handled, see docstring for supported types."
130          % p.dtype)
131    return _ndtri(p)
132
133
134def _ndtri(p):
135  """Implements ndtri core logic."""
136
137  # Constants used in piece-wise rational approximations. Taken from the cephes
138  # library:
139  # https://github.com/scipy/scipy/blob/master/scipy/special/cephes/ndtri.c
140  p0 = list(reversed([-5.99633501014107895267E1,
141                      9.80010754185999661536E1,
142                      -5.66762857469070293439E1,
143                      1.39312609387279679503E1,
144                      -1.23916583867381258016E0]))
145  q0 = list(reversed([1.0,
146                      1.95448858338141759834E0,
147                      4.67627912898881538453E0,
148                      8.63602421390890590575E1,
149                      -2.25462687854119370527E2,
150                      2.00260212380060660359E2,
151                      -8.20372256168333339912E1,
152                      1.59056225126211695515E1,
153                      -1.18331621121330003142E0]))
154  p1 = list(reversed([4.05544892305962419923E0,
155                      3.15251094599893866154E1,
156                      5.71628192246421288162E1,
157                      4.40805073893200834700E1,
158                      1.46849561928858024014E1,
159                      2.18663306850790267539E0,
160                      -1.40256079171354495875E-1,
161                      -3.50424626827848203418E-2,
162                      -8.57456785154685413611E-4]))
163  q1 = list(reversed([1.0,
164                      1.57799883256466749731E1,
165                      4.53907635128879210584E1,
166                      4.13172038254672030440E1,
167                      1.50425385692907503408E1,
168                      2.50464946208309415979E0,
169                      -1.42182922854787788574E-1,
170                      -3.80806407691578277194E-2,
171                      -9.33259480895457427372E-4]))
172  p2 = list(reversed([3.23774891776946035970E0,
173                      6.91522889068984211695E0,
174                      3.93881025292474443415E0,
175                      1.33303460815807542389E0,
176                      2.01485389549179081538E-1,
177                      1.23716634817820021358E-2,
178                      3.01581553508235416007E-4,
179                      2.65806974686737550832E-6,
180                      6.23974539184983293730E-9]))
181  q2 = list(reversed([1.0,
182                      6.02427039364742014255E0,
183                      3.67983563856160859403E0,
184                      1.37702099489081330271E0,
185                      2.16236993594496635890E-1,
186                      1.34204006088543189037E-2,
187                      3.28014464682127739104E-4,
188                      2.89247864745380683936E-6,
189                      6.79019408009981274425E-9]))
190
191  def _create_polynomial(var, coeffs):
192    """Compute n_th order polynomial via Horner's method."""
193    if not coeffs:
194      return 0.
195    return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var
196
197  maybe_complement_p = array_ops.where(p > 1. - np.exp(-2.), 1. - p, p)
198  # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
199  # later on. The result from the computation when p == 0 is not used so any
200  # number that doesn't result in NaNs is fine.
201  one_half = constant_op.constant(0.5, dtype=p.dtype)
202  sanitized_mcp = array_ops.where(
203      maybe_complement_p <= 0.,
204      array_ops.fill(array_ops.shape(p), one_half),
205      maybe_complement_p)
206
207  # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
208  w = sanitized_mcp - 0.5
209  ww = w ** 2
210  x_for_big_p = w + w * ww * (_create_polynomial(ww, p0)
211                              / _create_polynomial(ww, q0))
212  x_for_big_p *= -np.sqrt(2. * np.pi)
213
214  # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
215  # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
216  # arrays based on wether p < exp(-32).
217  z = math_ops.sqrt(-2. * math_ops.log(sanitized_mcp))
218  first_term = z - math_ops.log(z) / z
219  second_term_small_p = (_create_polynomial(1. / z, p2)
220                         / _create_polynomial(1. / z, q2)) / z
221  second_term_otherwise = (_create_polynomial(1. / z, p1)
222                           / _create_polynomial(1. / z, q1)) / z
223  x_for_small_p = first_term - second_term_small_p
224  x_otherwise = first_term - second_term_otherwise
225
226  x = array_ops.where(sanitized_mcp > np.exp(-2.),
227                      x_for_big_p,
228                      array_ops.where(z >= 8.0, x_for_small_p, x_otherwise))
229
230  x = array_ops.where(p > 1. - np.exp(-2.), x, -x)
231  infinity_scalar = constant_op.constant(np.inf, dtype=p.dtype)
232  infinity = array_ops.fill(array_ops.shape(p), infinity_scalar)
233  x_nan_replaced = array_ops.where(
234      p <= 0.0, -infinity, array_ops.where(p >= 1.0, infinity, x))
235  return x_nan_replaced
236
237
238def log_ndtr(x, series_order=3, name="log_ndtr"):
239  """Log Normal distribution function.
240
241  For details of the Normal distribution function see `ndtr`.
242
243  This function calculates `(log o ndtr)(x)` by either calling `log(ndtr(x))` or
244  using an asymptotic series. Specifically:
245  - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on
246    `log(1-x) ~= -x, x << 1`.
247  - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique
248    and take a log.
249  - For `x <= lower_segment`, we use the series approximation of erf to compute
250    the log CDF directly.
251
252  The `lower_segment` is set based on the precision of the input:
253
254  ```
255  lower_segment = { -20,  x.dtype=float64
256                  { -10,  x.dtype=float32
257  upper_segment = {   8,  x.dtype=float64
258                  {   5,  x.dtype=float32
259  ```
260
261  When `x < lower_segment`, the `ndtr` asymptotic series approximation is:
262
263  ```
264     ndtr(x) = scale * (1 + sum) + R_N
265     scale   = exp(-0.5 x**2) / (-x sqrt(2 pi))
266     sum     = Sum{(-1)^n (2n-1)!! / (x**2)^n, n=1:N}
267     R_N     = O(exp(-0.5 x**2) (2N+1)!! / |x|^{2N+3})
268  ```
269
270  where `(2n-1)!! = (2n-1) (2n-3) (2n-5) ...  (3) (1)` is a
271  [double-factorial](https://en.wikipedia.org/wiki/Double_factorial).
272
273
274  Args:
275    x: `Tensor` of type `float32`, `float64`.
276    series_order: Positive Python `integer`. Maximum depth to
277      evaluate the asymptotic expansion. This is the `N` above.
278    name: Python string. A name for the operation (default="log_ndtr").
279
280  Returns:
281    log_ndtr: `Tensor` with `dtype=x.dtype`.
282
283  Raises:
284    TypeError: if `x.dtype` is not handled.
285    TypeError: if `series_order` is a not Python `integer.`
286    ValueError:  if `series_order` is not in `[0, 30]`.
287  """
288  if not isinstance(series_order, int):
289    raise TypeError("series_order must be a Python integer.")
290  if series_order < 0:
291    raise ValueError("series_order must be non-negative.")
292  if series_order > 30:
293    raise ValueError("series_order must be <= 30.")
294
295  with ops.name_scope(name, values=[x]):
296    x = ops.convert_to_tensor(x, name="x")
297
298    if x.dtype.as_numpy_dtype == np.float64:
299      lower_segment = LOGNDTR_FLOAT64_LOWER
300      upper_segment = LOGNDTR_FLOAT64_UPPER
301    elif x.dtype.as_numpy_dtype == np.float32:
302      lower_segment = LOGNDTR_FLOAT32_LOWER
303      upper_segment = LOGNDTR_FLOAT32_UPPER
304    else:
305      raise TypeError("x.dtype=%s is not supported." % x.dtype)
306
307    # The basic idea here was ported from py/scipy/special/cephes/ndtr.c.
308    # We copy the main idea, with a few changes
309    # * For x >> 1, and X ~ Normal(0, 1),
310    #     Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
311    #     which extends the range of validity of this function.
312    # * We use one fixed series_order for all of 'x', rather than adaptive.
313    # * Our docstring properly reflects that this is an asymptotic series, not a
314    #   Taylor series. We also provided a correct bound on the remainder.
315    # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when
316    #   x=0. This happens even though the branch is unchosen because when x=0
317    #   the gradient of a select involves the calculation 1*dy+0*(-inf)=nan
318    #   regardless of whether dy is finite. Note that the minimum is a NOP if
319    #   the branch is chosen.
320    return array_ops.where(
321        math_ops.greater(x, upper_segment),
322        -_ndtr(-x),  # log(1-x) ~= -x, x << 1
323        array_ops.where(math_ops.greater(x, lower_segment),
324                        math_ops.log(_ndtr(math_ops.maximum(x, lower_segment))),
325                        _log_ndtr_lower(math_ops.minimum(x, lower_segment),
326                                        series_order)))
327
328
329def _log_ndtr_lower(x, series_order):
330  """Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`."""
331  x_2 = math_ops.square(x)
332  # Log of the term multiplying (1 + sum)
333  log_scale = -0.5 * x_2 - math_ops.log(-x) - 0.5 * math.log(2. * math.pi)
334  return log_scale + math_ops.log(_log_ndtr_asymptotic_series(x, series_order))
335
336
337def _log_ndtr_asymptotic_series(x, series_order):
338  """Calculates the asymptotic series used in log_ndtr."""
339  if series_order <= 0:
340    return 1.
341  x_2 = math_ops.square(x)
342  even_sum = 0.
343  odd_sum = 0.
344  x_2n = x_2  # Start with x^{2*1} = x^{2*n} with n = 1.
345  for n in range(1, series_order + 1):
346    if n % 2:
347      odd_sum += _double_factorial(2 * n - 1) / x_2n
348    else:
349      even_sum += _double_factorial(2 * n - 1) / x_2n
350    x_2n *= x_2
351  return 1. + even_sum - odd_sum
352
353
354def erfinv(x, name="erfinv"):
355  """The inverse function for erf, the error function.
356
357  Args:
358    x: `Tensor` of type `float32`, `float64`.
359    name: Python string. A name for the operation (default="erfinv").
360
361  Returns:
362    x: `Tensor` with `dtype=x.dtype`.
363
364  Raises:
365    TypeError: if `x` is not floating-type.
366  """
367
368  with ops.name_scope(name, values=[x]):
369    x = ops.convert_to_tensor(x, name="x")
370    if x.dtype.as_numpy_dtype not in [np.float32, np.float64]:
371      raise TypeError(
372          "x.dtype=%s is not handled, see docstring for supported types."
373          % x.dtype)
374    return ndtri((x + 1.0) / 2.0) / np.sqrt(2)
375
376
377def _double_factorial(n):
378  """The double factorial function for small Python integer `n`."""
379  return np.prod(np.arange(n, 1, -2))
380
381
382def log_cdf_laplace(x, name="log_cdf_laplace"):
383  """Log Laplace distribution function.
384
385  This function calculates `Log[L(x)]`, where `L(x)` is the cumulative
386  distribution function of the Laplace distribution, i.e.
387
388  ```L(x) := 0.5 * int_{-infty}^x e^{-|t|} dt```
389
390  For numerical accuracy, `L(x)` is computed in different ways depending on `x`,
391
392  ```
393  x <= 0:
394    Log[L(x)] = Log[0.5] + x, which is exact
395
396  0 < x:
397    Log[L(x)] = Log[1 - 0.5 * e^{-x}], which is exact
398  ```
399
400  Args:
401    x: `Tensor` of type `float32`, `float64`.
402    name: Python string. A name for the operation (default="log_ndtr").
403
404  Returns:
405    `Tensor` with `dtype=x.dtype`.
406
407  Raises:
408    TypeError: if `x.dtype` is not handled.
409  """
410
411  with ops.name_scope(name, values=[x]):
412    x = ops.convert_to_tensor(x, name="x")
413
414    # For x < 0, L(x) = 0.5 * exp{x} exactly, so Log[L(x)] = log(0.5) + x.
415    lower_solution = -np.log(2.) + x
416
417    # safe_exp_neg_x = exp{-x} for x > 0, but is
418    # bounded above by 1, which avoids
419    #   log[1 - 1] = -inf for x = log(1/2), AND
420    #   exp{-x} --> inf, for x << -1
421    safe_exp_neg_x = math_ops.exp(-math_ops.abs(x))
422
423    # log1p(z) = log(1 + z) approx z for |z| << 1. This approxmation is used
424    # internally by log1p, rather than being done explicitly here.
425    upper_solution = math_ops.log1p(-0.5 * safe_exp_neg_x)
426
427    return array_ops.where(x < 0., lower_solution, upper_solution)
428