我试图探索Infinite Line上的Random Walk算法,我正在寻找一种方法来使其成为as optimal as possible.以下是代码:

import random
from collections import Counter

def run(initial_pos,iterations,trials):
    final_pos = []
    for i in range (0,trials):
        pos = initial_pos
        for j in range (0,iterations):
            if random.choice(["left","right"]) == "left":
                pos -= 1
            else:
                pos += 1
        final_pos.append(pos)
    return Counter(final_pos)

Variable iterations indicates the number of repetitions in a single walk.
While trials indicates the number of walks.

The runtime is satisfactory for trials and iterations equals 10^4
Hwever, increasing to 10^5 requires long waiting time.

推荐答案

低垂的果子

首先,让我们go 掉那些神奇的数字和字符串,并定义一些命名常量:

RIGHT_NAME, LEFT_NAME = "right", "left"
RIGHT, LEFT = 1, -1

现在,我们可以重写内部循环体以使用这些常量:

if random.choice([LEFT_NAME , RIGHT_NAME]) == LEFT_NAME:
    pos += LEFT
else:
    pos += RIGHT

考虑到这一点,两个问题变得显而易见:

  1. 我们正在比较字符串是否相等.这不太可能比比较小整数表现得更好.
  2. 我们有两对代表同一概念的常量(在左和右之间进行 Select ).我们在重复自己,这很少是一件好事.

让我们只使用整数常量RIGHTLEFT来消除这两个问题.让我们还将random.choice的结果存储到临时变量中,以便更好地查看下一个问题:

current_direction = random.choice([LEFT, RIGHT])
if current_direction == LEFT:
    pos += LEFT
else:
    pos += RIGHT

现在,我们可以看到current_direction可以是RIGHT,也可以是LEFT.如果是LEFT,那么我们将LEFT加到pos,否则(唯一的其他 Select 是RIGHT),我们将RIGHT加到pos.换句话说:

current_direction = random.choice([LEFT, RIGHT])
if current_direction == LEFT:
    pos += current_direction   # current_direction is LEFT
else:
    pos += current_direction   # current_direction is RIGHT

在两个分支中发生相同的事情,所以让我们go 掉这个条件:

pos += random.choice([LEFT, RIGHT])

这是我们的新起点:

def run(initial_pos, iterations, trials):
    final_pos = []
    RIGHT, LEFT = 1, -1
    for i in range(0, trials):
        pos = initial_pos
        for j in range(0, iterations):
            pos += random.choice([LEFT, RIGHT])
        final_pos.append(pos)
    return Counter(final_pos)

现在,如果我们批判性地查看代码(并考虑到字节码和解释器的低级别细节--dis模块在这里有所帮助),我们可以看到其他一些缺陷:

  1. 呼叫range(0, trials)在功能上与range(trials)相同.然而,冗余0意味着将常量压入堆栈额外操作码.我们不想把时间浪费在无用的事情上,对吗?
  2. 在我们最内层的循环中,我们调用一个带有参数[LEFT, RIGHT]的函数.Python不会优化这些东西,所以这意味着我们要创建相同的列表iterations * trials次.即使只有3个操作码,让我们只做一次,然后重用相同的列表.我们不妨把它变成一个元组,它不需要改变.
  3. 班级collections.Counter支持updates.因此,让我们避免中间的final_pos列表,而直接更新Counter.
  4. 如果我们在搜索周期,我们还可以注意到对random.choice的调用涉及两个操作码(首先是GET random,然后是Find choice).我们可以缓存要调用的本地变量的实际函数,以避免额外的步骤.
def run_v1(initial_pos, iterations, trials):
    RIGHT, LEFT = 1, -1
    CHOICES = (RIGHT, LEFT)
   
    result = Counter()
    
    random_choice = random.choice
    result_update = result.update
    
    for i in range(trials):
        pos = initial_pos
        for j in range(iterations):
            pos += random_choice(CHOICES)
        result_update([pos])

    return result

这些最初的更改只意味着代码的运行时间大约是原始代码所需时间的90%-95%,但它们为进一步优化提供了坚实的基础.

消除内环-try 1

现在,我们的内部循环基本上是随机 Select 列表的总和,偏移initial_pos.让我们使用random.choices来在一次调用中生成iterations个选项,并使用内置sum将它们相加.

def run_v2(initial_pos, iterations, trials):
    RIGHT, LEFT = 1, -1
    CHOICES = (RIGHT, LEFT)
    
    result = Counter()
    
    random_choices = random.choices
    result_update = result.update
    
    for i in range(trials):
        result_update([initial_pos + sum(random_choices(CHOICES, k=iterations))])

    return result

这大约需要run_v1年所需时间的25%.

消除内环-try 2

前一个版本的主要问题是,它最终分配了许多中间的Python对象(每个 Select 一个).使用内存的一种更有效的方式是使用NumPy数组和库提供的各种函数.例如,我们可以使用NumPy数组的numpy.Generator.choicesum方法:

def run_v3(initial_pos, iterations, trials):
    RIGHT, LEFT = 1, -1
    CHOICES = (RIGHT, LEFT)
    
    result = Counter()
    
    rng = np.random.default_rng()
    rng_choice = rng.choice
    result_update = result.update
    
    for i in range(trials):
        result_update([initial_pos + rng_choice(CHOICES, size=iterations).sum()])
        
    return result

这个版本需要的时间大约是run_v2所需时间的20%.

更高效的 Select

目前,随机 Select 需要间接查找才能映射到我们想要的(1, -1)个选项集.让我们来观察一下这iterations = right_count + left_count条.我们已经知道iterations的值,所以只要我们知道right_countleft_count中的一个,我们就可以计算另一个.

因此是pos_offset = 2 * right_count - iterations,我们可以这样实施:

def run_v4(initial_pos, iterations, trials):
    result = Counter()
    
    rng = np.random.default_rng()
    rng_choice = rng.choice
    result_update = result.update
    
    for i in range(0, trials):
        result_update([initial_pos + 2 * rng_choice(2, size=iterations).sum() - iterations])

    return result

这一次,它需要大约60%-90%的run_v3.


TODO:解释以下内容

np.random.default_rng().integers代替.

def run_v5a(initial_pos, iterations, trials):
    final_pos = []
    rng = np.random.default_rng()
    rng_integers = rng.integers
    for i in range(0, trials):
        pos = initial_pos + 2 * rng_integers(2, size=iterations).sum() - iterations
        final_pos.append(pos)
    return Counter(final_pos)
def run_v5b(initial_pos, iterations, trials):
    final_pos = []
    rng = np.random.default_rng()
    rng_integers = rng.integers
    for i in range(0, trials):
        pos = initial_pos + 2 * rng_integers(2, dtype=np.uint8, size=iterations).sum(dtype=np.int32) - iterations
        final_pos.append(pos)
    return Counter(final_pos)

使用更多的内存来消除上面的循环.

def run_v6a(initial_pos, iterations, trials):
    rng = np.random.default_rng()
    final_pos = initial_pos + 2 * rng.integers(2, size=(trials, iterations)).sum(axis=1) - iterations
    return Counter(final_pos)
def run_v6b(initial_pos, iterations, trials):
    rng = np.random.default_rng()
    final_pos = initial_pos + 2 * rng.integers(2, dtype=np.uint8, size=(trials, iterations)).sum(axis=1,
        dtype=np.int32) - iterations
    return Counter(final_pos)

根据Jérôme Richard个样本binomial distribution个样本的建议.

应用数学并使用np.random.binomial.

def run_v7(initial_pos, iterations, trials):
    final_pos = initial_pos + 2 * np.random.binomial(iterations, 0.5, trials) - iterations
    return Counter(final_pos)

这将在1/4秒内运行10^6次迭代和试验.

Python相关问答推荐

如何编写一个正规表达式来查找序列中具有2个或更多相同辅音的所有单词

将numpy矩阵映射到字符串矩阵

当值是一个integer时,在Python中使用JMESPath来验证字典中的值(例如:1)

数字梯度的意外值

DuckDB将蜂巢分区插入拼花文件

Python -根据另一个数据框中的列编辑和替换数据框中的列值

如果索引不存在,pandas系列将通过索引获取值,并填充值

LAB中的增强数组

如何将ctyles.POINTER(ctyles.c_float)转换为int?

PMMLPipeline._ fit()需要2到3个位置参数,但给出了4个位置参数

根据二元组列表在pandas中创建新列

为什么默认情况下所有Python类都是可调用的?

如何使用它?

Streamlit应用程序中的Plotly条形图中未正确显示Y轴刻度

如何使用Python以编程方式判断和检索Angular网站的动态内容?

如何在turtle中不使用write()来绘制填充字母(例如OEG)

python—telegraph—bot send_voice发送空文件

巨 Python :逆向猜谜游戏

在二维NumPy数组中,如何 Select 内部数组的第一个和第二个元素?这可以通过索引来实现吗?

如何使用正则表达式修改toml文件中指定字段中的参数值