1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Special Math Ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22import numpy as np 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28 29__all__ = [ 30 "erfinv", 31 "ndtr", 32 "ndtri", 33 "log_ndtr", 34 "log_cdf_laplace", 35] 36 37 38# log_ndtr uses different functions over the ranges 39# (-infty, lower](lower, upper](upper, infty) 40# Lower bound values were chosen by examining where the support of ndtr 41# appears to be zero, relative to scipy's (which is always 64bit). They were 42# then made more conservative just to be safe. (Conservative means use the 43# expansion more than we probably need to.) See `NdtrTest` in 44# special_math_test.py. 45LOGNDTR_FLOAT64_LOWER = -20 46LOGNDTR_FLOAT32_LOWER = -10 47 48# Upper bound values were chosen by examining for which values of 'x' 49# Log[cdf(x)] is 0, after which point we need to use the approximation 50# Log[cdf(x)] = Log[1 - cdf(-x)] approx -cdf(-x). We chose a value slightly 51# conservative, meaning we use the approximation earlier than needed. 52LOGNDTR_FLOAT64_UPPER = 8 53LOGNDTR_FLOAT32_UPPER = 5 54 55 56def ndtr(x, name="ndtr"): 57 """Normal distribution function. 58 59 Returns the area under the Gaussian probability density function, integrated 60 from minus infinity to x: 61 62 ``` 63 1 / x 64 ndtr(x) = ---------- | exp(-0.5 t**2) dt 65 sqrt(2 pi) /-inf 66 67 = 0.5 (1 + erf(x / sqrt(2))) 68 = 0.5 erfc(x / sqrt(2)) 69 ``` 70 71 Args: 72 x: `Tensor` of type `float32`, `float64`. 73 name: Python string. A name for the operation (default="ndtr"). 74 75 Returns: 76 ndtr: `Tensor` with `dtype=x.dtype`. 77 78 Raises: 79 TypeError: if `x` is not floating-type. 80 """ 81 82 with ops.name_scope(name, values=[x]): 83 x = ops.convert_to_tensor(x, name="x") 84 if x.dtype.as_numpy_dtype not in [np.float32, np.float64]: 85 raise TypeError( 86 "x.dtype=%s is not handled, see docstring for supported types." 87 % x.dtype) 88 return _ndtr(x) 89 90 91def _ndtr(x): 92 """Implements ndtr core logic.""" 93 half_sqrt_2 = constant_op.constant( 94 0.5 * math.sqrt(2.), dtype=x.dtype, name="half_sqrt_2") 95 w = x * half_sqrt_2 96 z = math_ops.abs(w) 97 y = array_ops.where(math_ops.less(z, half_sqrt_2), 98 1. + math_ops.erf(w), 99 array_ops.where(math_ops.greater(w, 0.), 100 2. - math_ops.erfc(z), 101 math_ops.erfc(z))) 102 return 0.5 * y 103 104 105def ndtri(p, name="ndtri"): 106 """The inverse of the CDF of the Normal distribution function. 107 108 Returns x such that the area under the pdf from minus infinity to x is equal 109 to p. 110 111 A piece-wise rational approximation is done for the function. 112 This is a port of the implementation in netlib. 113 114 Args: 115 p: `Tensor` of type `float32`, `float64`. 116 name: Python string. A name for the operation (default="ndtri"). 117 118 Returns: 119 x: `Tensor` with `dtype=p.dtype`. 120 121 Raises: 122 TypeError: if `p` is not floating-type. 123 """ 124 125 with ops.name_scope(name, values=[p]): 126 p = ops.convert_to_tensor(p, name="p") 127 if p.dtype.as_numpy_dtype not in [np.float32, np.float64]: 128 raise TypeError( 129 "p.dtype=%s is not handled, see docstring for supported types." 130 % p.dtype) 131 return _ndtri(p) 132 133 134def _ndtri(p): 135 """Implements ndtri core logic.""" 136 137 # Constants used in piece-wise rational approximations. Taken from the cephes 138 # library: 139 # https://github.com/scipy/scipy/blob/master/scipy/special/cephes/ndtri.c 140 p0 = list(reversed([-5.99633501014107895267E1, 141 9.80010754185999661536E1, 142 -5.66762857469070293439E1, 143 1.39312609387279679503E1, 144 -1.23916583867381258016E0])) 145 q0 = list(reversed([1.0, 146 1.95448858338141759834E0, 147 4.67627912898881538453E0, 148 8.63602421390890590575E1, 149 -2.25462687854119370527E2, 150 2.00260212380060660359E2, 151 -8.20372256168333339912E1, 152 1.59056225126211695515E1, 153 -1.18331621121330003142E0])) 154 p1 = list(reversed([4.05544892305962419923E0, 155 3.15251094599893866154E1, 156 5.71628192246421288162E1, 157 4.40805073893200834700E1, 158 1.46849561928858024014E1, 159 2.18663306850790267539E0, 160 -1.40256079171354495875E-1, 161 -3.50424626827848203418E-2, 162 -8.57456785154685413611E-4])) 163 q1 = list(reversed([1.0, 164 1.57799883256466749731E1, 165 4.53907635128879210584E1, 166 4.13172038254672030440E1, 167 1.50425385692907503408E1, 168 2.50464946208309415979E0, 169 -1.42182922854787788574E-1, 170 -3.80806407691578277194E-2, 171 -9.33259480895457427372E-4])) 172 p2 = list(reversed([3.23774891776946035970E0, 173 6.91522889068984211695E0, 174 3.93881025292474443415E0, 175 1.33303460815807542389E0, 176 2.01485389549179081538E-1, 177 1.23716634817820021358E-2, 178 3.01581553508235416007E-4, 179 2.65806974686737550832E-6, 180 6.23974539184983293730E-9])) 181 q2 = list(reversed([1.0, 182 6.02427039364742014255E0, 183 3.67983563856160859403E0, 184 1.37702099489081330271E0, 185 2.16236993594496635890E-1, 186 1.34204006088543189037E-2, 187 3.28014464682127739104E-4, 188 2.89247864745380683936E-6, 189 6.79019408009981274425E-9])) 190 191 def _create_polynomial(var, coeffs): 192 """Compute n_th order polynomial via Horner's method.""" 193 if not coeffs: 194 return 0. 195 return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var 196 197 maybe_complement_p = array_ops.where(p > 1. - np.exp(-2.), 1. - p, p) 198 # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs 199 # later on. The result from the computation when p == 0 is not used so any 200 # number that doesn't result in NaNs is fine. 201 one_half = constant_op.constant(0.5, dtype=p.dtype) 202 sanitized_mcp = array_ops.where( 203 maybe_complement_p <= 0., 204 array_ops.fill(array_ops.shape(p), one_half), 205 maybe_complement_p) 206 207 # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2). 208 w = sanitized_mcp - 0.5 209 ww = w ** 2 210 x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) 211 / _create_polynomial(ww, q0)) 212 x_for_big_p *= -np.sqrt(2. * np.pi) 213 214 # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z), 215 # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different 216 # arrays based on wether p < exp(-32). 217 z = math_ops.sqrt(-2. * math_ops.log(sanitized_mcp)) 218 first_term = z - math_ops.log(z) / z 219 second_term_small_p = (_create_polynomial(1. / z, p2) 220 / _create_polynomial(1. / z, q2)) / z 221 second_term_otherwise = (_create_polynomial(1. / z, p1) 222 / _create_polynomial(1. / z, q1)) / z 223 x_for_small_p = first_term - second_term_small_p 224 x_otherwise = first_term - second_term_otherwise 225 226 x = array_ops.where(sanitized_mcp > np.exp(-2.), 227 x_for_big_p, 228 array_ops.where(z >= 8.0, x_for_small_p, x_otherwise)) 229 230 x = array_ops.where(p > 1. - np.exp(-2.), x, -x) 231 infinity_scalar = constant_op.constant(np.inf, dtype=p.dtype) 232 infinity = array_ops.fill(array_ops.shape(p), infinity_scalar) 233 x_nan_replaced = array_ops.where( 234 p <= 0.0, -infinity, array_ops.where(p >= 1.0, infinity, x)) 235 return x_nan_replaced 236 237 238def log_ndtr(x, series_order=3, name="log_ndtr"): 239 """Log Normal distribution function. 240 241 For details of the Normal distribution function see `ndtr`. 242 243 This function calculates `(log o ndtr)(x)` by either calling `log(ndtr(x))` or 244 using an asymptotic series. Specifically: 245 - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on 246 `log(1-x) ~= -x, x << 1`. 247 - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique 248 and take a log. 249 - For `x <= lower_segment`, we use the series approximation of erf to compute 250 the log CDF directly. 251 252 The `lower_segment` is set based on the precision of the input: 253 254 ``` 255 lower_segment = { -20, x.dtype=float64 256 { -10, x.dtype=float32 257 upper_segment = { 8, x.dtype=float64 258 { 5, x.dtype=float32 259 ``` 260 261 When `x < lower_segment`, the `ndtr` asymptotic series approximation is: 262 263 ``` 264 ndtr(x) = scale * (1 + sum) + R_N 265 scale = exp(-0.5 x**2) / (-x sqrt(2 pi)) 266 sum = Sum{(-1)^n (2n-1)!! / (x**2)^n, n=1:N} 267 R_N = O(exp(-0.5 x**2) (2N+1)!! / |x|^{2N+3}) 268 ``` 269 270 where `(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a 271 [double-factorial](https://en.wikipedia.org/wiki/Double_factorial). 272 273 274 Args: 275 x: `Tensor` of type `float32`, `float64`. 276 series_order: Positive Python `integer`. Maximum depth to 277 evaluate the asymptotic expansion. This is the `N` above. 278 name: Python string. A name for the operation (default="log_ndtr"). 279 280 Returns: 281 log_ndtr: `Tensor` with `dtype=x.dtype`. 282 283 Raises: 284 TypeError: if `x.dtype` is not handled. 285 TypeError: if `series_order` is a not Python `integer.` 286 ValueError: if `series_order` is not in `[0, 30]`. 287 """ 288 if not isinstance(series_order, int): 289 raise TypeError("series_order must be a Python integer.") 290 if series_order < 0: 291 raise ValueError("series_order must be non-negative.") 292 if series_order > 30: 293 raise ValueError("series_order must be <= 30.") 294 295 with ops.name_scope(name, values=[x]): 296 x = ops.convert_to_tensor(x, name="x") 297 298 if x.dtype.as_numpy_dtype == np.float64: 299 lower_segment = LOGNDTR_FLOAT64_LOWER 300 upper_segment = LOGNDTR_FLOAT64_UPPER 301 elif x.dtype.as_numpy_dtype == np.float32: 302 lower_segment = LOGNDTR_FLOAT32_LOWER 303 upper_segment = LOGNDTR_FLOAT32_UPPER 304 else: 305 raise TypeError("x.dtype=%s is not supported." % x.dtype) 306 307 # The basic idea here was ported from py/scipy/special/cephes/ndtr.c. 308 # We copy the main idea, with a few changes 309 # * For x >> 1, and X ~ Normal(0, 1), 310 # Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x], 311 # which extends the range of validity of this function. 312 # * We use one fixed series_order for all of 'x', rather than adaptive. 313 # * Our docstring properly reflects that this is an asymptotic series, not a 314 # Taylor series. We also provided a correct bound on the remainder. 315 # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when 316 # x=0. This happens even though the branch is unchosen because when x=0 317 # the gradient of a select involves the calculation 1*dy+0*(-inf)=nan 318 # regardless of whether dy is finite. Note that the minimum is a NOP if 319 # the branch is chosen. 320 return array_ops.where( 321 math_ops.greater(x, upper_segment), 322 -_ndtr(-x), # log(1-x) ~= -x, x << 1 323 array_ops.where(math_ops.greater(x, lower_segment), 324 math_ops.log(_ndtr(math_ops.maximum(x, lower_segment))), 325 _log_ndtr_lower(math_ops.minimum(x, lower_segment), 326 series_order))) 327 328 329def _log_ndtr_lower(x, series_order): 330 """Asymptotic expansion version of `Log[cdf(x)]`, appropriate for `x<<-1`.""" 331 x_2 = math_ops.square(x) 332 # Log of the term multiplying (1 + sum) 333 log_scale = -0.5 * x_2 - math_ops.log(-x) - 0.5 * math.log(2. * math.pi) 334 return log_scale + math_ops.log(_log_ndtr_asymptotic_series(x, series_order)) 335 336 337def _log_ndtr_asymptotic_series(x, series_order): 338 """Calculates the asymptotic series used in log_ndtr.""" 339 if series_order <= 0: 340 return 1. 341 x_2 = math_ops.square(x) 342 even_sum = 0. 343 odd_sum = 0. 344 x_2n = x_2 # Start with x^{2*1} = x^{2*n} with n = 1. 345 for n in range(1, series_order + 1): 346 if n % 2: 347 odd_sum += _double_factorial(2 * n - 1) / x_2n 348 else: 349 even_sum += _double_factorial(2 * n - 1) / x_2n 350 x_2n *= x_2 351 return 1. + even_sum - odd_sum 352 353 354def erfinv(x, name="erfinv"): 355 """The inverse function for erf, the error function. 356 357 Args: 358 x: `Tensor` of type `float32`, `float64`. 359 name: Python string. A name for the operation (default="erfinv"). 360 361 Returns: 362 x: `Tensor` with `dtype=x.dtype`. 363 364 Raises: 365 TypeError: if `x` is not floating-type. 366 """ 367 368 with ops.name_scope(name, values=[x]): 369 x = ops.convert_to_tensor(x, name="x") 370 if x.dtype.as_numpy_dtype not in [np.float32, np.float64]: 371 raise TypeError( 372 "x.dtype=%s is not handled, see docstring for supported types." 373 % x.dtype) 374 return ndtri((x + 1.0) / 2.0) / np.sqrt(2) 375 376 377def _double_factorial(n): 378 """The double factorial function for small Python integer `n`.""" 379 return np.prod(np.arange(n, 1, -2)) 380 381 382def log_cdf_laplace(x, name="log_cdf_laplace"): 383 """Log Laplace distribution function. 384 385 This function calculates `Log[L(x)]`, where `L(x)` is the cumulative 386 distribution function of the Laplace distribution, i.e. 387 388 ```L(x) := 0.5 * int_{-infty}^x e^{-|t|} dt``` 389 390 For numerical accuracy, `L(x)` is computed in different ways depending on `x`, 391 392 ``` 393 x <= 0: 394 Log[L(x)] = Log[0.5] + x, which is exact 395 396 0 < x: 397 Log[L(x)] = Log[1 - 0.5 * e^{-x}], which is exact 398 ``` 399 400 Args: 401 x: `Tensor` of type `float32`, `float64`. 402 name: Python string. A name for the operation (default="log_ndtr"). 403 404 Returns: 405 `Tensor` with `dtype=x.dtype`. 406 407 Raises: 408 TypeError: if `x.dtype` is not handled. 409 """ 410 411 with ops.name_scope(name, values=[x]): 412 x = ops.convert_to_tensor(x, name="x") 413 414 # For x < 0, L(x) = 0.5 * exp{x} exactly, so Log[L(x)] = log(0.5) + x. 415 lower_solution = -np.log(2.) + x 416 417 # safe_exp_neg_x = exp{-x} for x > 0, but is 418 # bounded above by 1, which avoids 419 # log[1 - 1] = -inf for x = log(1/2), AND 420 # exp{-x} --> inf, for x << -1 421 safe_exp_neg_x = math_ops.exp(-math_ops.abs(x)) 422 423 # log1p(z) = log(1 + z) approx z for |z| << 1. This approxmation is used 424 # internally by log1p, rather than being done explicitly here. 425 upper_solution = math_ops.log1p(-0.5 * safe_exp_neg_x) 426 427 return array_ops.where(x < 0., lower_solution, upper_solution) 428