我希望创建一个自定义池层,可以有效地在GPU上工作.

例如,我有以下输入张量

in = <tf.Tensor: shape=(4, 5), dtype=float32, numpy=
array([[0., 1., 2., 3., 4.],
       [5., 1., 7., 3., 2.],
       [9., 9., 2., 3., 5.],
       [2., 6., 2., 8., 4.]], dtype=float32)>

我希望提供一个列号列表,我希望对其执行池,例如,我希望对以下列索引执行最大池

pool_cols =  
[<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0, 1], dtype=int32)>,
 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 3, 4], dtype=int32)>]

最终的合并输出将如下所示

pooled_out = <tf.Tensor: shape=(4, 2), dtype=float32, numpy=
array([[1., 4.],
       [5., 7.],
       [9., 5.],
       [6., 8.]], dtype=float32)>

最有效的方法是什么?

推荐答案

IIUC,你可以用tf个操作来try 类似的东西,但我不确定在GPU上会有多高效:

import tensorflow as tf

tensor = tf.constant([[0., 1., 2., 3., 4.],
                      [5., 1., 7., 3., 2.],
                      [9., 9., 2., 3., 5.],
                      [2., 6., 2., 8., 4.]])


pool_cols = [tf.constant([0, 1]), tf.constant([2, 3, 4])]

def column_max_pooling(tensor, pool_cols):
  results = []
  tensor_shape = tf.shape(tensor)
  for col in pool_cols:
    col_shape = tf.shape(col)
    t = tf.gather_nd(tensor, tf.transpose(tf.stack([tf.tile(tf.range(tensor_shape[0]), [col_shape[0]]), tf.repeat(col, [tensor_shape[0]])])))
    t = tf.reduce_max(tf.transpose(tf.reshape(t, (col_shape[0], tensor_shape[0]))), axis=-1, keepdims=True)
    results.append(t)
  return tf.concat(results, axis=-1)

print(column_max_pooling(tensor, pool_cols))
tf.Tensor(
[[1. 4.]
 [5. 7.]
 [9. 5.]
 [6. 8.]], shape=(4, 2), dtype=float32)

如果您可以保证订单为pool_cols,您也可以try 使用tf.math.unsorted_segment_max:

import tensorflow as tf

tensor = tf.constant([[0., 1., 2., 3., 4.],
                      [5., 1., 7., 3., 2.],
                      [9., 9., 2., 3., 5.],
                      [2., 6., 2., 8., 4.]])

pool_cols = [tf.constant([0, 1]), tf.constant([2, 3, 4])]
result = tf.transpose(tf.math.unsorted_segment_max(tf.transpose(tensor), tf.concat([tf.repeat(idx, tf.shape(col)[0])for idx, col in enumerate(pool_cols)], axis=0), num_segments=len(pool_cols)))
print(result)
tf.Tensor(
[[1. 4.]
 [5. 7.]
 [9. 5.]
 [6. 8.]], shape=(4, 2), dtype=float32)

Python相关问答推荐

未删除映射表的行

对于一个给定的数字,找出一个整数的最小和最大可能的和

如何从在虚拟Python环境中运行的脚本中运行需要宿主Python环境的Shell脚本?

Python解析整数格式说明符的规则?

如何调整QscrollArea以正确显示内部正在变化的Qgridlayout?

把一个pandas文件夹从juyter笔记本放到堆栈溢出问题中的最快方法?

如何设置视频语言时上传到YouTube与Python API客户端

如何在表中添加重复的列?

在方法中设置属性值时,如何处理语句不可达[Unreacable]";的问题?

(Python/Pandas)基于列中非缺失值的子集DataFrame

OpenCV轮廓.很难找到给定图像的所需轮廓

如何在Python Pandas中填充外部连接后的列中填充DDL值

Python类型提示:对于一个可以迭代的变量,我应该使用什么?

使用Python TCP套接字发送整数并使用C#接收—接收正确数据时出错

Pandas:将值从一列移动到适当的列

将数字数组添加到Pandas DataFrame的单元格依赖于初始化

在matplotlib中重叠极 map 以创建径向龙卷风图

上传文件并使用Panda打开时的Flask 问题

为什么在不先将包作为模块导入的情况下相对导入不起作用

如何删除剪裁圆的对角线的外部部分