当我研究在不支持double
计算的32位处理器上高效实现MRG32k3a PRNG时,出现了这个问题.我对ARM、RISC-V和GPU特别感兴趣.MRG32k3a是一种非常高质量的PRNG,因此至今仍在广泛使用,尽管它可以追溯到20世纪90年代末:
L所著,,第Operations Research期,第47卷,第1期,1-2月.1999年,第159-164页
MRG32k3a组合了形式为(c0 ⋅ state0 - c1 ⋅ state1) mod m的两个递归序列,其中statei < m.MRG32k3a中的常数和状态变量都是适合32位的正整数,并且计算中的所有中间表达式的大小都小于253.这是经过设计的,因为参考实现使用IEEE-754double
进行存储和计算.数学模数mod与ISO-C的%
运算符的不同之处在于总是传递非负结果.下面代码中的第一个变体显示了使用double
的稍微现代化的参考实现.
在MRG32k3a的纯整数实现中,32位变量用于状态分量,中间计算以64位算法执行.通过确保被除数为非负,模很容易通过%
计算:(c0 ⋅ state0 - c1 ⋅ state1) mod m = (c0 ⋅ state0 - c1 ⋅ state1 + c1 ⋅ m) % m.计算% m
在32位处理器上是昂贵的,通常导致库调用.这很容易通过标准的常数除法优化来解决,64位乘法运算是最昂贵的部分(参见下面代码中的GENERIC_MOD=1
变体).
当被除数的大小被限制为m = 2n-d时,甚至更快的模计算是可能的,小的d.一套lo = x % 2n, hi= x / 2^n, t = hi * d + lo.只要t < 2 ⋅ m,x mod m = (t >= m) ? (t - m) : t.当x < 2n+(n-ceil(log2(d+1)))时,所需条件成立.这对于MRG32k3a使用的第一个递归效果很好,n=32和d=209需要x < 256,这是微不足道的满足.但是对于第二次重复n=32和d = 22853,需要x < 249.在施加偏移c1 ⋅ m以确保正x之后,在这种情况下,x可以大到8.15 ⋅ 1015,仅略小于253 ≈ 9 ⋅ 1015.
我目前正在解决这个问题,方法是在根据状态变量的值计算x % m
之前,将偏移量相加以确保正x
,这样就保持了x < 244.但从下面提取的相关代码行可以看出,这是一种相当昂贵的方法,它包括32位除法(具有常量除数,因此是可优化的,但仍会产生不必要的成本).
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
对于MRG32k3a使用的第二次递归,是否有替代且成本更低的缓解策略来实现更有效的模计算?
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>
#define BUILTIN_64BIT (0)
#define GENERIC_MOD (0) // applies ony when BUILTIN_64BIT == 0
static double MRG32k3a_s10, MRG32k3a_s11, MRG32k3a_s12;
static double MRG32k3a_s20, MRG32k3a_s21, MRG32k3a_s22;
/* SIMD vectorized by Clang with -ffp-model=precise on x86-84 and AArch64
SIMD vectorized by Intel compiler with -fp-model=precise -march=core-avx2
*/
double MRG32k3a (void)
{
const double norm = 2.328306549295728e-10;
const double m1 = 4294967087.0;
const double m2 = 4294944443.0;
const double a12 = 1403580.0;
const double a13n = 810728.0;
const double a21 = 527612.0;
const double a23n = 1370589.0;
double k, p1, p2;
/* Component 1 */
p1 = a12 * MRG32k3a_s11 - a13n * MRG32k3a_s10;
k = floor (p1 / m1);
p1 -= k * m1;
MRG32k3a_s10 = MRG32k3a_s11; MRG32k3a_s11 = MRG32k3a_s12; MRG32k3a_s12 = p1;
/* Component 2 */
p2 = a21 * MRG32k3a_s22 - a23n * MRG32k3a_s20;
k = floor (p2 / m2);
p2 -= k * m2;
MRG32k3a_s20 = MRG32k3a_s21; MRG32k3a_s21 = MRG32k3a_s22; MRG32k3a_s22 = p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
static uint32_t MRG32k3a_s10i, MRG32k3a_s11i, MRG32k3a_s12i;
static uint32_t MRG32k3a_s20i, MRG32k3a_s21i, MRG32k3a_s22i;
#if BUILTIN_64BIT
double MRG32k3a_i (void)
{
const double norm = 2.328306549295728e-10;
const uint32_t m1 = 4294967087u;
const uint32_t m2 = 4294944443u;
const uint32_t a12 = 1403580u;
const uint32_t a13n = 810728u;
const uint32_t a21 = 527612u;
const uint32_t a23n = 1370589u;
uint64_t prod;
uint32_t p1, p2;
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure it's positive
p1 = (uint32_t)(prod % m1);
MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
/* Component 2 */
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
prod = ((uint64_t)a23n) * m2 + prod; // ensure it's positive
p2 = (uint32_t)(prod % m2);
MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#elif GENERIC_MOD
uint64_t umul64hi (uint64_t a, uint64_t b)
{
uint32_t alo = (uint32_t)a;
uint32_t ahi = (uint32_t)(a >> 32);
uint32_t blo = (uint32_t)b;
uint32_t bhi = (uint32_t)(b >> 32);
uint64_t p0 = (uint64_t)alo * blo;
uint64_t p1 = (uint64_t)alo * bhi;
uint64_t p2 = (uint64_t)ahi * blo;
uint64_t p3 = (uint64_t)ahi * bhi;
return (p1 >> 32) + (((p0 >> 32) + (uint64_t)(uint32_t)p1 + p2) >> 32) + p3;
}
double MRG32k3a_i (void)
{
const double norm = 2.328306549295728e-10;
const uint32_t m1 = 4294967087u;
const uint32_t m2 = 4294944443u;
const uint32_t a12 = 1403580u;
const uint32_t a13n = 810728u;
const uint32_t a21 = 527612u;
const uint32_t a23n = 1370589u;
const uint32_t neg_m1 = 0 - m1; // 209
const uint32_t neg_m2 = 0 - m2; // 22853
const uint64_t magic_mul_m1 = 0x8000006880005551ull;
const uint64_t magic_mul_m2 = 0x4000165147c845ddull;
const uint32_t shft_m1 = 31;
const uint32_t shft_m2 = 30;
uint64_t prod;
uint32_t p1, p2;
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure it's positive
p1 = (uint32_t)((umul64hi (prod, magic_mul_m1) >> shft_m1) * neg_m1 + prod);
MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
/* Component 2 */
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
prod = ((uint64_t)a23n) * m2 + prod; // ensure it's positive
p2 = (uint32_t)((umul64hi (prod, magic_mul_m2) >> shft_m2) * neg_m2 + prod);
MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#else // !BUILTIN_64BIT && !GENERIC_MOD --> special fast modulo computation
double MRG32k3a_i (void)
{
const double norm = 2.328306549295728e-10;
const uint32_t m1 = 4294967087u;
const uint32_t m2 = 4294944443u;
const uint32_t a12 = 1403580u;
const uint32_t a13n = 810728u;
const uint32_t a21 = 527612u;
const uint32_t a23n = 1370589u;
const uint32_t neg_m1 = 0 - m1; // 209
const uint32_t neg_m2 = 0 - m2; // 22853
uint64_t prod;
uint32_t p1, p2, prod_lo, prod_hi, adj;
/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure its positive
// ! special modulo computation: prod must be < 2**56 !
prod_lo = (uint32_t)prod;
prod_hi = (uint32_t)(prod >> 32);
p1 = prod_hi * neg_m1 + prod_lo;
if ((p1 >= m1) || (p1 < prod_lo)) p1 += neg_m1;
MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
/* Component 2 */
prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
// ! special modulo computation: prod must be < 2**49 !
prod_lo = (uint32_t)prod;
prod_hi = (uint32_t)(prod >> 32);
p2 = prod_hi * neg_m2 + prod_lo;
if ((p2 >= m2) || (p2 < prod_lo)) p2 += neg_m2;
MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
/* Combination */
return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#endif // BUILTIN_64BIT
/*
http://www.burtleburtle.net/bob/hash/doobs.html
By Bob Jenkins, 1996. bob_jenkins@burtleburtle.net. You may use this
code any way you wish, private, educational, or commercial. It's free.
*/
#define mix(a,b,c) \
(a -= b, a -= c, a ^= (c>>13), \
b -= c, b -= a, b ^= (a<<8), \
c -= a, c -= b, c ^= (b>>13), \
a -= b, a -= c, a ^= (c>>12), \
b -= c, b -= a, b ^= (a<<16), \
c -= a, c -= b, c ^= (b>>5), \
a -= b, a -= c, a ^= (c>>3), \
b -= c, b -= a, b ^= (a<<10), \
c -= a, c -= b, c ^= (b>>15))
int main (void)
{
uint32_t m1 = 4294967087u;
uint32_t m2 = 4294944443u;
uint32_t a, b, c;
a = 3141592654u;
b = 2718281828u;
c = 10; MRG32k3a_s10 = MRG32k3a_s10i = (1u << 10) | (mix (a, b, c) % m1);
c = 11; MRG32k3a_s11 = MRG32k3a_s11i = (1u << 11) | (mix (a, b, c) % m1);
c = 12; MRG32k3a_s12 = MRG32k3a_s12i = (1u << 12) | (mix (a, b, c) % m1);
c = 20; MRG32k3a_s20 = MRG32k3a_s20i = (1u << 20) | (mix (a, b, c) % m2);
c = 21; MRG32k3a_s21 = MRG32k3a_s21i = (1u << 21) | (mix (a, b, c) % m2);
c = 22; MRG32k3a_s22 = MRG32k3a_s22i = (1u << 22) | (mix (a, b, c) % m2);
double res, ref;
uint64_t count = 0;
do {
res = MRG32k3a_i();
ref = MRG32k3a();
if (res != ref) {
printf("\ncount=%llu ref=%23.16e res=%23.16e\n", count, res, ref);
return EXIT_FAILURE;
}
count++;
if ((count & 0xfffffff) == 0) printf ("\rcount = %llu ", count);
} while (ref != 0);
return EXIT_SUCCESS;
}