我有以下代码用于在Keras中实现标准分类问题:

import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense
from sklearn.datasets import load_breast_cancer

X,y = load_breast_cancer(return_X_y =True)

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3,random_state=1)

model = Sequential()
model.add(Dense(units=30,input_dim=X_train.shape[1],activation='relu',kernel_initializer='uniform'))
model.add(Dense(units=20,activation='relu',kernel_initializer='uniform'))
model.add(Dense(units=1,activation='sigmoid',kernel_initializer='uniform'))

model.compile(optimizer='sgd',loss='binary_crossentropy',metrics=['accuracy'])
model.fit(X_train,y_train,validation_data=(X_test,y_test),epochs=200,batch_size=40,verbose=1)

一切正常,但我想在kernel_initializer中使用自定义初始化函数.

例如,在这一行中

model.add(Dense(units=20,activation='relu',kernel_initializer='uniform'))

我更喜欢这种代码:

def my_custom_initialization():
    return here

model.add(Dense(units=20,activation='relu',kernel_initializer=my_custom_initialization()))

如何使用符合Keras框架的自定义分布或方法生成数字?

推荐答案

要在Keras中创建自定义内核初始化,请try 以下操作:

import numpy as np
import tensorflow as tf
import keras
from keras.models import Sequential
from keras.layers import Dense
from sklearn.datasets import load_breast_cancer

X,y = load_breast_cancer(return_X_y =True)

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3,random_state=1)


def my_custom_initialization(shape, dtype=None):
    return tf.random.normal(shape, dtype=dtype)

model = Sequential()
model.add(Dense(units=30,input_dim=X_train.shape[1],activation='relu',
                kernel_initializer=my_custom_initialization))
model.add(Dense(units=20,activation='relu',
                kernel_initializer=my_custom_initialization))
model.add(Dense(units=1,activation='sigmoid',
                kernel_initializer=my_custom_initialization))

model.compile(optimizer='sgd',loss='binary_crossentropy',metrics=['accuracy'])
model.fit(X_train,y_train,validation_data=(X_test,y_test),epochs=3,batch_size=40,verbose=1)

输出:

Epoch 1/3
10/10 [==============================] - 2s 65ms/step - loss: 598388.8750 - accuracy: 0.5553 - val_loss: 0.6913 - val_accuracy: 0.6316
Epoch 2/3
10/10 [==============================] - 0s 9ms/step - loss: 0.6907 - accuracy: 0.6256 - val_loss: 0.6898 - val_accuracy: 0.6316
Epoch 3/3
10/10 [==============================] - 0s 8ms/step - loss: 0.6893 - accuracy: 0.6256 - val_loss: 0.6883 - val_accuracy: 0.6316

Python相关问答推荐

Python—压缩叶 map html作为邮箱附件并通过sendgrid发送

Pandas—堆栈多索引头,但不包括第一列

Pandas—MultiIndex Resample—我不想丢失其他索引的信息´

使用polars. pivot()旋转一个框架(类似于R中的pivot_longer)

仅使用预先计算的排序获取排序元素

如何在Django模板中显示串行化器错误

修改.pdb文件中的值并另存为新的

如何在python tkinter中绑定键盘上的另一个回车?

根据边界点的属性将图划分为子图

Pandas查找给定时间戳之前的最后一个值

以元组为索引的Numpy多维索引

在Pandas 中,有没有办法让元组作为索引运行得很好?

使代码更快地解决哪个字母代表给定公式中的哪个数字

如何判断变量可调用函数的参数是否都属于某个子类?

对齐多个叠置多面Seborn CAT图

我的tkinter应用程序不会改变它正在加载的文件

Python .删除DataFrame中的标题标签(&U;UNNAMED:0),并将其余标题标签向左移动(不更改值)

修复如何使用python排序方法对列表中的元素进行排序

Shopware 6 REST-API产品更新不起作用

我无法用python语言中的matplotlib来绘制一个简单的图形,该图形以列表中给出的5处的X值开始