我开始使用R中的XGBoost,并试图将binary:logistic模型中的预测与使用自定义日志丢失函数生成的预测相匹配。我希望以下两个对predict的调用会产生相同的结果:
require(xgboost)
loglossobj <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label")
preds <- 1/(1 + exp(-preds))
grad <- preds - labels
hess <- preds * (1 - preds)
return(list(grad = grad, hess = hess))
}
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
train<-agaricus.train
test<-agaricus.test
model<-xgboost(data = train$data, label = train$label, nrounds=2,objective="binary:logistic")
preds = predict(model,test$data)
print (head(preds))
model<-xgboost(data = train$data, label = train$label, nrounds=2,objective=loglossobj, eval_metric = "error")
preds = predict(model,test$data)
x = 1 / (1+exp(-preds))
print (head(x))
自定义对数损失函数的模型输出未应用逻辑变换1 /(1 + exp(-x))。但是,如果这样做,两次调用predict:
的结果概率是不同的[1] 0.2582498 0.7433221 0.2582498 0.2582498 0.2576509 0.2750908
对
[1] 0.3076240 0.7995583 0.3076240 0.3076240 0.3079328 0.3231709
我确定这里有一个简单的解释。有什么建议吗?
事实证明,此行为是由于初始条件引起的。当调用binary:logistic或binary:logit_raw时,xgboost隐式假定base_score = 0.5,但在使用自定义损失函数时,必须将base_score设置为0.0,以复制其输出。为了说明这一点,下面的R代码在所有三种情况下都生成相同的预测:
require(xgboost)
loglossobj <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label")
preds <- 1/(1 + exp(-preds))
grad <- preds - labels
hess <- preds * (1 - preds)
return(list(grad = grad, hess = hess))
}
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
train<-agaricus.train
test<-agaricus.test
model<-xgboost(data = train$data, label = train$label, objective = "binary:logistic", nround = 10, eta = 0.1, verbose=0)
preds = predict(model,test$data)
print (head(preds))
model<-xgboost(data = train$data, label = train$label, objective = "binary:logitraw", nround = 10, eta = 0.1, verbose=0)
preds = predict(model,test$data)
x = 1 / (1+exp(-preds))
print (head(x))
model<-xgboost(data = train$data, label = train$label, objective = loglossobj, base_score = 0.0, nround = 10, eta = 0.1, verbose=0)
preds = predict(model,test$data)
x = 1 / (1+exp(-preds))
print (head(x))
输出
[1] 0.1814032 0.8204284 0.1814032 0.1814032 0.1837782 0.1952717
[1] 0.1814032 0.8204284 0.1814032 0.1814032 0.1837782 0.1952717
[1] 0.1814032 0.8204284 0.1814032 0.1814032 0.1837782 0.1952717