我有两个 torch 张量
mask = torch.ones(1024, 64, dtype=torch.float32)
indices = torch.randint(0, 64, (1024, ))
对于mask
中的每i
行,我希望将indices
的第i
个元素指定的索引之后的所有元素设置为零.例如,如果indices
的第一个元素是50
,那么我想设置mask[0, 50:]=0
.可以在不使用for循环的情况下实现这一点吗?
使用for循环的解决方案:
for i in range(mask.shape[0]):
mask[i, indices[i]:] = 0