我对下面代码片段中的方法view()感到困惑.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool  = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

我的困惑在于下面这句话.

x = x.view(-1, 16*5*5)

tensor.view()函数的作用是什么?我在很多地方见过它的用法,但我不明白它是如何解释其参数的.

如果我将负值作为view()函数的参数,会发生什么?例如,如果我打tensor_variable.view(1, 1, -1)会怎么样?

有没有人能举例说明view()函数的主要原理?

推荐答案

视图函数旨在reshape 张量.

假设你有一个张量

import torch
a = torch.range(1, 16)

a是一个张量,有16个元素,从1到16(包括).如果你想reshape 这个张量,使之成为4 x 4张量,那么你可以使用

a = a.view(4, 4)

现在a将是4 x 4张量.Note that after the reshape the total number of elements need to remain the same. Reshaping the tensor 100 to a 103 tensor would not be appropriate.

参数-1的含义是什么?

如果您不知道需要多少行,但确定列的数量,那么可以使用-1指定.(Note that you can extend this to tensors with more dimensions. Only one of the axis value can be -1). 这是告诉图书馆的一种方式:"给我一个有这么多列的张量,你计算出实现这一点所需的适当行数.".

这可以在您上面给出的神经网络代码中看到.在FORWARD函数中的第x = self.pool(F.relu(self.conv2(x)))行之后,您将拥有一个16深度的特征图.您必须将其展平,才能将其提供给完全连接的图层.因此,您告诉pytorchreshape 您获得的张量,使其具有特定的列数,并告诉它自己决定行数.

在numpy和pytorch之间绘制了一个相似之处,view类似于numpy的reshape函数.

Python相关问答推荐

错误:找不到TensorFlow / Cygwin的匹配分布

Asyncio与队列的多处理通信-仅运行一个协程

有什么方法可以修复奇怪的y轴Python matplotlib图吗?

Pandas read_jsonfuture 警告:解析字符串时,to_datetime与单位的行为已被反对

了解shuffle在NP.random.Generator.choice()中的作用

如何在超时的情况下同步运行Matplolib服务器端?该过程随机挂起

如何使用Tkinter创建两个高度相同的框架(顶部和底部)?

在for循环中仅执行一次此操作

Pandas 除以一列中出现的每个值

添加包含中具有任何值的其他列的计数的列

如何才能知道Python中2列表中的巧合.顺序很重要,但当1个失败时,其余的不应该失败或是0巧合

ModuleNotFound错误:没有名为Crypto Windows 11、Python 3.11.6的模块

Python上的Instagram API:缺少client_id参数"

C#使用程序从Python中执行Exec文件

如何将一个动态分配的C数组转换为Numpy数组,并在C扩展模块中返回给Python

如何创建一个缓冲区周围的一行与manim?

Python—从np.array中 Select 复杂的列子集

在ubuntu上安装dlib时出错

在Python中调用变量(特别是Tkinter)

使用Python查找、替换和调整PDF中的图像'