我想在主进程中做一些计算,并将张量广播到其他进程.下面是我当前代码的草图:

from accelerate.utils import broadcast

x = None
if accelerator.is_local_main_process:
    x = <do_some_computation>
    x = broadcast(x)  # I have even tried moving this line out of the if block
print(x.shape)

这给了我以下错误: TypeError: Unsupported types (<class 'NoneType'>) passed to `_gpu_broadcast_one` . Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` s hould be passed.

这意味着x仍然是None,并没有真正被播出.我该怎么解决这个问题?

推荐答案

x不可能是None它必须是一个张量,它是相同的形状和正确的设备(当前过程).我怀疑这是因为broadcast内部做了copy_.出于某种原因,空张量也不起作用.相反,我只是创建了一个全为零的张量.

from accelerate.utils import broadcast

x = torch.zeros(*final_shape, device=accelerator.device)
if accelerator.is_local_main_process:
    x = <do_some_computation>
    x = broadcast(x)
print(x.shape)

Python相关问答推荐

将特定列信息移动到当前行下的新行

我们可以为Flask模型中的id字段主键设置默认uuid吗

cv2.matchTemplate函数匹配失败

索引到 torch 张量,沿轴具有可变长度索引

使用特定值作为引用替换数据框行上的值

未调用自定义JSON编码器

找到相对于列表索引的当前最大值列表""

如何在海上配对图中使某些标记周围的黑色边框

在用于Python的Bokeh包中设置按钮的样式

如果有2个或3个,则从pandas列中删除空格

语法错误:文档. evaluate:表达式不是合法表达式

为什么t sns.barplot图例不显示所有值?'

提取数组每行的非零元素

python的文件. truncate()意外地没有截断'

如何在信号的FFT中获得正确的频率幅值

提取最内层嵌套链接

Pandas数据框上的滚动平均值,其中平均值的中心基于另一数据框的时间

为什么在更新Pandas 2.x中的列时,数据类型不会更改,而在Pandas 1.x中会更改?

将参数从另一个python脚本中传递给main(argv

如何定义一个将类型与接收该类型的参数的可调用进行映射的字典?