如果我想在KERAS中使用BatchNormalization函数,那么我是否只需要在开始时调用它一次?

我读了这个文档:http://keras.io/layers/normalization/

我不知道该怎么称呼它.下面是我试图使用它的代码:

model = Sequential()
keras.layers.normalization.BatchNormalization(epsilon=1e-06, mode=0, momentum=0.9, weights=None)
model.add(Dense(64, input_dim=14, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(64, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(2, init='uniform'))
model.add(Activation('softmax'))

sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='binary_crossentropy', optimizer=sgd)
model.fit(X_train, y_train, nb_epoch=20, batch_size=16, show_accuracy=True, validation_split=0.2, verbose = 2)

我之所以这样问,是因为如果我在第二行运行代码,包括批处理规范化,如果我在没有第二行的情况下运行代码,我会得到类似的输出.所以要么我没有在正确的位置调用函数,要么我猜它没有造成太大的不同.

推荐答案

更详细地回答这个问题,正如Pavel所说,批处理标准化只是另一个层,因此您可以使用它来创建您想要的网络体系 struct .

一般情况下,在网络中的线性层和非线性层之间使用BN,因为它规范化了激活函数的输入,因此您位于激活函数的线性部分(例如Sigmoid)的中心.这里有一个关于它的小讨论

在上面的例子中,这可能看起来像:


# import BatchNormalization
from keras.layers.normalization import BatchNormalization

# instantiate model
model = Sequential()

# we can think of this chunk as the input layer
model.add(Dense(64, input_dim=14, init='uniform'))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Dropout(0.5))

# we can think of this chunk as the hidden layer    
model.add(Dense(64, init='uniform'))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Dropout(0.5))

# we can think of this chunk as the output layer
model.add(Dense(2, init='uniform'))
model.add(BatchNormalization())
model.add(Activation('softmax'))

# setting up the optimization of our weights 
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='binary_crossentropy', optimizer=sgd)

# running the fitting
model.fit(X_train, y_train, nb_epoch=20, batch_size=16, show_accuracy=True, validation_split=0.2, verbose = 2)

希望这能让事情更清楚一点.

Python相关问答推荐

如何确保Flask应用程序管理面板中的项目具有单击删除功能?

使用pandas MultiIndex进行不连续 Select

Snap 7- read_Area用于类似地址的变量

Polars Select 多个元素产品

如何让我的Tkinter应用程序适合整个窗口,无论大小如何?

计算相同形状的两个张量的SSE损失

如何根据另一列值用字典中的值替换列值

运行回文查找器代码时发生错误:[类型错误:builtin_index_or_system对象不可订阅]

Select 用a和i标签包裹的复选框?

我从带有langchain的mongoDB中的vector serch获得一个空数组

如何根据参数推断对象的返回类型?

scikit-learn导入无法导入名称METRIC_MAPPING64'

将输入管道传输到正在运行的Python脚本中

当从Docker的--env-file参数读取Python中的环境变量时,每个\n都会添加一个\'.如何没有额外的?

python中字符串的条件替换

为一个组的每个子组绘制,

什么是最好的方法来切割一个相框到一个面具的第一个实例?

numpy.unique如何消除重复列?

Python Pandas—时间序列—时间戳缺失时间精确在00:00

Python—转换日期:价目表到新行