我试图理解为什么我不能直接覆盖 torch 层的权重.

import torch
from torch import nn

net = nn.Linear(3, 1)
weights = torch.zeros(1,3)

# Overwriting does not work
net.state_dict()["weight"] = weights  # nothing happens
print(f"{net.state_dict()['weight']=}")

# But mutating does work
net.state_dict()["weight"][0] = weights  # indexing works
print(f"{net.state_dict()['weight']=}")

#########
# output
: net.state_dict()['weight']=tensor([[ 0.5464, -0.4110, -0.1063]])
: net.state_dict()['weight']=tensor([[0., 0., 0.]])

我很困惑,因为state_dict()["weight"]只是一个 torch 张量,所以我觉得我遗漏了一些非常明显的东西.

推荐答案

这是因为net.state_dict()首先创建collections.OrderedDict对象,然后将该模块的权重张量存储到该对象,并返回dict:

state_dict = net.state_dict()
print(type(state_dict))    # <class 'collections.OrderedDict'>

当你"覆盖"(实际上不是覆盖;在python中是assignment)这个有序dict时,你会给这个有序dict的键'weights'重新赋值一个int 0.这个张量中的数据没有被修改,它只是没有被有序dict引用.

当您判断张量是否由以下项修改时:

print(f"{net.state_dict()['weight']}")

创建了一个新的有序dict,它与您修改的dict不同,因此您可以看到未更改的张量.

但是,当您这样使用索引时:

net.state_dict()["weight"][0] = weights  # indexing works

那么它就不再是对有序dict的赋值了.相反,我们调用了张量的__setitem__方法,它允许您访问和修改底层内存.其他张量API(如copy_)也可以实现预期结果.

a是张量/数组时,a = ba[:] = b的差异可以在这里找到明确的解释:https://stackoverflow.com/a/68978622/11790637

Python相关问答推荐

' osmnx.shortest_track '返回有效源 node 和目标 node 的'无'

如何使用数组的最小条目拆分数组

从spaCy的句子中提取日期

递归访问嵌套字典中的元素值

如何保持服务器发送的事件连接活动?

考虑到同一天和前2天的前2个数值,如何估算电力时间序列数据中的缺失值?

旋转多边形而不改变内部空间关系

Python Tkinter为特定样式调整所有ttkbootstrap或ttk Button填充的大小,适用于所有主题

Maya Python脚本将纹理应用于所有对象,而不是选定对象

Python Pandas—时间序列—时间戳缺失时间精确在00:00

如何获取Python synsets列表的第一个内容?

巨 Python :逆向猜谜游戏

Django Table—如果项目是唯一的,则单行

需要帮助使用Python中的Google的People API更新联系人的多个字段'

应用指定的规则构建数组

根据过滤后的牛郎星图表中的数据计算新系列

合并Pandas中的数据帧,但处理不存在的列

在Pandas 中,有没有办法让元组作为索引运行得很好?

如何导入与我试图从该目录之外运行的文件位于同一目录中的Python文件?

如何判断特定的OPC UA node 是否已经存在Asyncua?