我有一个(N, T, d)形状的数组x.我有两个函数fg,它们都采用形状数组(some_dimension, d)并返回形状数组(some_dimension, ).

我想把所有的x都算f.这很简单:f(x.reshape(-1, d)).

然后我只想在第二个维度的第一个切片上计算g,也就是g(x[:, 0, :]),然后将它减go 所有维度的fON的值.代码中举例说明了这一点

MWe--效率低下的方式

import numpy as np

# Reproducibility
seed = 1234
rng = np.random.default_rng(seed=seed)

# Generate x
N = 100
T = 10
d = 2
x = rng.normal(loc=0.0, scale=1.0, size=(N, T, d))

# In practice the functions are not this simple
def f(x):
    return x[:, 0] + x[:, 1]

def g(x):
    return x[:, 0]**2 - x[:, 1]**2

# Compute f on all the (flattened) array
fx = f(x.reshape(-1, d)).reshape(N, T)

# Compute g only on the first slice of second dimension. Here are two ways of doing so
gx = np.tile(g(x[:, 0])[:, None], reps=(1, T))
gx = np.repeat(g(x[:, 0]), axis=0, repeats=T).reshape(N, T)

# Finally compute what I really want to compute
diff = fx - gx

有没有更有效的方法?我觉得一定要用广播,但我想不通.

推荐答案

减小示例的大小,以便我们可以判断(5,4)个数组:

In [138]: 
     ...: # Generate x
     ...: N = 5
     ...: T = 4
     ...: d = 2
     ...: x = np.arange(40).reshape(N,T,d) #(rng.normal(loc=0.0, scale=1.0, size=(N, T, d))
     ...: 
     ...: # In practice the functions are not this simple
     ...: def f(x):
     ...:     return x[:, 0] + x[:, 1]
     ...: 
     ...: def g(x):
     ...:     return x[:, 0]**2 - x[:, 1]**2
     ...: 
     ...: # Compute f on all the (flattened) array
     ...: fx = f(x.reshape(-1, d)).reshape(N, T)
     ...: 
     ...: # Compute g only on the first slice of second dimension. Here are two ways of doing so
     ...: gx1 = np.tile(g(x[:, 0])[:, None], reps=(1, T))
     ...: gx2 = np.repeat(g(x[:, 0]), axis=0, repeats=T).reshape(N, T)

In [139]: fx.shape,gx1.shape,gx2.shape
Out[139]: ((5, 4), (5, 4), (5, 4))

fx的所有元素都是不同的,所以不可能再有进一步的"广播".

In [140]: fx
Out[140]: 
array([[ 1,  5,  9, 13],
       [17, 21, 25, 29],
       [33, 37, 41, 45],
       [49, 53, 57, 61],
       [65, 69, 73, 77]])

你使用tilerepeat做同样的事情.tile使用repeat,因此不会添加任何内容:

In [141]: gx1
Out[141]: 
array([[ -1,  -1,  -1,  -1],
       [-17, -17, -17, -17],
       [-33, -33, -33, -33],
       [-49, -49, -49, -49],
       [-65, -65, -65, -65]])

In [142]: gx2
Out[142]: 
array([[ -1,  -1,  -1,  -1],
       [-17, -17, -17, -17],
       [-33, -33, -33, -33],
       [-49, -49, -49, -49],
       [-65, -65, -65, -65]])

gx只会将5 g()的值重复4次.

In [143]: g(x[:, 0])
Out[143]: array([ -1, -17, -33, -49, -65])

In [144]: fx-gx1
Out[144]: 
array([[  2,   6,  10,  14],
       [ 34,  38,  42,  46],
       [ 66,  70,  74,  78],
       [ 98, 102, 106, 110],
       [130, 134, 138, 142]])

所以gx可以用一个(5,1)数组代替,它用(5,4)fx广播:

In [145]: fx-g(x[:,0])[:,None]
Out[145]: 
array([[  2,   6,  10,  14],
       [ 34,  38,  42,  46],
       [ 66,  70,  74,  78],
       [ 98, 102, 106, 110],
       [130, 134, 138, 142]])

我还没有试着更好地理解我所 comments 的T维和d维.

这个答案可能太冗长了,但它说明了我可视化并发现了broadcasting修复的方式.

Python相关问答推荐

比较两个二元组列表,NP.isin

Python 约束无法解决n皇后之谜

如何从在虚拟Python环境中运行的脚本中运行需要宿主Python环境的Shell脚本?

Python虚拟环境的轻量级使用

使用NeuralProphet绘制置信区间时出错

有没有一种ONE—LINER的方法给一个框架的每一行一个由整数和字符串组成的唯一id?

使用Python和文件进行模糊输出

在Python中计算连续天数

Flash只从html表单中获取一个值

以逻辑方式获取自己的pyproject.toml依赖项

在不同的帧B中判断帧A中的子字符串,每个帧的大小不同

ConversationalRetrivalChain引发键错误

根据客户端是否正在传输响应来更改基于Flask的API的行为

Django Table—如果项目是唯一的,则单行

如何在GEKKO中使用复共轭物

ModuleNotFoundError:Python中没有名为google的模块''

Python协议不兼容警告

Pandas数据框上的滚动平均值,其中平均值的中心基于另一数据框的时间

将数字数组添加到Pandas DataFrame的单元格依赖于初始化

对于标准的原始类型注释,从键入`和`从www.example.com `?