1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#include "main.h"
11#include "../Eigen/SpecialFunctions"
12
13template<typename X, typename Y>
14void verify_component_wise(const X& x, const Y& y)
15{
16  for(Index i=0; i<x.size(); ++i)
17  {
18    if((numext::isfinite)(y(i)))
19      VERIFY_IS_APPROX( x(i), y(i) );
20    else if((numext::isnan)(y(i)))
21      VERIFY((numext::isnan)(x(i)));
22    else
23      VERIFY_IS_EQUAL( x(i), y(i) );
24  }
25}
26
27template<typename ArrayType> void array_special_functions()
28{
29  using std::abs;
30  using std::sqrt;
31  typedef typename ArrayType::Scalar Scalar;
32  typedef typename NumTraits<Scalar>::Real RealScalar;
33
34  Scalar plusinf = std::numeric_limits<Scalar>::infinity();
35  Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
36
37  Index rows = internal::random<Index>(1,30);
38  Index cols = 1;
39
40  // API
41  {
42    ArrayType m1 = ArrayType::Random(rows,cols);
43#if EIGEN_HAS_C99_MATH
44    VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1));
45    VERIFY_IS_APPROX(m1.digamma(), digamma(m1));
46    VERIFY_IS_APPROX(m1.erf(), erf(m1));
47    VERIFY_IS_APPROX(m1.erfc(), erfc(m1));
48#endif  // EIGEN_HAS_C99_MATH
49  }
50
51
52#if EIGEN_HAS_C99_MATH
53  // check special functions (comparing against numpy implementation)
54  if (!NumTraits<Scalar>::IsComplex)
55  {
56
57    {
58      ArrayType m1 = ArrayType::Random(rows,cols);
59      ArrayType m2 = ArrayType::Random(rows,cols);
60
61      // Test various propreties of igamma & igammac.  These are normalized
62      // gamma integrals where
63      //   igammac(a, x) = Gamma(a, x) / Gamma(a)
64      //   igamma(a, x) = gamma(a, x) / Gamma(a)
65      // where Gamma and gamma are considered the standard unnormalized
66      // upper and lower incomplete gamma functions, respectively.
67      ArrayType a = m1.abs() + 2;
68      ArrayType x = m2.abs() + 2;
69      ArrayType zero = ArrayType::Zero(rows, cols);
70      ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0));
71      ArrayType a_m1 = a - one;
72      ArrayType Gamma_a_x = Eigen::igammac(a, x) * a.lgamma().exp();
73      ArrayType Gamma_a_m1_x = Eigen::igammac(a_m1, x) * a_m1.lgamma().exp();
74      ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp();
75      ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp();
76
77      // Gamma(a, 0) == Gamma(a)
78      VERIFY_IS_APPROX(Eigen::igammac(a, zero), one);
79
80      // Gamma(a, x) + gamma(a, x) == Gamma(a)
81      VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp());
82
83      // Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x)
84      VERIFY_IS_APPROX(Gamma_a_x, (a - 1) * Gamma_a_m1_x + x.pow(a-1) * (-x).exp());
85
86      // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
87      VERIFY_IS_APPROX(gamma_a_x, (a - 1) * gamma_a_m1_x - x.pow(a-1) * (-x).exp());
88    }
89
90    {
91      // Check exact values of igamma and igammac against a third party calculation.
92      Scalar a_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
93      Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
94
95      // location i*6+j corresponds to a_s[i], x_s[j].
96      Scalar igamma_s[][6] = {{0.0, nan, nan, nan, nan, nan},
97                              {0.0, 0.6321205588285578, 0.7768698398515702,
98                              0.9816843611112658, 9.999500016666262e-05, 1.0},
99                              {0.0, 0.4275932955291202, 0.608374823728911,
100                              0.9539882943107686, 7.522076445089201e-07, 1.0},
101                              {0.0, 0.01898815687615381, 0.06564245437845008,
102                              0.5665298796332909, 4.166333347221828e-18, 1.0},
103                              {0.0, 0.9999780593618628, 0.9999899967080838,
104                              0.9999996219837988, 0.9991370418689945, 1.0},
105                              {0.0, 0.0, 0.0, 0.0, 0.0, 0.5042041932513908}};
106      Scalar igammac_s[][6] = {{nan, nan, nan, nan, nan, nan},
107                              {1.0, 0.36787944117144233, 0.22313016014842982,
108                                0.018315638888734182, 0.9999000049998333, 0.0},
109                              {1.0, 0.5724067044708798, 0.3916251762710878,
110                                0.04601170568923136, 0.9999992477923555, 0.0},
111                              {1.0, 0.9810118431238462, 0.9343575456215499,
112                                0.4334701203667089, 1.0, 0.0},
113                              {1.0, 2.1940638138146658e-05, 1.0003291916285e-05,
114                                3.7801620118431334e-07, 0.0008629581310054535,
115                                0.0},
116                              {1.0, 1.0, 1.0, 1.0, 1.0, 0.49579580674813944}};
117      for (int i = 0; i < 6; ++i) {
118        for (int j = 0; j < 6; ++j) {
119          if ((std::isnan)(igamma_s[i][j])) {
120            VERIFY((std::isnan)(numext::igamma(a_s[i], x_s[j])));
121          } else {
122            VERIFY_IS_APPROX(numext::igamma(a_s[i], x_s[j]), igamma_s[i][j]);
123          }
124
125          if ((std::isnan)(igammac_s[i][j])) {
126            VERIFY((std::isnan)(numext::igammac(a_s[i], x_s[j])));
127          } else {
128            VERIFY_IS_APPROX(numext::igammac(a_s[i], x_s[j]), igammac_s[i][j]);
129          }
130        }
131      }
132    }
133  }
134#endif  // EIGEN_HAS_C99_MATH
135
136  // Check the zeta function against scipy.special.zeta
137  {
138    ArrayType x(7), q(7), res(7), ref(7);
139    x << 1.5,   4, 10.5, 10000.5,    3, 1,        0.9;
140    q << 2,   1.5,    3,  1.0001, -2.5, 1.2345, 1.2345;
141    ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan;
142    CALL_SUBTEST( verify_component_wise(ref, ref); );
143    CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); );
144    CALL_SUBTEST( res = zeta(x,q); verify_component_wise(res, ref); );
145  }
146
147  // digamma
148  {
149    ArrayType x(7), res(7), ref(7);
150    x << 1, 1.5, 4, -10.5, 10000.5, 0, -1;
151    ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, plusinf, plusinf;
152    CALL_SUBTEST( verify_component_wise(ref, ref); );
153
154    CALL_SUBTEST( res = x.digamma(); verify_component_wise(res, ref); );
155    CALL_SUBTEST( res = digamma(x);  verify_component_wise(res, ref); );
156  }
157
158
159#if EIGEN_HAS_C99_MATH
160  {
161    ArrayType n(11), x(11), res(11), ref(11);
162    n << 1, 1,    1, 1.5,   17,   31,   28,    8, 42, 147, 170;
163    x << 2, 3, 25.5, 1.5,  4.7, 11.8, 17.7, 30.2, 15.8, 54.1, 64;
164    ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927;
165    CALL_SUBTEST( verify_component_wise(ref, ref); );
166
167    if(sizeof(RealScalar)>=8) {  // double
168      // Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
169      //       CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
170      CALL_SUBTEST( res = polygamma(n,x);  verify_component_wise(res, ref); );
171    }
172    else {
173      //       CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
174      CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
175    }
176  }
177#endif
178
179#if EIGEN_HAS_C99_MATH
180  {
181    // Inputs and ground truth generated with scipy via:
182    //   a = np.logspace(-3, 3, 5) - 1e-3
183    //   b = np.logspace(-3, 3, 5) - 1e-3
184    //   x = np.linspace(-0.1, 1.1, 5)
185    //   (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x))
186    //   full_a = full_a.flatten().tolist()  # same for full_b, full_x
187    //   v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist()
188    //
189    // Note in Eigen, we call betainc with arguments in the order (x, a, b).
190    ArrayType a(125);
191    ArrayType b(125);
192    ArrayType x(125);
193    ArrayType v(125);
194    ArrayType res(125);
195
196    a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
197        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
198        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
199        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
200        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
201        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
202        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
203        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
204        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
205        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
206        0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
207        0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
208        0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
209        31.62177660168379, 31.62177660168379, 31.62177660168379,
210        31.62177660168379, 31.62177660168379, 31.62177660168379,
211        31.62177660168379, 31.62177660168379, 31.62177660168379,
212        31.62177660168379, 31.62177660168379, 31.62177660168379,
213        31.62177660168379, 31.62177660168379, 31.62177660168379,
214        31.62177660168379, 31.62177660168379, 31.62177660168379,
215        31.62177660168379, 31.62177660168379, 31.62177660168379,
216        31.62177660168379, 31.62177660168379, 31.62177660168379,
217        31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
218        999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
219        999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
220        999.999, 999.999, 999.999;
221
222    b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
223        0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
224        0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
225        31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999,
226        999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0,
227        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
228        0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999,
229        0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379,
230        31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
231        999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
232        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
233        0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
234        31.62177660168379, 31.62177660168379, 31.62177660168379,
235        31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
236        999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
237        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
238        0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
239        31.62177660168379, 31.62177660168379, 31.62177660168379,
240        31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
241        999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
242        0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
243        0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
244        31.62177660168379, 31.62177660168379, 31.62177660168379,
245        31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
246        999.999, 999.999;
247
248    x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
249        0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
250        0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
251        0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
252        -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
253        1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
254        0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
255        0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
256        0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
257        0.8, 1.1;
258
259    v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
260        nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
261        nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
262        0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
263        0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
264        0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan,
265        nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
266        0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
267        0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
268        0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
269        0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
270        1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06,
271        nan, nan, 7.864342668429763e-23, 3.015969667594166e-10,
272        0.0008598571564165444, nan, nan, 6.031987710123844e-08,
273        0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999,
274        0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan,
275        nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan,
276        0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0,
277        3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan,
278        2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan;
279
280    CALL_SUBTEST(res = betainc(a, b, x);
281                 verify_component_wise(res, v););
282  }
283
284  // Test various properties of betainc
285  {
286    ArrayType m1 = ArrayType::Random(32);
287    ArrayType m2 = ArrayType::Random(32);
288    ArrayType m3 = ArrayType::Random(32);
289    ArrayType one = ArrayType::Constant(32, Scalar(1.0));
290    const Scalar eps = std::numeric_limits<Scalar>::epsilon();
291    ArrayType a = (m1 * 4.0).exp();
292    ArrayType b = (m2 * 4.0).exp();
293    ArrayType x = m3.abs();
294
295    // betainc(a, 1, x) == x**a
296    CALL_SUBTEST(
297        ArrayType test = betainc(a, one, x);
298        ArrayType expected = x.pow(a);
299        verify_component_wise(test, expected););
300
301    // betainc(1, b, x) == 1 - (1 - x)**b
302    CALL_SUBTEST(
303        ArrayType test = betainc(one, b, x);
304        ArrayType expected = one - (one - x).pow(b);
305        verify_component_wise(test, expected););
306
307    // betainc(a, b, x) == 1 - betainc(b, a, 1-x)
308    CALL_SUBTEST(
309        ArrayType test = betainc(a, b, x) + betainc(b, a, one - x);
310        ArrayType expected = one;
311        verify_component_wise(test, expected););
312
313    // betainc(a+1, b, x) = betainc(a, b, x) - x**a * (1 - x)**b / (a * beta(a, b))
314    CALL_SUBTEST(
315        ArrayType num = x.pow(a) * (one - x).pow(b);
316        ArrayType denom = a * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
317        // Add eps to rhs and lhs so that component-wise test doesn't result in
318        // nans when both outputs are zeros.
319        ArrayType expected = betainc(a, b, x) - num / denom + eps;
320        ArrayType test = betainc(a + one, b, x) + eps;
321        if (sizeof(Scalar) >= 8) { // double
322          verify_component_wise(test, expected);
323        } else {
324          // Reason for limited test: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
325          verify_component_wise(test.head(8), expected.head(8));
326        });
327
328    // betainc(a, b+1, x) = betainc(a, b, x) + x**a * (1 - x)**b / (b * beta(a, b))
329    CALL_SUBTEST(
330        // Add eps to rhs and lhs so that component-wise test doesn't result in
331        // nans when both outputs are zeros.
332        ArrayType num = x.pow(a) * (one - x).pow(b);
333        ArrayType denom = b * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
334        ArrayType expected = betainc(a, b, x) + num / denom + eps;
335        ArrayType test = betainc(a, b + one, x) + eps;
336        verify_component_wise(test, expected););
337  }
338#endif
339}
340
341void test_special_functions()
342{
343  CALL_SUBTEST_1(array_special_functions<ArrayXf>());
344  CALL_SUBTEST_2(array_special_functions<ArrayXd>());
345}
346