我正在研究一个使用DNN的语音go 噪问题.我正在通过下面的函数计算我的信噪比.

def calculate_snr(clean_signal, recovered_signal):

    clean_power = tf.reduce_sum(tf.square(clean_signal))

    noise_power = tf.reduce_sum(tf.square(clean_signal - recovered_signal))

    snr_db = 10 * tf.math.log(clean_power / noise_power) / tf.math.log(10.0)

    return snr_db

我正在使用keras API创建这样的模型

model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam(learning_rate=learning_rate),metrics=[calculate_snr])

sound_denoising_history = model.fit(x = X_abs.T, y = S_abs.T,epochs=200,batch_size = 100,validation_data=(X_test_01_abs.T,S_test_01_abs.T))

calculate_snr (X_test_01_abs.T,model.predict(X_test_01_abs.T) : 10.9
While model fit: -4.4 to -3

当我训练它时,我发现我的验证SNR度量是-7,并在该范围内振荡.然而,如果我预测xval输入,然后将其与上面的函数一起使用,它会得到8.2.这是相同的功能,我已经判断了多次尺寸.我不知道发生了什么事?

编辑:我知道我错过了信号SNR计算的处理步骤,但即使该度量是独立使用的,它也应该在列车末端产生几乎相同的结果,然后进行推理和计算

推荐答案

当您在model.compile中使用calculate_snr作为度量时,它将在训练期间分批应用,然后对这些分批的值求平均值以计算最终度量.这可能会导致计算的SNR与在代码末尾进行预测后在整个数据集上手动计算时的SNR有所不同.

您可以通过将snr_metric定义为一个类来克服这一限制.

class SNRMetric(keras.metrics.Metric):
    def __init__(self, **kwargs):
        super(SNRMetric, self).__init__(**kwargs)
        self.clean_power = self.add_weight(name="clean_power", initializer="zeros")
        self.noise_power = self.add_weight(name="noise_power", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        clean_power = tf.reduce_sum(tf.square(y_true))
        noise_power = tf.reduce_sum(tf.square(y_true - y_pred))

        self.clean_power.assign_add(clean_power)
        self.noise_power.assign_add(noise_power)
        self.count.assign_add(1)

    def result(self):
        snr_db = 10 * tf.math.log(self.clean_power / self.noise_power) / tf.math.log(10.0)
        return snr_db

然后,您可以修改用于培训和测试的代码,如下所示:

# TODO define your model

model.compile(
   loss='mean_squared_error', 
   optimizer=keras.optimizers.Adam(learning_rate=learning_rate), 
   metrics=[SNRMetric()] # here the crucial point
)

# Train 
sound_denoising_history = model.fit(x=X_abs.T, y=S_abs.T, epochs=200, batch_size=100, validation_data=(X_test_01_abs.T, S_test_01_abs.T))

# Calculate SNR using the custom metric after training
snr_metric = SNRMetric()
snr_metric.update_state(S_test_01_abs.T, model.predict(X_test_01_abs.T))
snr_value = snr_metric.result()
print(f"SNR after training: {snr_value.numpy()}")

Python相关问答推荐

使用Beautiful Soup获取第二个srcset属性

如何将桌子刮成带有Se的筷子/要求/Beautiful Soup ?

无法使用python.h文件; Python嵌入错误

Pydantic 2.7.0模型接受字符串日期时间或无

从收件箱中的列中删除html格式

Pandas 有条件轮班操作

scikit-learn导入无法导入名称METRIC_MAPPING64'

如何在Django基于类的视图中有效地使用UTE和RST HTIP方法?

组/群集按字符串中的子字符串或子字符串中的字符串轮询数据框

关于Python异步编程的问题和使用await/await def关键字

polars:有效的方法来应用函数过滤列的字符串

Pandas—堆栈多索引头,但不包括第一列

从一个df列提取单词,分配给另一个列

Django Table—如果项目是唯一的,则单行

jsonschema日期格式

当HTTP 201响应包含 Big Data 的POST请求时,应该是什么?  

裁剪数字.nd数组引发-ValueError:无法将空图像写入JPEG

我什么时候应该使用帆布和标签?

多个矩阵的张量积

Python:从目录内的文件导入目录