我有一个相对简单的要求,但令人惊讶的是,这似乎并不是在pytorch中直接实现的.假设一个带有$P$参数的神经网络输出长度为$Y$的矢量和一批$B$数据输入,我想计算输出相对于模型参数的梯度.
换句话说,我想要以下函数:
def calculate_gradients(model, X):
"""
Args:
nn module with P parameters in total that outputs a tensor of size (B, Y).
torch tensor of shape (B, .).
Returns:
torch tensor of shape (B, Y, P)
"""
# function logic here
不幸的是,我目前还看不到一种明显的有效计算方法,特别是在不对数据或目标维度进行聚合的情况下.下面的一个最小的工作示例涉及在输入和目标维度上循环,但肯定有更有效的方法?
import torch
from torchvision import datasets, transforms
import torch.nn as nn
###### SETUP ######
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
h = self.fc1(x)
pred = self.fc2(self.relu(h))
return pred
train_dataset = datasets.MNIST(root='./data', train=True, download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]))
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)
X, y = next(iter(train_dataloader)) # take a random batch of data
net = MLP(28*28, 20, 10) # define a network
###### CALCULATE GRADIENTS ######
def calculate_gradients(model, X):
# Create a tensor to hold the gradients
gradients = torch.zeros(X.shape[0], 10, sum(p.numel() for p in model.parameters()))
# Calculate the gradients for each input and target dimension
for i in range(X.shape[0]):
for j in range(10):
model.zero_grad()
output = model(X[i])
# Calculate the gradients
grads = torch.autograd.grad(output[j], model.parameters())
# Flatten the gradients and store them
gradients[i, j, :] = torch.cat([g.view(-1) for g in grads])
return gradients
grads = calculate_gradients(net, X.view(X.shape[0], -1))
Edit:个 我对Felix Zimmermann的解决方案运行了一些快速基准测试,它确实为我的机器上的这个玩具问题提供了一些很好的加速.
import time
start = time.time()
for _ in range(1000):
grads = calculate_gradients(net, X.view(X.shape[0], -1))
end = time.time()
print('Loop solution', end - start)
start = time.time()
for _ in range(1000):
params = {k: v.detach() for k, v in net.named_parameters()}
buffers = {k: v.detach() for k, v in net.named_buffers()}
grads2 = torch.vmap(one_sample)(X.flatten(1))
end = time.time()
print('Vmap solution', end - start)
哪一项输出
Loop solution 8.408899307250977
Vmap solution 2.355229139328003
请注意,在更现实的环境中,GPU上的批次更大,性能提升可能会更大.