我正在使用 DNN 解决语音去噪问题。我正在通过下面的函数计算我的信噪比。
def calculate_snr(clean_signal, recovered_signal):
clean_power = tf.reduce_sum(tf.square(clean_signal))
noise_power = tf.reduce_sum(tf.square(clean_signal - recovered_signal))
snr_db = 10 * tf.math.log(clean_power / noise_power) / tf.math.log(10.0)
return snr_db
我正在使用 keras api 创建这样的模型
model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam(learning_rate=learning_rate),metrics=[calculate_snr])
sound_denoising_history = model.fit(x = X_abs.T, y = S_abs.T,epochs=200,batch_size = 100,validation_data=(X_test_01_abs.T,S_test_01_abs.T))
calculate_snr (X_test_01_abs.T,model.predict(X_test_01_abs.T) : 10.9
While model fit: -4.4 to -3
当我训练它时,我发现我的验证 SNR 指标为 -7 并在该范围内振荡。然而,如果我预测 xval 输入,然后将其与上述函数一起使用,它会给出 8.2。这是相同的功能,我已经检查过多次尺寸。我不确定发生了什么?
编辑:我知道我错过了信号信噪比计算的处理步骤,但即使指标是独立使用的,它也应该在火车端产生几乎相同的大致结果,然后进行推理,然后进行计算
当您在
calculate_snr
中使用 model.compile
作为度量时,它会在训练期间批量应用,然后对这些批量值进行平均以计算最终度量。与在代码末尾进行预测后在整个数据集上手动计算 SNR 相比,这可能会导致计算出的 SNR 存在差异。
您可以通过将
snr_metric
定义为类来克服此限制。
class SNRMetric(keras.metrics.Metric):
def __init__(self, **kwargs):
super(SNRMetric, self).__init__(**kwargs)
self.clean_power = self.add_weight(name="clean_power", initializer="zeros")
self.noise_power = self.add_weight(name="noise_power", initializer="zeros")
self.count = self.add_weight(name="count", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
clean_power = tf.reduce_sum(tf.square(y_true))
noise_power = tf.reduce_sum(tf.square(y_true - y_pred))
self.clean_power.assign_add(clean_power)
self.noise_power.assign_add(noise_power)
self.count.assign_add(1)
def result(self):
snr_db = 10 * tf.math.log(self.clean_power / self.noise_power) / tf.math.log(10.0)
return snr_db
然后您可以修改您的代码进行训练和测试,如下所示:
# TODO define your model
model.compile(
loss='mean_squared_error',
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
metrics=[SNRMetric()] # here the crucial point
)
# Train
sound_denoising_history = model.fit(x=X_abs.T, y=S_abs.T, epochs=200, batch_size=100, validation_data=(X_test_01_abs.T, S_test_01_abs.T))
# Calculate SNR using the custom metric after training
snr_metric = SNRMetric()
snr_metric.update_state(S_test_01_abs.T, model.predict(X_test_01_abs.T))
snr_value = snr_metric.result()
print(f"SNR after training: {snr_value.numpy()}")