我想计算大于固定阈值(这里是0.8)的元素的平均值.

麻木中的任务:

X = np.array([[0.11,0.99,0.70]])
print(np.nanmean(X[X>0.8]))
Out : 0.99

在不将张量c转换为数值数组的情况下,TensorFlow中的类似功能是什么?

示例:

c = tf.constant([[0.11,0.99,0.70]])
tf.reduce_mean(tf.where(tf.greater(c,(tf.constant(0.8, dtype=tf.float32)))))

输出等于0!

Output : <tf.Tensor: shape=(), dtype=int64, numpy=0>

推荐答案

你不需要tf.greatertf.where.

c = tf.constant([[0.11,0.99,0.70]])
# Or : tf.reduce_mean(c[c > 0.8])
tf.reduce_mean(c[c > tf.constant(0.8, dtype=tf.float32)])

您可以使用tf.boolean_mask作为替代选项:

c = tf.constant([[0.11,0.99,0.70]])
mask = c > 0.8
tf.reduce_mean(tf.boolean_mask(tensor, mask))

输出:<tf.Tensor: shape=(), dtype=float32, numpy=0.99>

Python相关问答推荐

DuckDB将蜂巢分区插入拼花文件

根据给定日期的状态过滤查询集

根据不同列的值在收件箱中移动数据

SQLGory-file包FilField不允许提供自定义文件名,自动将文件保存为未命名

在Polars(Python库)中将二进制转换为具有非UTF-8字符的字符串变量

' osmnx.shortest_track '返回有效源 node 和目标 node 的'无'

如何使用数组的最小条目拆分数组

如何获得每个组的时间戳差异?

如何创建一个缓冲区周围的一行与manim?

使用密钥字典重新配置嵌套字典密钥名

如何并行化/加速并行numba代码?

如何启动下载并在不击中磁盘的情况下呈现响应?

替换现有列名中的字符,而不创建新列

比Pandas 更好的 Select

如何在Python Pandas中填充外部连接后的列中填充DDL值

如何合并具有相同元素的 torch 矩阵的行?

在电影中向西北方向对齐""

如何在Python中将超链接添加到PDF中每个页面的顶部?

如何在Polars中创建条件增量列?

如何在Quarto中的标题页之前创建序言页