当我研究在不支持double计算的32位处理器上高效实现MRG32k3a PRNG时,出现了这个问题.我对ARM、RISC-V和GPU特别感兴趣.MRG32k3a是一种非常高质量的PRNG,因此至今仍在广泛使用,尽管它可以追溯到20世纪90年代末:

L所著,,第Operations Research期,第47卷,第1期,1-2月.1999年,第159-164页

MRG32k3a组合了形式为(c0 ⋅ state0 - c1 ⋅ state1) mod m的两个递归序列,其中statei < m.MRG32k3a中的常数和状态变量都是适合32位的正整数,并且计算中的所有中间表达式的大小都小于253.这是经过设计的,因为参考实现使用IEEE-754double进行存储和计算.数学模数mod与ISO-C的%运算符的不同之处在于总是传递非负结果.下面代码中的第一个变体显示了使用double的稍微现代化的参考实现.

在MRG32k3a的纯整数实现中,32位变量用于状态分量,中间计算以64位算法执行.通过确保被除数为非负,模很容易通过%计算:(c0 ⋅ state0 - c1 ⋅ state1) mod m = (c0 ⋅ state0 - c1 ⋅ state1 + c1 ⋅ m) % m.计算% m在32位处理器上是昂贵的,通常导致库调用.这很容易通过标准的常数除法优化来解决,64位乘法运算是最昂贵的部分(参见下面代码中的GENERIC_MOD=1变体).

当被除数的大小被限制为m = 2n-d时,甚至更快的模计算是可能的,小的d.一套lo = x % 2n, hi= x / 2^n, t = hi * d + lo.只要t < 2 ⋅ mx mod m = (t >= m) ? (t - m) : t.当x < 2n+(n-ceil(log2(d+1)))时,所需条件成立.这对于MRG32k3a使用的第一个递归效果很好,n=32d=209需要x < 256,这是微不足道的满足.但是对于第二次重复n=32d = 22853,需要x < 249.在施加偏移c1 ⋅ m以确保正x之后,在这种情况下,x可以大到8.15 ⋅ 1015,仅略小于253 ≈ 9 ⋅ 1015.

我目前正在解决这个问题,方法是在根据状态变量的值计算x % m之前,将偏移量相加以确保正x,这样就保持了x < 244.但从下面提取的相关代码行可以看出,这是一种相当昂贵的方法,它包括32位除法(具有常量除数,因此是可优化的,但仍会产生不必要的成本).

    prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
    adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
    prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive

对于MRG32k3a使用的第二次递归,是否有替代且成本更低的缓解策略来实现更有效的模计算?

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>

#define BUILTIN_64BIT (0)
#define GENERIC_MOD   (0)  // applies ony when BUILTIN_64BIT == 0

static double MRG32k3a_s10, MRG32k3a_s11, MRG32k3a_s12;
static double MRG32k3a_s20, MRG32k3a_s21, MRG32k3a_s22;

/* SIMD vectorized by Clang with -ffp-model=precise on x86-84 and AArch64
   SIMD vectorized by Intel compiler with -fp-model=precise -march=core-avx2
 */
double MRG32k3a (void)
{
    const double norm = 2.328306549295728e-10;
    const double m1 = 4294967087.0;
    const double m2 = 4294944443.0;
    const double a12 = 1403580.0;
    const double a13n = 810728.0;
    const double a21 = 527612.0;
    const double a23n = 1370589.0;
    double k, p1, p2;

    /* Component 1 */
    p1 = a12 * MRG32k3a_s11 - a13n * MRG32k3a_s10;
    k = floor (p1 / m1);  
    p1 -= k * m1;
    MRG32k3a_s10 = MRG32k3a_s11; MRG32k3a_s11 = MRG32k3a_s12; MRG32k3a_s12 = p1;
    /* Component 2 */
    p2 = a21 * MRG32k3a_s22 - a23n * MRG32k3a_s20;
    k = floor (p2 / m2);  
    p2 -= k * m2;
    MRG32k3a_s20 = MRG32k3a_s21; MRG32k3a_s21 = MRG32k3a_s22; MRG32k3a_s22 = p2;
    /* Combination */
    return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}

static uint32_t MRG32k3a_s10i, MRG32k3a_s11i, MRG32k3a_s12i;
static uint32_t MRG32k3a_s20i, MRG32k3a_s21i, MRG32k3a_s22i;

#if BUILTIN_64BIT
double MRG32k3a_i (void)
{
    const double norm = 2.328306549295728e-10;
    const uint32_t m1 = 4294967087u;
    const uint32_t m2 = 4294944443u;
    const uint32_t a12 = 1403580u;
    const uint32_t a13n = 810728u;
    const uint32_t a21 = 527612u;
    const uint32_t a23n = 1370589u;
    uint64_t prod;
    uint32_t p1, p2;

    /* Component 1 */
    prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
    prod = ((uint64_t)a13n) * m1 + prod; // ensure it's positive
    p1 = (uint32_t)(prod % m1);
    MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
    /* Component 2 */
    prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
    prod = ((uint64_t)a23n) * m2 + prod; // ensure it's positive
    p2 = (uint32_t)(prod % m2);
    MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
    /* Combination */
    return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#elif GENERIC_MOD
uint64_t umul64hi (uint64_t a, uint64_t b)
{
    uint32_t alo = (uint32_t)a;
    uint32_t ahi = (uint32_t)(a >> 32);
    uint32_t blo = (uint32_t)b;
    uint32_t bhi = (uint32_t)(b >> 32);
    uint64_t p0 = (uint64_t)alo * blo;
    uint64_t p1 = (uint64_t)alo * bhi;
    uint64_t p2 = (uint64_t)ahi * blo;
    uint64_t p3 = (uint64_t)ahi * bhi;
    return (p1 >> 32) + (((p0 >> 32) + (uint64_t)(uint32_t)p1 + p2) >> 32) + p3;
}

double MRG32k3a_i (void)
{
    const double norm = 2.328306549295728e-10;
    const uint32_t m1 = 4294967087u;
    const uint32_t m2 = 4294944443u;
    const uint32_t a12 = 1403580u;
    const uint32_t a13n = 810728u;
    const uint32_t a21 = 527612u;
    const uint32_t a23n = 1370589u;
    const uint32_t neg_m1 = 0 - m1; // 209
    const uint32_t neg_m2 = 0 - m2; // 22853
    const uint64_t magic_mul_m1 = 0x8000006880005551ull;
    const uint64_t magic_mul_m2 = 0x4000165147c845ddull;
    const uint32_t shft_m1 = 31;
    const uint32_t shft_m2 = 30;
    uint64_t prod;
    uint32_t p1, p2;

    /* Component 1 */
    prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
    prod = ((uint64_t)a13n) * m1 + prod; // ensure it's positive
    p1 = (uint32_t)((umul64hi (prod, magic_mul_m1) >> shft_m1) * neg_m1 + prod);
    MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
    /* Component 2 */
    prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
    prod = ((uint64_t)a23n) * m2 + prod; // ensure it's positive
    p2 = (uint32_t)((umul64hi (prod, magic_mul_m2) >> shft_m2) * neg_m2 + prod);
    MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
    /* Combination */
    return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#else // !BUILTIN_64BIT && !GENERIC_MOD --> special fast modulo computation
double MRG32k3a_i (void)
{
    const double norm = 2.328306549295728e-10;
    const uint32_t m1 = 4294967087u;
    const uint32_t m2 = 4294944443u;
    const uint32_t a12 = 1403580u;
    const uint32_t a13n = 810728u;
    const uint32_t a21 = 527612u;
    const uint32_t a23n = 1370589u;
    const uint32_t neg_m1 = 0 - m1; // 209
    const uint32_t neg_m2 = 0 - m2; // 22853
    uint64_t prod;
    uint32_t p1, p2, prod_lo, prod_hi, adj;

    /* Component 1 */
    prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
    prod = ((uint64_t)a13n) * m1 + prod; // ensure its positive
    // ! special modulo computation: prod must be < 2**56 !
    prod_lo = (uint32_t)prod;
    prod_hi = (uint32_t)(prod >> 32);
    p1 = prod_hi * neg_m1 + prod_lo;
    if ((p1 >= m1) || (p1 < prod_lo)) p1 += neg_m1;
    MRG32k3a_s10i=MRG32k3a_s11i; MRG32k3a_s11i=MRG32k3a_s12i; MRG32k3a_s12i=p1;
    /* Component 2 */
    prod = ((uint64_t)a21) * MRG32k3a_s22i - ((uint64_t)a23n) * MRG32k3a_s20i;
    adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
    prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive
    // ! special modulo computation: prod must be < 2**49 !
    prod_lo = (uint32_t)prod;
    prod_hi = (uint32_t)(prod >> 32);
    p2 = prod_hi * neg_m2 + prod_lo;
    if ((p2 >= m2) || (p2 < prod_lo)) p2 += neg_m2;
    MRG32k3a_s20i=MRG32k3a_s21i; MRG32k3a_s21i=MRG32k3a_s22i; MRG32k3a_s22i=p2;
    /* Combination */
    return ((p1 <= p2) ? (p1 - p2 + m1) : (p1 - p2)) * norm;
}
#endif // BUILTIN_64BIT

/*
  http://www.burtleburtle.net/bob/hash/doobs.html
  By Bob Jenkins, 1996.  bob_jenkins@burtleburtle.net.  You may use this
  code any way you wish, private, educational, or commercial.  It's free.
*/
#define mix(a,b,c) \
    (a -= b, a -= c, a ^= (c>>13), \
     b -= c, b -= a, b ^= (a<<8),  \
     c -= a, c -= b, c ^= (b>>13), \
     a -= b, a -= c, a ^= (c>>12), \
     b -= c, b -= a, b ^= (a<<16), \
     c -= a, c -= b, c ^= (b>>5),  \
     a -= b, a -= c, a ^= (c>>3),  \
     b -= c, b -= a, b ^= (a<<10), \
     c -= a, c -= b, c ^= (b>>15))

int main (void)
{
    uint32_t m1 = 4294967087u;
    uint32_t m2 = 4294944443u;
    uint32_t a, b, c;
    a = 3141592654u;
    b = 2718281828u;
    c = 10; MRG32k3a_s10 = MRG32k3a_s10i = (1u << 10) | (mix (a, b, c) % m1);
    c = 11; MRG32k3a_s11 = MRG32k3a_s11i = (1u << 11) | (mix (a, b, c) % m1);
    c = 12; MRG32k3a_s12 = MRG32k3a_s12i = (1u << 12) | (mix (a, b, c) % m1);
    c = 20; MRG32k3a_s20 = MRG32k3a_s20i = (1u << 20) | (mix (a, b, c) % m2);
    c = 21; MRG32k3a_s21 = MRG32k3a_s21i = (1u << 21) | (mix (a, b, c) % m2);
    c = 22; MRG32k3a_s22 = MRG32k3a_s22i = (1u << 22) | (mix (a, b, c) % m2);
    
    double res, ref;
    uint64_t count = 0;
    do {
        res = MRG32k3a_i();
        ref = MRG32k3a();
        if (res != ref) {
            printf("\ncount=%llu  ref=%23.16e  res=%23.16e\n", count, res, ref);
            return EXIT_FAILURE;
        }
        count++;
        if ((count & 0xfffffff) == 0) printf ("\rcount = %llu ", count);
    } while (ref != 0);
    return EXIT_SUCCESS;
}

推荐答案

我还没有对此进行测试或基准测试,但如果我理解你的做法是正确的,我认为另一个 Select 是增加固定的m2倍数,并进行两轮削减.

prod = ((uint64_t)a21) * MRG32k3a_s22i + ((uint64_t)m2 << 22) - ((uint64_t)a23n) * MRG32k3a_s20i; // 55 bits

然后你可以省略下面两行.

adj = MRG32k3a_s20i / 3133 - (MRG32k3a_s22i >> 13) + 1;
prod = ((int64_t)(int32_t)adj) * (int64_t)m2 + prod; // ensure it's positive

然后再go 做

prod_lo = (uint32_t)prod; 
prod_hi = (uint32_t)(prod >> 32); // 23 bits
prod = (uint64_t)prod_hi * neg_m2 + prod_lo; // 39 bits
prod_lo = (uint32_t)prod; 
prod_hi = (uint32_t)(prod >> 32); 
p2 = prod_hi * neg_m2 + prod_lo; 
if ((p2 >= m2) || (p2 < prod_lo)) p2 += neg_m2;

您可能还希望将组件1的计算更改为如下所示

/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i + ((uint64_t)a13n) * (m1 - MRG32k3a_s10i); // 54 bits

而不是两个步骤

/* Component 1 */
prod = ((uint64_t)a12) * MRG32k3a_s11i - ((uint64_t)a13n) * MRG32k3a_s10i;
prod = ((uint64_t)a13n) * m1 + prod; // ensure its positive

这样就不依赖于编译器足够聪明来实现它只需要两次乘法.我认为这是有效的,因为0 = MRG32k3a_s1xi m1对吗?<<

C++相关问答推荐

理解没有返回语句的递归C函数的行为

错误:在.h程序中重新定义 struct

手动矢量化性能差异较大

有没有可能我不能打印?(C,流程)

X86/x64上的SIGSEGV,由于原始内存访问和C中的DS寄存器之间的冲突,在Linux上用TCC编译为JIT引擎

为什么cudaFree不需要数据 struct 的地址?

解决S随机内存分配问题,实现跨进程高效数据共享

用C++从外部ELF符号读取值

C I/O:在Windows控制台上处理键盘输入

链接到底是如何工作的,我在这里到底做错了什么

是否定义了此函数的行为?

C代码可以在在线编译器上运行,但不能在Leetcode上运行

不确定如何处理此编译错误

不带Malloc的链表

如何在C中计算包含递增和递减运算符的逻辑表达式?

struct 中的qsort,但排序后的 struct 很乱

Ubuntu编译:C中的文件格式无法识别错误

为什么使用 C 引用这个 char 数组会导致 Stack smasing?

如何在Linux上从控制台左上角开始打印文本?

c中数组上显示的随机元素