最近,PyTorch推出了nested tensor.然而,如果我创建一个嵌套的张量,例如,

import torch

a = torch.randn(20, 128)
nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)

然后看看它的类类型,它显示:

type(nt)
torch.Tensor

即,类类型只是一个普通的PyTorch Tensor.因此,type(nt) == torch.Tensorisinstance(nt, torch.Tensor)都将返回True.

所以,我的问题是,有没有办法区分正则张量和嵌套张量?

我能想到的一种方式是,嵌套张量的size方法(目前)与常规张量的工作方式不同,因为它需要一个参数,否则它会引发RuntimeError.因此,解决方案可能是:

def is_nested_tensor(nt):
    if not isinstance(nt, torch.Tensor):
        return False

    try:
        # try calling size without an argument
        nt.size()
        return False
    except RuntimeError:
        return True

    return False

但是,有没有更简单的方法,不依赖于像size方法这样的方法在future 不会改变?

推荐答案

所有torch.Tensor都有一个名为is_nested的属性,但遗憾的是,它没有被记录在案.只有FAQ个人提到过

> nt.is_nested
True
> a.is_nested
False

Python相关问答推荐

在Python中处理大量CSV文件中的数据

numba jitClass,记录类型为字符串

删除任何仅包含字符(或不包含其他数字值的邮政编码)的观察

如何在Windows上用Python提取名称中带有逗号的文件?

pandas滚动和窗口中有效观察的最大数量

发生异常:TclMessage命令名称无效.!listbox"

python中字符串的条件替换

Asyncio:如何从子进程中读取stdout?

什么是最好的方法来切割一个相框到一个面具的第一个实例?

* 动态地 * 修饰Python中的递归函数

ConversationalRetrivalChain引发键错误

从嵌套极轴列的列表中删除元素

在Django中重命名我的表后,旧表中的项目不会被移动或删除

获取git修订版中每个文件的最后修改时间的最有效方法是什么?

如何防止html代码出现在quarto gfm报告中的pandas表之上

SpaCy:Regex模式在基于规则的匹配器中不起作用

将数字数组添加到Pandas DataFrame的单元格依赖于初始化

两个名称相同但值不同的 Select 都会产生相同的值(discord.py)

Pandas ,快速从词典栏中提取信息到新栏

是否将列表分割为2?