我有一个形状为[7, 7, 2, 4]tensor A和一个形状为[7, 7]tensor B.

Tensor Btensor Aargmax,它的值是0,1.

我想从A和B得到形状为[7, 7, 4][7, 7, 1, 4]tensor C分.

规则是,tensor B(i, j)个元素是tensor A的第二个维度的索引.

我怎么才能快点做呢?我试着在A[B]之前得到C,但不起作用.有谁能帮我吗?谢谢.

推荐答案

好吧,我用了tf.gather_nd来解决这个问题:

tensor_C = tf.gather_nd(tensor_A, tf.expand_dims(tf.argmax(tensor_B, 2), 2), batch_dims=3)

Python相关问答推荐

Python 3.12中的通用[T]类方法隐式类型检索

线性模型PanelOLS和statmodels OLS之间的区别

如何使用pandasDataFrames和scipy高度优化相关性计算

如何使用Python将工作表从一个Excel工作簿复制粘贴到另一个工作簿?

对整个 pyramid 进行分组与对 pyramid 列子集进行分组

在Pandas DataFrame操作中用链接替换'方法的更有效方法

将图像拖到另一个图像

在pandas中使用group_by,但有条件

不允许访问非IPM文件夹

如何防止Pandas将索引标为周期?

如何创建引用列表并分配值的Systemrame列

Numpyro AR(1)均值切换模型抽样不一致性

polars:有效的方法来应用函数过滤列的字符串

Python将一个列值分割成多个列,并保持其余列相同

应用指定的规则构建数组

使用SQLAlchemy从多线程Python应用程序在postgr中插入多行的最佳方法是什么?'

如何在Pandas中用迭代器求一个序列的平均值?

如何在Python中画一个只能在对角线内裁剪的圆?

Pandas:根据相邻行之间的差异过滤数据帧

为什么fizzbuzz在两个数字的条件出现在一个数字的条件之后时不起作用?