我试图理解为什么我不能直接覆盖 torch 层的权重.
import torch
from torch import nn
net = nn.Linear(3, 1)
weights = torch.zeros(1,3)
# Overwriting does not work
net.state_dict()["weight"] = weights # nothing happens
print(f"{net.state_dict()['weight']=}")
# But mutating does work
net.state_dict()["weight"][0] = weights # indexing works
print(f"{net.state_dict()['weight']=}")
#########
# output
: net.state_dict()['weight']=tensor([[ 0.5464, -0.4110, -0.1063]])
: net.state_dict()['weight']=tensor([[0., 0., 0.]])
我很困惑,因为state_dict()["weight"]
只是一个 torch 张量,所以我觉得我遗漏了一些非常明显的东西.