我试图通过在训练和推理(蒙特卡罗辍学)过程中保持辍学概率来接近贝叶斯模型,以获得该模型的认知不确定性.

有没有一种方法可以修复重复性的所有随机性来源(随机种子),但保持辍学的随机性?

# Set random seed for reproducibility
seed = 123
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

# Training and Inference phase (with dropout)
dropout_mask = torch.bernoulli(torch.full_like(input, 1 - self.dropout))
skip = self.skip0(input * dropout_mask / (1 - self.dropout))

for i in range(self.layers):
    residual = x
    filter = self.filter_convs[i](x)
    filter = torch.tanh(filter)
    gate = self.gate_convs[i](x)
    gate = torch.sigmoid(gate)
    x = filter * gate

    dropout_mask = torch.bernoulli(torch.full_like(x, 1 - self.dropout))
    x = x * dropout_mask / (1 - self.dropout)

    s = x
    s = self.skip_convs[i](s)
    skip = s + skip
    if self.gcn_true:
        x = self.gconv1[i](x, adp) + self.gconv2[i](x, adp.transpose(1, 0))
    else:
        x = self.residual_convs[i](x)

    x = x + residual[:, :, :, -x.size(3):]
    if idx is None:
        x = self.norm[i](x, self.idx)
    else:
        x = self.norm[i](x, idx)

skip = self.skipE(x) + skip
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)

return x

上面的代码每次都会产生相同的结果,这不是我想要做的.

推荐答案

也许可以考虑为您想要复制的部分设置并重置种子?

# Training and Inference phase (with dropout)
torch.manual_seed(torch.initial_seed())
dropout_mask = torch.bernoulli(torch.full_like(input, 1 - self.dropout))
torch.manual_seed(seed) 
skip = self.skip0(input * dropout_mask / (1 - self.dropout))

你可能也想go 看看torch.cuda.manual_seed_all(seed)家店.

此外,再现性也有一些限制,至少如果你不想降低性能的话:https://pytorch.org/docs/stable/notes/randomness.html.

Python相关问答推荐

在后台运行的Python函数

使用Python进行网页抓取,没有页面

从单个列创建多个列并按pandas分组

Numpy索引argsorted使用integer数组,同时保留排序顺序

有没有方法可以修复删除了换码字符的无效的SON记录?

使用Python Cerberus初始化一个循环数据 struct (例如树)(v1.3.5)

如何使用Selenium访问svg对象内部的元素

在上下文管理器中更改异常类型

如何在Python中使用io.BytesIO写入现有缓冲区?

Python中使用时区感知日期时间对象进行时间算术的Incredit

多处理代码在while循环中不工作

根据不同列的值在收件箱中移动数据

Pandas 滚动最接近的价值

删除最后一个pip安装的包

如何使用html从excel中提取条件格式规则列表?

' osmnx.shortest_track '返回有效源 node 和目标 node 的'无'

加速Python循环

如何在UserSerializer中添加显式字段?

Python 3试图访问在线程调用中实例化的类的对象

语法错误:文档. evaluate:表达式不是合法表达式