假设我有N个项目和一个值为{0, 1}的多个热向量,表示结果中包含这些项目:

N = 4

# items 1 and 3 will be included in the result
vector = [0, 1, 0, 1]

# item 2 will be included in the result
vector = [0, 0, 1, 0]

我还提供了一个冲突矩阵,指出哪些项目不能同时包含在结果中:

conflicts = [
  [0, 1, 1, 0], # any result that contains items 1 AND 2 is invalid
  [0, 1, 1, 1], # any result that contains AT LEAST 2 items from {1, 2, 3} is invalid
]

鉴于这一冲突矩阵,我们可以确定前vector年的有效性:

# invalid as it triggers conflict 1: [0, 1, 1, 1]
vector = [0, 1, 0, 1]

# valid as it triggers no conflicts
vector = [0, 0, 1, 0]

检测给定vector是否"有效"(即不触发任何冲突)的简单解决方案可以通过numpy中的点积和求和操作完成:

violation = np.dot(conflicts, vector)
is_valid = np.max(violation) <= 1

是否有更有效的方法来执行此操作,或者通过np.einsum或完全绕过numpy数组来支持位操作?

我们假设被判断的向量数量可能非常大(例如,如果我们判断所有可能性,最多可判断2^N个),但一次可能只判断一个向量(以避免生成形状高达(2^N, N)的矩阵作为输入).

推荐答案

TL;DR:你可以用Numba来优化np.dot,只操作only on binary values.更具体地说,您可以使用64-bit views一次对8个字节执行SIMD-like operations.




将列表转换为数组

首先,使用这种方法可以有效地将列表转换为相对紧凑的数组:

vector = np.fromiter(vector, np.uint8)
conflicts = np.array([np.fromiter(conflicts[i], np.uint8) for i in range(len(conflicts))])

这比使用自动Numpy转换或np.array更快(在Numpy代码内部执行的判断更少,Numpy、Numpy知道要构建什么类型的数组,并且生成的数组在内存中更小,因此填充速度更快).此步骤可用于加速基于np.dot的解决方案.

如果输入已经是一个Numpy数组,那么判断它们的类型是np.uint8np.int8.否则,请以conflits = conflits.astype(np.uint8)为例将其转换为该类型.


第一次try

然后,一种解决方案是使用np.packbits将输入二进制值尽可能多地打包到内存中的位数组中,然后执行逻辑and.但事实证明np.packbits是相当慢的.因此,这种解决方案最终不是一个好主意.事实上,任何创建形状类似于conflicts的临时数组的解决方案都会很慢,因为在内存中写入这样的数组通常比np.dot慢(np.dot从内存中读取conflicts一次).


使用麻木

由于np.dot经过了很好的优化,所以击败它的唯一解决方案是使用优化的本机代码.得益于即时编译器,Numba可以在运行时从基于Numpy的Python代码生成本机可执行代码.其思想是在每个块vectorconflicts行之间执行逻辑and.判断每个块的冲突,以便尽早停止计算.通过比较两个数组的uint64视图(以SIMD友好的方式),可以按8个八位字节的组高效地比较块.

import numba as nb

@nb.njit('bool_(uint8[::1], uint8[:,::1])')
def check_valid(vector, conflicts):
    n, m = conflicts.shape
    assert vector.size == m

    for i in range(n):
        block_size = 128 # In the range: 8,16,...,248
        conflicts_row = conflicts[i,:]
        gsum = 0 # Global sum of conflicts
        m_limit = m // block_size * block_size

        for j in range(0, m_limit, block_size):
            vector_block = vector[j:j+block_size].view(np.uint64)
            conflicts_block = conflicts_row[j:j+block_size].view(np.uint64)

            # Matching
            lsum = np.uint64(0) # 8 local sums of conflicts
            for k in range(block_size//8):
                lsum += vector_block[k] & conflicts_block[k]

            # Trick to perform the reduction of all the bytes in lsum
            lsum += lsum >> 32
            lsum += lsum >> 16
            lsum += lsum >> 8
            gsum += lsum & 0xFF

            # Check if there is a conflict
            if gsum >= 2:
                return False

        # Remaining part
        for j in range(m_limit, m):
            gsum += vector[j] & conflicts_row[j]

        if gsum >= 2:
            return False

    return True

后果

在我的机器上,对于形状(16, 65536)的大型conflicts数组(没有冲突),这大约是9 times fasternp.dot.两种情况下都不包括转换列表的时间.当存在冲突时,提供的解决方案会更快,因为它可以提前停止计算.

理论上,计算速度应该更快,但Numba JIT无法使用SIMD指令将循环矢量化.尽管如此,似乎同样的问题也出现在np.dot人身上.如果数组更大,则可以并行化块的计算(如果函数返回False,则以较慢的计算为代价).

Python相关问答推荐

如何根据另一列值用字典中的值替换列值

Python 约束无法解决n皇后之谜

如何使用html从excel中提取条件格式规则列表?

聚合具有重复元素的Python字典列表,并添加具有重复元素数量的新键

将pandas Dataframe转换为3D numpy矩阵

梯度下降:简化要素集的运行时间比原始要素集长

转换为浮点,pandas字符串列,混合千和十进制分隔符

Python逻辑操作作为Pandas中的条件

解决调用嵌入式函数的XSLT中表达式的语法移位/归约冲突

如何更改groupby作用域以找到满足掩码条件的第一个值?

将scipy. sparse矩阵直接保存为常规txt文件

如何在PySide/Qt QColumbnView中删除列

为什么调用函数的值和次数不同,递归在代码中是如何工作的?

Python—为什么我的代码返回一个TypeError

在Google Drive中获取特定文件夹内的FolderID和文件夹名称

Pandas—堆栈多索引头,但不包括第一列

在极点中读取、扫描和接收有什么不同?

如何在Python中自动创建数字文件夹和正在进行的文件夹?

当我定义一个继承的类时,我可以避免使用`metaclass=`吗?

极柱内丢失类型信息""