Keras Metrics 意外行为

问题描述 投票:0回答:1

我正在使用 DNN 解决语音去噪问题。我正在通过下面的函数计算我的信噪比。

def calculate_snr(clean_signal, recovered_signal):

    clean_power = tf.reduce_sum(tf.square(clean_signal))

    noise_power = tf.reduce_sum(tf.square(clean_signal - recovered_signal))

    snr_db = 10 * tf.math.log(clean_power / noise_power) / tf.math.log(10.0)

    return snr_db

我正在使用 keras api 创建这样的模型

model.compile(loss='mean_squared_error', optimizer=keras.optimizers.Adam(learning_rate=learning_rate),metrics=[calculate_snr])

sound_denoising_history = model.fit(x = X_abs.T, y = S_abs.T,epochs=200,batch_size = 100,validation_data=(X_test_01_abs.T,S_test_01_abs.T))

calculate_snr (X_test_01_abs.T,model.predict(X_test_01_abs.T) : 10.9
While model fit: -4.4 to -3

当我训练它时,我发现我的验证 SNR 指标为 -7 并在该范围内振荡。然而,如果我预测 xval 输入,然后将其与上述函数一起使用,它会给出 8.2。这是相同的功能,我已经检查过多次尺寸。我不确定发生了什么?

编辑:我知道我错过了信号信噪比计算的处理步骤,但即使指标是独立使用的,它也应该在火车端产生几乎相同的大致结果,然后进行推理,然后进行计算

python tensorflow machine-learning keras deep-learning
1个回答
0
投票

当您在

calculate_snr
中使用
model.compile
作为度量时,它会在训练期间批量应用,然后对这些批量值进行平均以计算最终度量。与在代码末尾进行预测后在整个数据集上手动计算 SNR 相比,这可能会导致计算出的 SNR 存在差异。

您可以通过将

snr_metric
定义为类来克服此限制。

class SNRMetric(keras.metrics.Metric):
    def __init__(self, **kwargs):
        super(SNRMetric, self).__init__(**kwargs)
        self.clean_power = self.add_weight(name="clean_power", initializer="zeros")
        self.noise_power = self.add_weight(name="noise_power", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        clean_power = tf.reduce_sum(tf.square(y_true))
        noise_power = tf.reduce_sum(tf.square(y_true - y_pred))

        self.clean_power.assign_add(clean_power)
        self.noise_power.assign_add(noise_power)
        self.count.assign_add(1)

    def result(self):
        snr_db = 10 * tf.math.log(self.clean_power / self.noise_power) / tf.math.log(10.0)
        return snr_db

然后您可以修改您的代码进行训练和测试,如下所示:

# TODO define your model

model.compile(
   loss='mean_squared_error', 
   optimizer=keras.optimizers.Adam(learning_rate=learning_rate), 
   metrics=[SNRMetric()] # here the crucial point
)

# Train 
sound_denoising_history = model.fit(x=X_abs.T, y=S_abs.T, epochs=200, batch_size=100, validation_data=(X_test_01_abs.T, S_test_01_abs.T))

# Calculate SNR using the custom metric after training
snr_metric = SNRMetric()
snr_metric.update_state(S_test_01_abs.T, model.predict(X_test_01_abs.T))
snr_value = snr_metric.result()
print(f"SNR after training: {snr_value.numpy()}")
© www.soinside.com 2019 - 2024. All rights reserved.