最近我一直在开发一个能够处理带维度张量的函数:

Torch.Size([51,265,23,23])

其中第一个是时间,第二个是图案,最后2个是图案大小.

每个单独的模式最多可以有3个状态:[-1,0,1],并且它被认为是"活的" 与此同时,在所有其他情况下,模式都是"死亡"的.

我的目标是通过判断张量的最后一行(最后一个时间步)来过滤所有死模式.

My current implementation (that works) is:

def filter_patterns(tensor_sims):

   # Get the indices of the columns that need to be kept
   keep_indices = torch.tensor([i for i in 
   range(tensor_sims.shape[1]) if 
   tensor_sims[-1,i].unique().numel() == 3])

   # Keep only the columns that meet the condition
   tensor_sims = tensor_sims[:, keep_indices]

   print(f'Number of patterns: {tensor_sims.shape[1]}')
   return tensor_sims

不幸的是,我无法摆脱for循环.

我try 使用torch.unique()函数,并使用参数dim,我try 减少张量的维度并拉平,但没有任何效果.

Found Solution (thanks to the answer):

def filter_patterns(tensor_sims):
   # Flatten the spatial dimensions of the last timestep
   x_ = tensor_sims[-1].flatten(1)

   # Create masks to identify -1, 0, and 1 conditions
   mask_minus_one = (x_ == -1).any(dim=1)
   mask_zero = (x_ == 0).any(dim=1)
   mask_one = (x_ == 1).any(dim=1)

   # Combine the masks using logical_and
   mask = 
   mask_minus_one.logical_and(mask_zero).logical_and(mask_one)

   # Keep only the columns that meet the condition
   tensor_sims = tensor_sims[:, mask]

   print(f'Number of patterns: {tensor_sims.shape[1]}')
   return tensor_sims

新的实现速度非常快.

推荐答案

我不相信你能逃脱torch.unique个惩罚,因为它对每列都不起作用.您可以构造三个屏蔽张量来分别判断-101个值,而不是迭代dim=1.要计算结果的列屏蔽,在组合屏蔽时可以摆脱一些基本逻辑:

考虑到您只判断最后一个时间步,请关注该时间步并拉平空间维度:

x_ = x[-1].flatten(1)

用于识别条件-101的三个口罩可以分别通过:x_ == -1x_ == 0x_ == 1获得.将它们与torch.logical_or合并

mask = (x_ == -1).logical_or(x_ == 0).logical_or(x_ == 1)

最后,判断行中所有元素是否为True:

keep_indices = mask.all(dim=1)

Python相关问答推荐

滚动和,句号来自Pandas列

删除任何仅包含字符(或不包含其他数字值的邮政编码)的观察

Python中的嵌套Ruby哈希

修复mypy错误-赋值中的类型不兼容(表达式具有类型xxx,变量具有类型yyy)

将9个3x3矩阵按特定顺序排列成9x9矩阵

SQLAlchemy Like ALL ORM analog

Pandas—在数据透视表中占总数的百分比

在单个对象中解析多个Python数据帧

SQLAlchemy bindparam在mssql上失败(但在mysql上工作)

无法连接到Keycloat服务器

让函数调用方程

解决调用嵌入式函数的XSLT中表达式的语法移位/归约冲突

pandas fill和bfill基于另一列中的条件

Beautifulsoup:遍历一个列表,从a到z,并解析数据,以便将其存储在pdf中.

Python 3试图访问在线程调用中实例化的类的对象

根据Pandas中带条件的两个列的值创建新列

Pandas在rame中在组内洗牌行,保持相对组的顺序不变,

启动线程时,Python键盘模块冻结/不工作

使用xlsxWriter在EXCEL中为数据帧的各行上色

当lambda函数作为参数传递时,pyo3执行