最近,PyTorch推出了nested tensor.然而,如果我创建一个嵌套的张量,例如,
import torch
a = torch.randn(20, 128)
nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
然后看看它的类类型,它显示:
type(nt)
torch.Tensor
即,类类型只是一个普通的PyTorch Tensor
.因此,type(nt) == torch.Tensor
和isinstance(nt, torch.Tensor)
都将返回True
.
所以,我的问题是,有没有办法区分正则张量和嵌套张量?
我能想到的一种方式是,嵌套张量的size
方法(目前)与常规张量的工作方式不同,因为它需要一个参数,否则它会引发RuntimeError
.因此,解决方案可能是:
def is_nested_tensor(nt):
if not isinstance(nt, torch.Tensor):
return False
try:
# try calling size without an argument
nt.size()
return False
except RuntimeError:
return True
return False
但是,有没有更简单的方法,不依赖于像size
方法这样的方法在future 不会改变?