在什么逻辑中组合多维张量(例如,在pytorch中)? 假设我有形状为(64,1,1,42)的张量A和形状为(1,42,42)的张量B. A&A;B的结果是什么?我如何提前确定结果形状(如果可能)?

推荐答案

不同形状的张量的组合通常由一种称为广播的机制来完成.广播在NumPy和PyTorch等库中是一个强大的工具,它有助于在不实际复制数据的情况下扩展张量的大小,从而在不同形状的张量之间执行操作.NumPy和PyTorch之间的广播规则通常是一致的.

广播规则: 如果张量的维度数不匹配,则在维度较少的张量维度前面加上1.

比较每个维度的大小:

如果大小匹配或其中一个大小为1,则可以对该维度进行广播. 如果两个大小都不是1并且大小不匹配,则广播失败. 在成功广播之后,每个张量的行为就好像它的形状是两个输入张量的形状的元素最大值一样.

在任何维度中,其中一个张量的大小为1,而另一个张量的大小大于1,则较小的张量的行为就像它已被展开以匹配较大的张量的大小.例如:

已给予:

形状张量A(64,1,1,42) 形状张量B(1,42,42)

要将它们组合在一起,请执行以下操作:

使尺寸的数量相等:

张量A:(64,1,1,42) 张量B:(1,1,42,42) 比较尺寸:

维度为(64,1,1,42)和(1,1,42,42) 每个维度要么相同,要么其中一个是1,因此可以进行广播.得到的张量将具有如下形状:(64,1,42,42)

import torch

# Creating example tensors
A = torch.rand((64, 1, 1, 42))
B = torch.rand((1, 42, 42))

# Broadcasting operation
result = A * B

# Outputting the resulting shape
print(result.shape)  

这导致了

torch.Size([64, 1, 42, 42])

不出所料

编辑:要解决对此答案的第一个 comments :

谢谢!所以每个维度必须是一维,或者两个张量的维度必须匹配.如果前面提到的判断失败,重新安排是否也默认进行,但重新安排会有所帮助?

正确,在PyTorch中使用广播来组合两个张量.两个张量的每个维度要么为1,要么匹配. PyTorch从尾部维度(最后一个维度)开始比较维度,然后向后进行.

如果不满足广播标准,则会出现运行时错误.

但是,如果广播标准失败,您可以重新排列或reshape 张量以使它们兼容:

Adding Singleton Dimensions:您可以使用非挤压方法或NumPy样式的None索引.例如,如果你有一个形状张量(5,4),并且你想要它有一个形状(5,1,4),你可以添加一个单独的维度.

tensor = tensor.unsqueeze(1) 
# or
tensor = tensor[:, None, :]

Removing Singleton Dimensions:可以使用挤压方法删除单个尺寸标注.例如,将张量从形状(5,1,4)更改为(5,4).

tensor = tensor.squeeze(1)

Reshape:如果添加或移除单一维度还不够,您可能需要使用reshape 方法完全reshape 张量.

tensor = tensor.reshape(new_shape)

Permute Dimensions:如果问题出在维度的顺序上,则可以使用置换方法重新排列它们.

tensor = tensor.permute(dimension_order)

在重新排列或reshape 张量时,了解数据的语义是至关重要的.通过更改形状,您可能会更改张量的含义或其值之间的关系.始终确保这样的操作在您的问题背景下是有意义的.

Python相关问答推荐

使用Ubuntu、Python和Weasyprint的Docker文件-venv的问题

Locust请求中的Python和参数

连接两个具有不同标题的收件箱

Streamlit应用程序中的Plotly条形图中未正确显示Y轴刻度

如何创建一个缓冲区周围的一行与manim?

Pre—Commit MyPy无法禁用非错误消息

python中字符串的条件替换

如何根据一列的值有条件地 Select 前N个组,然后按两列分组?

如何在Python中找到线性依赖mod 2

不允许访问非IPM文件夹

计算分布的标准差

为什么numpy. vectorize调用vectorized函数的次数比vector中的元素要多?

合并与拼接并举

Python—压缩叶 map html作为邮箱附件并通过sendgrid发送

从一个df列提取单词,分配给另一个列

查看pandas字符列是否在字符串列中

需要帮助使用Python中的Google的People API更新联系人的多个字段'

启动线程时,Python键盘模块冻结/不工作

是否将Pandas 数据帧标题/标题以纯文本格式转换为字符串输出?

如何在表单中添加管理员风格的输入(PDF)