Let's say val is a matrix of size (2, N). I need to multiply it with a mask matrix mask of size
(K,K) containing values zeros and ones at different indexes. This should output a matrix result of size (N, K, K) matrix with each submatrix of result along the dimension 0 to be a (K,K) matrix where zeros are replaced by val(i,1) and ones are replaced by val(i,2).

例如,

mask = tf.constant([[0, 1, 0, 1],
                   [1, 0, 0, 1],
                   [1, 1, 1, 0],
                   [0, 1, 0, 0]], dtype=tf.int32)

val = tf.constant([[3, 2, 8, 1, 9, 5, 6], [7, 4, 9, 8, 3, 1, 9]])

那么输出应该是这样的7 x 4 x 4矩阵,

result  = 

tf.Tensor(
[[[ 3.  7.  3.  7.]
  [ 7.  3.  3.  7.]
  [ 7.  7.  7.  3.]
  [ 3.  7.  3.  3.]]

 [[ 2.  4.  2.  4.]
  [ 4.  2.  2.  4.]
  [ 4.  4.  4.  2.]
  [ 3.  4.  3.  2.]]
          :
          :
          :
 [[ 6.  9.  6.  9.]
  [ 9.  6.  6.  9.]
  [ 9.  9.  9.  6.]
  [ 6.  9.  6.  6.]]]

目前,我正在使用for循环迭代val的每一列,以执行以下操作

val[0,i]*mask + val[1,i]*(1-mask)

我希望将其矢量化,以纳入TensorFlow的矩阵乘法能力.

推荐答案

对于要用tensorflow执行的操作,您不需要显式的for循环.

编辑:我的代码中有一个错误

import tensorflow as tf

mask = tf.constant([[0, 1, 0, 1],
                    [1, 0, 0, 1],
                    [1, 1, 1, 0],
                    [0, 1, 0, 0]], dtype=tf.int32)

val = tf.constant([[3, 2, 8, 1, 9, 5, 6], [7, 4, 9, 8, 3, 1, 9]])

mask_broadcasted = tf.broadcast_to(tf.expand_dims(mask, axis=0), (val.shape[1], *mask.shape))
mask_inv_broadcasted = 1 - mask_broadcasted

val_expanded = tf.expand_dims(val, axis=-1)  # Add an extra dimension to val (expand the last dimension)
val_expanded_reshaped = tf.reshape(val_expanded, (2, -1, 1, 1))

result = val_expanded_reshaped[0] * mask_inv_broadcasted + val_expanded_reshaped[1] * mask_broadcasted

print(result)

Python相关问答推荐

Polars比较了两个预设-有没有方法在第一次不匹配时立即失败

对Numpy函数进行载体化

TARete错误:类型对象任务没有属性模型'

使用FASTCGI在IIS上运行Django频道

删除任何仅包含字符(或不包含其他数字值的邮政编码)的观察

如何标记Spacy中不包含特定符号的单词?

将输入管道传输到正在运行的Python脚本中

基于字符串匹配条件合并两个帧

如何使用Python以编程方式判断和检索Angular网站的动态内容?

python中字符串的条件替换

将JSON对象转换为Dataframe

Python脚本使用蓝牙运行在Windows 11与raspberry pi4

Python列表不会在条件while循环中正确随机化'

如何使用SentenceTransformers创建矢量嵌入?

旋转多边形而不改变内部空间关系

如何在Pyplot表中舍入值

统计numpy. ndarray中的项目列表出现次数的最快方法

设置索引值每隔17行左右更改的索引

使用美汤对维基百科表格进行网络刮擦未返回任何内容

如何通过特定导入在类中执行Python代码