# Python Numba 通过独立迭代减慢循环速度

@numba.jit(nopython=True)
def power_method(A, v):
u = v.copy()
for i in range(3 * 10**3):
u = A @ u
u /= np.linalg.norm(u)
return u

@numba.jit(nopython=True, parallel=True)
def iterate_grid(A, scale, sz):
assert A.shape[0] == A.shape[1] == 3
n = A.shape[0]

results = np.empty((sz**3, n))
tmp = np.linspace(-scale, scale, sz)

for i1 in range(sz):
v = np.empty(n, dtype=np.float64)
v1 = tmp[i1]
for i2, v2 in enumerate(tmp):
for i3, v3 in enumerate(tmp):
v[0] = v1
v[1] = v2
v[2] = v3

u = power_method(A, v)

idx = i1 * sz**2 + i2 * sz + i3
results[idx] = u.copy()

return results

n = 3
A = np.random.randn(n, n)
iterate_grid(A, 5.0, 20)

@numba.jit(nopython=True, parallel=True)
def iterate_grid(A, scale, sz):
assert A.shape[0] == A.shape[1] == 3
n = A.shape[0]

results = np.empty((sz**3, n))
tmp = np.linspace(-scale, scale, sz)

for i1 in numba.prange(sz):
v = np.empty(n, dtype=np.float64)
v1 = tmp[i1]
for i2, v2 in enumerate(tmp):
for i3, v3 in enumerate(tmp):
v[0] = v1
v[1] = v2
v[2] = v3

u = power_method(A, v)

idx = i1 * sz**2 + i2 * sz + i3
results[idx] = u.copy()

return results

## 推荐答案

@numba.njit
def fast_power_method(A, v1, v2, v3):
# Unpacking
# Note: there is no need for a copy here
u1, u2, u3 = v1, v2, v3
A11, A12, A13 = A[0]
A21, A22, A23 = A[1]
A31, A32, A33 = A[2]

for i in range(3_000):
# Optimized matrix multiplication
t1 = u1 * A11 + u2 * A12 + u3 * A13
t2 = u1 * A21 + u2 * A22 + u3 * A23
t3 = u1 * A31 + u2 * A32 + u3 * A33

# Renormalization
# Note: multiplications are faster than divisions
norm = np.sqrt(t1**2 + t2**2 + t3**2)
inv_norm = 1.0 / norm
u1 = t1 * inv_norm
u2 = t2 * inv_norm
u3 = t3 * inv_norm

return u1, u2, u3

@numba.njit
def iterate_grid(A, scale, sz):
assert A.shape[0] == A.shape[1] == 3
n = A.shape[0]

results = np.empty((sz**3, n))
tmp = np.linspace(-scale, scale, sz)

for i1 in range(sz):
v = np.empty(n, dtype=np.float64)
v1 = tmp[i1]
for i2, v2 in enumerate(tmp):
for i3, v3 in enumerate(tmp):
u1, u2, u3 = fast_power_method(A, v1, v2, v3)

idx = i1 * sz**2 + i2 * sz + i3
results[idx, 0] = u1
results[idx, 1] = u2
results[idx, 2] = u3

return results

UPDATE:我填写了一个可用的Nuba bug here.