Keras中的RMSE / RMSLE损失函数

问题描述 投票:14回答:3

我尝试参加我的第一场Kaggle比赛,其中RMSLE被作为所需的损失函数。因为我没有发现如何实施这个loss function我试图解决RMSE。我知道这是Keras过去的一部分,有没有办法在最新版本中使用它,也许通过backend定制功能?

这是我设计的NN:

from keras.models import Sequential
from keras.layers.core import Dense , Dropout
from keras import regularizers

model = Sequential()
model.add(Dense(units = 128, kernel_initializer = "uniform", activation = "relu", input_dim = 28,activity_regularizer = regularizers.l2(0.01)))
model.add(Dropout(rate = 0.2))
model.add(Dense(units = 128, kernel_initializer = "uniform", activation = "relu"))
model.add(Dropout(rate = 0.2))
model.add(Dense(units = 1, kernel_initializer = "uniform", activation = "relu"))
model.compile(optimizer = "rmsprop", loss = "root_mean_squared_error")#, metrics =["accuracy"])

model.fit(train_set, label_log, batch_size = 32, epochs = 50, validation_split = 0.15)

我尝试了在GitHub上找到的自定义root_mean_squared_error函数,但是据我所知,语法不是必需的。我认为y_truey_pred必须在传递给返回之前定义,但我不知道究竟是什么,我刚开始使用python进行编程,我在数学方面真的不太好...

from keras import backend as K

def root_mean_squared_error(y_true, y_pred):
        return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1)) 

我使用此函数收到以下错误:

ValueError: ('Unknown loss function', ':root_mean_squared_error')

感谢您的想法,我感谢您的一切帮助!

python keras custom-function loss-function
3个回答
30
投票

当您使用自定义丢失时,您需要在没有引号的情况下放置它,因为您传递的是函数对象,而不是字符串:

def root_mean_squared_error(y_true, y_pred):
        return K.sqrt(K.mean(K.square(y_pred - y_true))) 

model.compile(optimizer = "rmsprop", loss = root_mean_squared_error, 
              metrics =["accuracy"])

19
投票

接受的答案包含错误,导致RMSE实际上是MAE,根据以下问题:

https://github.com/keras-team/keras/issues/10706

应该是正确的定义

def root_mean_squared_error(y_true, y_pred):
        return K.sqrt(K.mean(K.square(y_pred - y_true)))

1
投票

如果你每晚使用最新的tensorflow,虽然文档中没有RMSE,但tf.keras.metrics.RootMeanSquaredError()中有一个source code

样品用量:

model.compile(tf.compat.v1.train.GradientDescentOptimizer(learning_rate),
              loss=tf.keras.metrics.mean_squared_error,
              metrics=[tf.keras.metrics.RootMeanSquaredError(name='rmse')])
© www.soinside.com 2019 - 2024. All rights reserved.