1// Optimizations for random number functions, x86 version -*- C++ -*-
2
3// Copyright (C) 2012-2014 Free Software Foundation, Inc.
4//
5// This file is part of the GNU ISO C++ Library.  This library is free
6// software; you can redistribute it and/or modify it under the
7// terms of the GNU General Public License as published by the
8// Free Software Foundation; either version 3, or (at your option)
9// any later version.
10
11// This library is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14// GNU General Public License for more details.
15
16// Under Section 7 of GPL version 3, you are granted additional
17// permissions described in the GCC Runtime Library Exception, version
18// 3.1, as published by the Free Software Foundation.
19
20// You should have received a copy of the GNU General Public License and
21// a copy of the GCC Runtime Library Exception along with this program;
22// see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
23// <http://www.gnu.org/licenses/>.
24
25/** @file bits/opt_random.h
26 *  This is an internal header file, included by other library headers.
27 *  Do not attempt to use it directly. @headername{random}
28 */
29
30#ifndef _BITS_OPT_RANDOM_H
31#define _BITS_OPT_RANDOM_H 1
32
33#include <x86intrin.h>
34
35
36#pragma GCC system_header
37
38
39namespace std _GLIBCXX_VISIBILITY(default)
40{
41_GLIBCXX_BEGIN_NAMESPACE_VERSION
42
43#ifdef __SSE3__
44  template<>
45    template<typename _UniformRandomNumberGenerator>
46      void
47      normal_distribution<double>::
48      __generate(typename normal_distribution<double>::result_type* __f,
49		 typename normal_distribution<double>::result_type* __t,
50		 _UniformRandomNumberGenerator& __urng,
51		 const param_type& __param)
52      {
53	typedef uint64_t __uctype;
54
55	if (__f == __t)
56	  return;
57
58	if (_M_saved_available)
59	  {
60	    _M_saved_available = false;
61	    *__f++ = _M_saved * __param.stddev() + __param.mean();
62
63	    if (__f == __t)
64	      return;
65	  }
66
67	constexpr uint64_t __maskval = 0xfffffffffffffull;
68	static const __m128i __mask = _mm_set1_epi64x(__maskval);
69	static const __m128i __two = _mm_set1_epi64x(0x4000000000000000ull);
70	static const __m128d __three = _mm_set1_pd(3.0);
71	const __m128d __av = _mm_set1_pd(__param.mean());
72
73	const __uctype __urngmin = __urng.min();
74	const __uctype __urngmax = __urng.max();
75	const __uctype __urngrange = __urngmax - __urngmin;
76	const __uctype __uerngrange = __urngrange + 1;
77
78	while (__f + 1 < __t)
79	  {
80	    double __le;
81	    __m128d __x;
82	    do
83	      {
84                union
85                {
86                  __m128i __i;
87                  __m128d __d;
88		} __v;
89
90		if (__urngrange > __maskval)
91		  {
92		    if (__detail::_Power_of_2(__uerngrange))
93		      __v.__i = _mm_and_si128(_mm_set_epi64x(__urng(),
94							     __urng()),
95					      __mask);
96		    else
97		      {
98			const __uctype __uerange = __maskval + 1;
99			const __uctype __scaling = __urngrange / __uerange;
100			const __uctype __past = __uerange * __scaling;
101			uint64_t __v1;
102			do
103			  __v1 = __uctype(__urng()) - __urngmin;
104			while (__v1 >= __past);
105			__v1 /= __scaling;
106			uint64_t __v2;
107			do
108			  __v2 = __uctype(__urng()) - __urngmin;
109			while (__v2 >= __past);
110			__v2 /= __scaling;
111
112			__v.__i = _mm_set_epi64x(__v1, __v2);
113		      }
114		  }
115		else if (__urngrange == __maskval)
116		  __v.__i = _mm_set_epi64x(__urng(), __urng());
117		else if ((__urngrange + 2) * __urngrange >= __maskval
118			 && __detail::_Power_of_2(__uerngrange))
119		  {
120		    uint64_t __v1 = __urng() * __uerngrange + __urng();
121		    uint64_t __v2 = __urng() * __uerngrange + __urng();
122
123		    __v.__i = _mm_and_si128(_mm_set_epi64x(__v1, __v2),
124					    __mask);
125		  }
126		else
127		  {
128		    size_t __nrng = 2;
129		    __uctype __high = __maskval / __uerngrange / __uerngrange;
130		    while (__high > __uerngrange)
131		      {
132			++__nrng;
133			__high /= __uerngrange;
134		      }
135		    const __uctype __highrange = __high + 1;
136		    const __uctype __scaling = __urngrange / __highrange;
137		    const __uctype __past = __highrange * __scaling;
138		    __uctype __tmp;
139
140		    uint64_t __v1;
141		    do
142		      {
143			do
144			  __tmp = __uctype(__urng()) - __urngmin;
145			while (__tmp >= __past);
146			__v1 = __tmp / __scaling;
147			for (size_t __cnt = 0; __cnt < __nrng; ++__cnt)
148			  {
149			    __tmp = __v1;
150			    __v1 *= __uerngrange;
151			    __v1 += __uctype(__urng()) - __urngmin;
152			  }
153		      }
154		    while (__v1 > __maskval || __v1 < __tmp);
155
156		    uint64_t __v2;
157		    do
158		      {
159			do
160			  __tmp = __uctype(__urng()) - __urngmin;
161			while (__tmp >= __past);
162			__v2 = __tmp / __scaling;
163			for (size_t __cnt = 0; __cnt < __nrng; ++__cnt)
164			  {
165			    __tmp = __v2;
166			    __v2 *= __uerngrange;
167			    __v2 += __uctype(__urng()) - __urngmin;
168			  }
169		      }
170		    while (__v2 > __maskval || __v2 < __tmp);
171
172		    __v.__i = _mm_set_epi64x(__v1, __v2);
173		  }
174
175		__v.__i = _mm_or_si128(__v.__i, __two);
176		__x = _mm_sub_pd(__v.__d, __three);
177		__m128d __m = _mm_mul_pd(__x, __x);
178		__le = _mm_cvtsd_f64(_mm_hadd_pd (__m, __m));
179              }
180            while (__le == 0.0 || __le >= 1.0);
181
182            double __mult = (std::sqrt(-2.0 * std::log(__le) / __le)
183                             * __param.stddev());
184
185            __x = _mm_add_pd(_mm_mul_pd(__x, _mm_set1_pd(__mult)), __av);
186
187            _mm_storeu_pd(__f, __x);
188            __f += 2;
189          }
190
191        if (__f != __t)
192          {
193            result_type __x, __y, __r2;
194
195            __detail::_Adaptor<_UniformRandomNumberGenerator, result_type>
196              __aurng(__urng);
197
198            do
199              {
200                __x = result_type(2.0) * __aurng() - 1.0;
201                __y = result_type(2.0) * __aurng() - 1.0;
202                __r2 = __x * __x + __y * __y;
203              }
204            while (__r2 > 1.0 || __r2 == 0.0);
205
206            const result_type __mult = std::sqrt(-2 * std::log(__r2) / __r2);
207            _M_saved = __x * __mult;
208            _M_saved_available = true;
209            *__f = __y * __mult * __param.stddev() + __param.mean();
210          }
211      }
212#endif
213
214
215_GLIBCXX_END_NAMESPACE_VERSION
216} // namespace
217
218
219#endif // _BITS_OPT_RANDOM_H
220