最近我一直在开发一个能够处理带维度张量的函数:
Torch.Size([51,265,23,23])
其中第一个是时间,第二个是图案,最后2个是图案大小.
每个单独的模式最多可以有3个状态:[-1,0,1],并且它被认为是"活的" 与此同时,在所有其他情况下,模式都是"死亡"的.
我的目标是通过判断张量的最后一行(最后一个时间步)来过滤所有死模式.
My current implementation (that works) is:
def filter_patterns(tensor_sims):
# Get the indices of the columns that need to be kept
keep_indices = torch.tensor([i for i in
range(tensor_sims.shape[1]) if
tensor_sims[-1,i].unique().numel() == 3])
# Keep only the columns that meet the condition
tensor_sims = tensor_sims[:, keep_indices]
print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
不幸的是,我无法摆脱for循环.
我try 使用torch.unique()函数,并使用参数dim,我try 减少张量的维度并拉平,但没有任何效果.
Found Solution (thanks to the answer):
def filter_patterns(tensor_sims):
# Flatten the spatial dimensions of the last timestep
x_ = tensor_sims[-1].flatten(1)
# Create masks to identify -1, 0, and 1 conditions
mask_minus_one = (x_ == -1).any(dim=1)
mask_zero = (x_ == 0).any(dim=1)
mask_one = (x_ == 1).any(dim=1)
# Combine the masks using logical_and
mask =
mask_minus_one.logical_and(mask_zero).logical_and(mask_one)
# Keep only the columns that meet the condition
tensor_sims = tensor_sims[:, mask]
print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
新的实现速度非常快.