假设我有一个数组

import numpy as np 

z = np.array(
    [
     [1, 1, 0, 0, 0, 0],
     [1, 1, 1, 1, 1, 0],
     [1, 1, 1, 0, 0, 0],
     [1, 1, 1, 1, 1, 1],
    ]
)

其中,1从每个数组的左侧开始,0从右侧开始(如果有).对于许多应用程序,这就是数组的填充方式,以便在数组数组中每个数组的长度相同.

我如何获得这样一个数组的最短非零序列.

在这种情况下,最短的序列是第一个数组,其长度为2.

显而易见的答案是迭代每个数组并找到第一个零的索引,但我觉得可能有一种方法可以更好地利用numpy的c处理.

推荐答案

5000×5000数组的基准测试:

 74.3 ms  Dani
 33.8 ms  user19077881
  2.6 ms  Kelly1
  1.4 ms  Kelly2

My Kelly1是一个从右上到左下的O(m+n)鞍形搜索:

def Kelly1(z):
    m, n = z.shape
    j = n - 1
    for i in range(m):
        while not z[i, j]:
            j -= 1
            if j < 0:
                return 0
    return j + 1

(Michael Szczesny说,使用Numba可以使速度提高约150倍(如果我没记错的话).不过,我自己没有能力测试.)

My Kelly2是一个O(m log n)水平二进制搜索,使用NumPy判断列是否充满非零:

def Kelly2(z):
    m, n = z.shape
    lo, hi = 0, n
    while lo < hi:
        mid = (lo + hi) // 2
        if z[:, mid].all():
            lo = mid + 1
        else:
            hi = mid
    return lo

(使用bisectkey可以缩短时间,但我现在没有Python 3.10测试.)

注意:Dani和user19077881返回不同的结果:任何行中非零数最少,或非零数最少的行.我听从了丹尼的领导,因为这是公认的答案.这其实并不重要,因为你可以很快地从另一个结果中计算出一个结果(通过分别找到列或行中第一个零的索引).

完整基准代码(Try it online!):

import numpy as np
from timeit import timeit
import random

m, n = 5000, 5000

def genz():
    lo = random.randrange(n*5//100, n//3)
    return np.array(
        [
            [1]*ones + [0]*(n-ones)
            for ones in random.choices(range(lo, n+1), k=m)
        ]
    )

def Dani(z):
    return np.count_nonzero(z, axis=1).min()

def user19077881(z):
    z_sums = z.sum(axis = 1)
    z_least = np.argmin(z_sums)
    return z_least

def Kelly1(z):
    m, n = z.shape
    j = n - 1
    for i in range(m):
        while not z[i, j]:
            j -= 1
            if j < 0:
                return 0
    return j + 1

def Kelly2(z):
    m, n = z.shape
    lo, hi = 0, n
    while lo < hi:
        mid = (lo + hi) // 2
        if z[:, mid].all():
            lo = mid + 1
        else:
            hi = mid
    return lo

funcs = Dani, user19077881, Kelly1, Kelly2

for _ in range(3):
    z = genz()
    for f in funcs:
        t = timeit(lambda: f(z), number=1)
        print('%5.1f ms ' % (t * 1e3), f.__name__)
    print()

Python相关问答推荐

Django:如何将一个模型的唯一实例创建为另一个模型中的字段

已安装' owiener ' Python模块,但在导入过程中始终没有名为owiener的模块

Django注释:将时差转换为小数或小数

Altair -箱形图边界设置为黑色,中线设置为红色

如何在不使用字符串的情况下将namedtuple属性传递给方法?

如何让pyparparsing匹配1天或2天,但1天和2天失败?

将numpy矩阵映射到字符串矩阵

Locust请求中的Python和参数

三个给定的坐标可以是矩形的点吗

Pandas 在最近的日期合并,考虑到破产

将jit与numpy linSpace函数一起使用时出错

为什么符号没有按顺序添加?

在Polars(Python库)中将二进制转换为具有非UTF-8字符的字符串变量

把一个pandas文件夹从juyter笔记本放到堆栈溢出问题中的最快方法?

所有列的滚动标准差,忽略NaN

在Python中,从给定范围内的数组中提取索引组列表的更有效方法

SQLAlchemy bindparam在mssql上失败(但在mysql上工作)

Django—cte给出:QuerySet对象没有属性with_cte''''

Pandas Data Wrangling/Dataframe Assignment

在pandas/python中计数嵌套类别