这个递归函数能变成性能相近的迭代函数吗?
考虑到这只是一个简单的深度优先搜索(您也可以使用使用队列而不是堆栈的广度优先搜索,两者都适用),将其转换为迭代函数是很简单的.只需使用堆栈来跟踪要访问的 node .以下是一个适用于任意数量维度的通用解决方案:
def label_image(decoded_image):
shape = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image > 0)):
if labels[idx] == 0:
current_label += 1
stack = [idx]
while stack:
top = stack.pop()
labels[top] = current_label
for i in range(0, len(shape)):
if top[i] > 0:
neighbor = list(top)
neighbor[i] -= 1
neighbor = tuple(neighbor)
if decoded_image[neighbor] == decoded_image[idx] and labels[neighbor] == 0:
stack.append(neighbor)
if top[i] < shape[i] - 1:
neighbor = list(top)
neighbor[i] += 1
neighbor = tuple(neighbor)
if decoded_image[neighbor] == decoded_image[idx] and labels[neighbor] == 0:
stack.append(neighbor)
return labels
不过,从元组的第i个组件中添加或减go 1是很麻烦的(我在这里要查看一个临时列表),而且Numba不接受它(输入错误).一个简单的解决方案是显式地编写2D和3D版本,这可能会极大地提高性能:
@numba.njit
def label_image_2d(decoded_image):
w, h = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image > 0)):
if labels[idx] == 0:
current_label += 1
stack = [idx]
while stack:
x, y = stack.pop()
if decoded_image[x, y] != decoded_image[idx] or labels[x, y] != 0:
continue # already visited or not part of this group
labels[x, y] = current_label
if x > 0: stack.append((x-1, y))
if x+1 < w: stack.append((x+1, y))
if y > 0: stack.append((x, y-1))
if y+1 < h: stack.append((x, y+1))
return labels
@numba.njit
def label_image_3d(decoded_image):
w, h, l = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
for idx in zip(*np.where(decoded_image > 0)):
if labels[idx] == 0:
current_label += 1
stack = [idx]
while stack:
x, y, z = stack.pop()
if decoded_image[x, y, z] != decoded_image[idx] or labels[x, y, z] != 0:
continue # already visited or not part of this group
labels[x, y, z] = current_label
if x > 0: stack.append((x-1, y, z))
if x+1 < w: stack.append((x+1, y, z))
if y > 0: stack.append((x, y-1, z))
if y+1 < h: stack.append((x, y+1, z))
if z > 0: stack.append((x, y, z-1))
if z+1 < l: stack.append((x, y, z+1))
return labels
def label_image(decoded_image):
dim = len(decoded_image.shape)
if dim == 2:
return label_image_2d(decoded_image)
assert dim == 3
return label_image_3d(decoded_image)
还请注意,迭代解决方案不受堆栈限制:np.full((100,100,100), 1)
在迭代解决方案中工作得很好,但在递归解决方案中失败(如果使用Numba,则会出现段错).
做了一个非常基本的基准测试
for i in range(1, 10000):
label_image(np.full((20,20,20), i))
(为了将JIT的影响降至最低,多次迭代也可以进行几次热身,然后开始测量时间或类似情况)
迭代解决方案似乎快了好几倍(在我的机器上大约是5倍见下文).你也许可以优化递归解,并使其达到类似的速度,F.E.通过避免临时的coords
列表或通过将np.where
改为> 0
.
我不知道Numba能在多大程度上优化拉链np.where
.为了进一步优化,您可以考虑(和基准测试)在那里使用显式嵌套for x in range(0, w): for y in range(0, h):
循环.
为了保持与尼克提出的合并战略的竞争力,我对此进行了进一步的优化,摘取了一些容易摘到的果实:
- 使用
continue
而不是np.where
将zip
转换为显式循环.
- 将
decoded_image[idx]
存储在一个本地变量中(理想情况下应该没什么关系,但也没什么坏处).
- 重复使用堆栈.这可以防止不必要的(重新)分配和GC压力.还可以进一步考虑为堆栈提供初始容量(分别为
w*h
或w*h*l
).
@numba.njit
def label_image_2d(decoded_image):
w, h = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
stack = []
for sx in range(0, w):
for sy in range(0, h):
start = (sx, sy)
image_label = decoded_image[start]
if image_label <= 0 or labels[start] != 0:
continue
current_label += 1
stack.append(start)
while stack:
x, y = stack.pop()
if decoded_image[x, y] != image_label or labels[x, y] != 0:
continue # already visited or not part of this group
labels[x, y] = current_label
if x > 0: stack.append((x-1, y))
if x+1 < w: stack.append((x+1, y))
if y > 0: stack.append((x, y-1))
if y+1 < h: stack.append((x, y+1))
return labels
@numba.njit
def label_image_3d(decoded_image):
w, h, l = decoded_image.shape
labels = np.zeros_like(decoded_image, dtype=np.uint32)
current_label = 0
stack = []
for sx in range(0, w):
for sy in range(0, h):
for sz in range(0, l):
start = (sx, sy, sz)
image_label = decoded_image[start]
if image_label <= 0 or labels[start] != 0:
continue
current_label += 1
stack.append(start)
while stack:
x, y, z = stack.pop()
if decoded_image[x, y, z] != image_label or labels[x, y, z] != 0:
continue # already visited or not part of this group
labels[x, y, z] = current_label
if x > 0: stack.append((x-1, y, z))
if x+1 < w: stack.append((x+1, y, z))
if y > 0: stack.append((x, y-1, z))
if y+1 < h: stack.append((x, y+1, z))
if z > 0: stack.append((x, y, z-1))
if z+1 < l: stack.append((x, y, z+1))
return labels
然后,我拼凑了一个基准来比较四种方法(原始递归、旧迭代、新迭代、基于合并),将它们放在四个不同的模块中:
import numpy as np
import timeit
import rec
import iter_old
import iter_new
import merge
shape = (100, 100, 100)
n = 20
for module in [rec, iter_old, iter_new, merge]:
print(module)
label_image = module.label_image
# Trigger compilation of 2d & 3d functions
label_image(np.zeros((1, 1)))
label_image(np.zeros((1, 1, 1)))
i = 0
def test_full():
global i
i += 1
label_image(np.full(shape, i))
print("single group:", timeit.timeit(test_full, number=n))
print("random (few groups):", timeit.timeit(
lambda: label_image(np.random.randint(low = 1, high = 10, size = shape)),
number=n))
print("random (many groups):", timeit.timeit(
lambda: label_image(np.random.randint(low = 1, high = 400, size = shape)),
number=n))
print("only groups:", timeit.timeit(
lambda: label_image(np.arange(np.prod(shape)).reshape(shape)),
number=n))
这将输出类似以下内容
<module 'rec' from '...'>
single group: 32.39212468900041
random (few groups): 14.648884047001047
random (many groups): 13.304533919001187
only groups: 13.513677138000276
<module 'iter_old' from '...'>
single group: 10.287227957000141
random (few groups): 17.37535468200076
random (many groups): 14.506630064999626
only groups: 13.132202609998785
<module 'iter_new' from '...'>
single group: 7.388022166000155
random (few groups): 11.585243002000425
random (many groups): 9.560101995000878
only groups: 8.693653742000606
<module 'merge' from '...'>
single group: 14.657021331999204
random (few groups): 14.146574055999736
random (many groups): 13.412314713001251
only groups: 12.642367746000673
在我看来,改进的迭代方法可能更好.请注意,原始的基本基准测试似乎是递归变体的最坏情况.总体而言,两者之间的差异并不大.
测试的数组非常小(20?).如果我使用一个较大的数组(rec
³)和一个较小的n(20)进行测试,我会得到大致如下的结果(rec
被省略,因为由于堆栈限制,它将出现段错):
<module 'iter_old' from '...'>
single group: 3.5357716739999887
random (few groups): 4.931695729999774
random (many groups): 3.4671142009992764
only groups: 3.3023930709987326
<module 'iter_new' from '...'>
single group: 2.45903080700009
random (few groups): 2.907660342001691
random (many groups): 2.309699692999857
only groups: 2.052835552000033
<module 'merge' from '...'>
single group: 3.7620838259990705
random (few groups): 3.3524249689999124
random (many groups): 3.126650959999097
only groups: 2.9456547739991947
迭代方法似乎仍然更有效率.