这让我大吃一惊,这是一个已知的错误还是我错过了什么?如果有漏洞,有办法绕过它吗?


假设我有一个相对较小的二进制(0/1)n x q scipy.sparse.csr_matrix,如下所示:

import numpy as np
from scipy import sparse

def get_dummies(vec, vec_max):
    vec_size = vec.size
    Z = sparse.csr_matrix((np.ones(vec_size), (np.arange(vec_size), vec)), shape=(vec_size, vec_max), dtype=np.uint8)
    return Z

q = 100
ns = np.round(np.random.random(q)*100).astype(np.int16)
Z_idx = np.repeat(np.arange(q), ns)
Z = get_dummies(Z_idx, q)
Z
<5171x100 sparse matrix of type '<class 'numpy.uint8'>'
    with 5171 stored elements in Compressed Sparse Row format>

这里Z是一个标准的虚拟变量矩阵,n=5171个观测值,q=Z个变量:

Z[:5, :5].toarray()
array([[1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [1, 0, 0, 0, 0]], dtype=uint8)

例如,如果前5个变量具有...

ns[:5]
array([21, 22, 37, 24, 99], dtype=int16)

频率,我们也会在Z‘S专栏总和中看到这一点:

Z[:, :5].sum(axis=0)
matrix([[21, 22, 37, 24, 99]], dtype=uint64)

现在,不出所料,如果我乘以Z.T @ Z,我应该得到一个Q x Q对角线矩阵,对角线上Q个变量的频率:

print((Z.T @ Z).shape)
print((Z.T @ Z)[:5, :5].toarray()
(100, 100)
[[21  0  0  0  0]
 [ 0 22  0  0  0]
 [ 0  0 37  0  0]
 [ 0  0  0 24  0]
 [ 0  0  0  0 99]]

Now for the bug:如果n真的很大(对我来说,它已经发生在n=Now for the bugK左右):

q = 1000
ns = np.round(np.random.random(q)*1000).astype(np.int16)
Z_idx = np.repeat(np.arange(q), ns)
Z = get_dummies(Z_idx, q)
Z
<495509x1000 sparse matrix of type '<class 'numpy.uint8'>'
    with 495509 stored elements in Compressed Sparse Row format>

频率大,百尺S栏目之和不出所料:

print(ns[:5])
Z[:, :5].sum(axis=0)
array([485, 756, 380,  87, 454], dtype=int16)
matrix([[485, 756, 380,  87, 454]], dtype=uint64)

Z.T @ Z美元却搞砸了!在某种意义上,我在对角线上没有得到正确的频率:

print((Z.T @ Z).shape)
print((Z.T @ Z)[:5, :5].toarray())
(1000, 1000)
[[229   0   0   0   0]
 [  0 244   0   0   0]
 [  0   0 124   0   0]
 [  0   0   0  87   0]
 [  0   0   0   0 198]]

令人惊讶的是,这与真实频率有一些关系:

import matplotlib.pyplot as plt

plt.scatter(ns, (Z.T @ Z).diagonal())
plt.xlabel('real frequencies')
plt.ylabel('values on ZZ diagonal')
plt.show()

enter image description here

怎么一回事?

我使用的是标准的CoLab:

import scipy as sc

print(np.__version__)
print(sc.__version__)
1.25.2
1.11.4

PS:显然,如果我只想要Z.T @ Z‘S输出矩阵,有更容易的方法得到它,这是一个非常简化的问题,谢谢.

推荐答案

您正在使用get_dummies中的uint8

print(Z.dtype, (Z.T @ Z)[:5, :5].dtype)
# uint8,  uint8

因此,结果是溢出的.这种模式是因为结果是真正的以256为模的结果.如果将dtype更改为uint16,则问题将消失.

sum提高了dtype的精度这一事实有点令人惊讶,但它是documented,并且与NumPy sum所做的一致.

Python相关问答推荐

SQLGory-file包FilField不允许提供自定义文件名,自动将文件保存为未命名

标题:如何在Python中使用嵌套饼图可视化分层数据?

不理解Value错误:在Python中使用迭代对象设置时必须具有相等的len键和值

在Python Attrs包中,如何在field_Transformer函数中添加字段?

如何制作10,000年及以后的日期时间对象?

在线条上绘制表面

在Python中管理打开对话框

如何在给定的条件下使numpy数组的计算速度最快?

Pandas:将多级列名改为一级

为什么抓取的HTML与浏览器判断的元素不同?

在极性中创建条件累积和

用砂箱开发Web统计分析

通过追加列表以极向聚合

如何在Great Table中处理inf和nans

使用__json__的 pyramid 在客户端返回意外格式

统计numpy. ndarray中的项目列表出现次数的最快方法

语法错误:文档. evaluate:表达式不是合法表达式

如何重新组织我的Pandas DataFrame,使列名成为列值?

如何从数据框列中提取特定部分并将该值填充到其他列中?

用0填充没有覆盖范围的垃圾箱