我有一个相对简单的要求,但令人惊讶的是,这似乎并不是在pytorch中直接实现的.假设一个带有$P$参数的神经网络输出长度为$Y$的矢量和一批$B$数据输入,我想计算输出相对于模型参数的梯度.

换句话说,我想要以下函数:

def calculate_gradients(model, X):
    """
    Args:
        nn module with P parameters in total that outputs a tensor of size (B, Y).
        torch tensor of shape (B, .).

    Returns:
        torch tensor of shape (B, Y, P)
    """
    # function logic here

不幸的是,我目前还看不到一种明显的有效计算方法,特别是在不对数据或目标维度进行聚合的情况下.下面的一个最小的工作示例涉及在输入和目标维度上循环,但肯定有更有效的方法?

import torch
from torchvision import datasets, transforms
import torch.nn as nn

###### SETUP ######

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        h = self.fc1(x)
        pred = self.fc2(self.relu(h))
        return pred
    
train_dataset = datasets.MNIST(root='./data', train=True, download=True, 
                            transform=transforms.Compose(
                                [transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))
        ]))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)

X, y = next(iter(train_dataloader))  # take a random batch of data

net = MLP(28*28, 20, 10)  # define a network


###### CALCULATE GRADIENTS ######
def calculate_gradients(model, X):
    # Create a tensor to hold the gradients
    gradients = torch.zeros(X.shape[0], 10, sum(p.numel() for p in model.parameters()))

    # Calculate the gradients for each input and target dimension
    for i in range(X.shape[0]):
        for j in range(10):
            model.zero_grad()
            output = model(X[i])
            # Calculate the gradients
            grads = torch.autograd.grad(output[j], model.parameters())
            # Flatten the gradients and store them
            gradients[i, j, :] = torch.cat([g.view(-1) for g in grads])
            
    return gradients

grads = calculate_gradients(net, X.view(X.shape[0], -1))

Edit:个 我对Felix Zimmermann的解决方案运行了一些快速基准测试,它确实为我的机器上的这个玩具问题提供了一些很好的加速.

import time

start = time.time()
for _ in range(1000):
    grads = calculate_gradients(net, X.view(X.shape[0], -1))
end = time.time()
print('Loop solution', end - start)

start = time.time()
for _ in range(1000):
    params = {k: v.detach() for k, v in net.named_parameters()}
    buffers = {k: v.detach() for k, v in net.named_buffers()}
    grads2 = torch.vmap(one_sample)(X.flatten(1))
end = time.time()
print('Vmap solution', end - start)

哪一项输出

Loop solution 8.408899307250977
Vmap solution 2.355229139328003

请注意,在更现实的环境中,GPU上的批次更大,性能提升可能会更大.

推荐答案

要解决这个问题,我们需要三个 idea :

这都是第functorch/torch.func部分.

将所有这些放在一起,它的作用与您的代码相同:

# extract the parameters and buffers for a funcional call
params = {k: v.detach() for k, v in net.named_parameters()}
buffers = {k: v.detach() for k, v in net.named_buffers()}

def one_sample(sample):
    # this will calculate the gradients for a single sample
    # we want the gradients for each output wrt to the parameters
    # this is the same as the jacobian of the network wrt the parameters

    # define a function that takes the as input returns the output of the network
    call = lambda x: torch.func.functional_call(net, (x, buffers), sample)
    
    # calculate the jacobian of the network wrt the parameters
    J = torch.func.jacrev(call)(params)
    
    # J is a dictionary with keys the names of the parameters and values the gradients
    # we want a tensor
    grads = torch.cat([v.flatten(1) for v in J.values()],-1) 
    return grads

# no we can use vmap to calculate the gradients for all samples at once
grads2 = torch.vmap(one_sample)(X.flatten(1))

print(torch.allclose(grads,grads2))

should并行运行,你应该在更大的型号上try 一下,我没有对它进行基准测试.

这也与Pytorch: Gradient of output w.r.t parameters(tbh没有很好的答案)和pytorch.org/tutorials/intermediate/per_sample_grads.html有关,后者显示了torch.func中用于计算每个样本的梯度的一些函数.

Python相关问答推荐

运行回文查找器代码时发生错误:[类型错误:builtin_index_or_system对象不可订阅]

使用FASTCGI在IIS上运行Django频道

try 在树叶 map 上应用覆盖磁贴

运行总计基于多列pandas的分组和总和

2D空间中的反旋算法

按顺序合并2个词典列表

通过pandas向每个非空单元格添加子字符串

如何根据一列的值有条件地 Select 前N个组,然后按两列分组?

Python逻辑操作作为Pandas中的条件

在matplotlib中删除子图之间的间隙_mosaic

如果初始groupby找不到满足掩码条件的第一行,我如何更改groupby列,以找到它?

网格基于1.Y轴与2.x轴显示在matplotlib中

Python—转换日期:价目表到新行

比Pandas 更好的 Select

如何使用加速广播主进程张量?

如何获得满足掩码条件的第一行的索引?

设置索引值每隔17行左右更改的索引

我可以同时更改多个图像吗?

Pandas 数据框自定义排序功能

`Convert_time_zone`函数用于根据为极点中的每一行指定的时区检索值