我想在 mlr3 和 xgboost 包中重现 XGBoost 模型的拟合(训练和后续预测)。为了简单起见,请参阅以下使用 Lung 数据集的示例,并在训练数据集上进行预测。 xgboost (xgb_pred) 和 mlr3 (mlr3_xgb$lp) 的线性预测变量并不完全相同。任何关于为什么会出现这种情况的建议将不胜感激(希望这只是我的编码中的一个小故障或缺乏理解)。
library(mlr3)
library(mlr3pipelines)
library(mlr3extralearners)
library(mlr3proba)
library(xgboost)
## mlr3 as an example ----
task_lung = tsk('lung')
lung = task_lung$data()
xgb_basic = as_learner(
po("encode") %>>% lrn("surv.xgboost.cox", eta = 0.0103))
set.seed(123)
xgb_basic$train(task_lung)
mlr3_xgb = xgb_basic$predict(task_lung)
## use xgboost package -----
# labels to be attached to dataset
label <- ifelse(lung$status == 0, lung$time, -lung$time) # label
y_lower_bound = lung$time
y_upper_bound = ifelse(lung$status==0, +Inf, lung$time)
xgb_data=model.matrix(~.+0, data = lung[,-c(1,2),with=F]) # one hot coding
# Data matrix
dmat = xgb.DMatrix(xgb_data, label=label) # for cox
params <- list(objective='survival:cox', # train
eval_metric='cox-nloglik',
learning_rate=0.0103) #aka eta
set.seed(123)
bst <- xgb.train(params=params,
data = dmat,
nrounds=1,
watchlist=list(train = dmat, eval=dmat))
#> [1] train-cox-nloglik:3.896049 eval-cox-nloglik:3.896049
xgb_pred = predict(bst, newdata=dmat)
round(exp(mlr3_xgb$lp),3)
#> [1] 0.500 0.510 0.496 0.510 0.500 0.503 0.500 0.498 0.498 0.499 0.498 0.510
#> [13] 0.500 0.501 0.498 0.499 0.499 0.513 0.502 0.513 0.499 0.513 0.510 0.513
#> [25] 0.499 0.496 0.513 0.499 0.500 0.505 0.500 0.500 0.513 0.505 0.498 0.499
#> [37] 0.496 0.499 0.498 0.498 0.500 0.499 0.499 0.500 0.503 0.499 0.510 0.510
#> [49] 0.495 0.499 0.505 0.499 0.499 0.513 0.504 0.499 0.498 0.499 0.505 0.498
#> [61] 0.503 0.503 0.504 0.496 0.500 0.499 0.498 0.499 0.513 0.499 0.505 0.510
#> [73] 0.513 0.499 0.495 0.499 0.505 0.499 0.503 0.513 0.505 0.503 0.500 0.513
#> [85] 0.510 0.500 0.502 0.505 0.505 0.499 0.513 0.500 0.505 0.510 0.496 0.496
#> [97] 0.499 0.503 0.505 0.496 0.499 0.503 0.513 0.505 0.513 0.499 0.498 0.499
#> [109] 0.503 0.505 0.500 0.510 0.513 0.502 0.502 0.499 0.495 0.504 0.500 0.499
#> [121] 0.498 0.495 0.503 0.499 0.499 0.499 0.501 0.499 0.500 0.503 0.499 0.513
#> [133] 0.499 0.495 0.498 0.499 0.501 0.495 0.505 0.498 0.510 0.503 0.505 0.498
#> [145] 0.499 0.500 0.499 0.495 0.502 0.499 0.513 0.495 0.495 0.495 0.510 0.495
#> [157] 0.505 0.502 0.498 0.513 0.500 0.495 0.496 0.499 0.499 0.500 0.499 0.499
round(xgb_pred,3)
#> [1] 0.496 0.499 0.495 0.501 0.497 0.495 0.502 0.499 0.499 0.496 0.495 0.501
#> [13] 0.496 0.496 0.496 0.498 0.496 0.495 0.499 0.495 0.496 0.495 0.501 0.495
#> [25] 0.495 0.497 0.495 0.496 0.500 0.501 0.500 0.495 0.501 0.503 0.500 0.505
#> [37] 0.498 0.505 0.505 0.496 0.498 0.496 0.496 0.501 0.499 0.505 0.505 0.495
#> [49] 0.505 0.505 0.500 0.496 0.500 0.495 0.498 0.496 0.498 0.499 0.503 0.498
#> [61] 0.495 0.506 0.510 0.506 0.496 0.505 0.505 0.503 0.495 0.496 0.505 0.503
#> [73] 0.498 0.505 0.500 0.505 0.496 0.499 0.498 0.501 0.503 0.506 0.505 0.495
#> [85] 0.495 0.502 0.495 0.500 0.497 0.497 0.495 0.496 0.496 0.499 0.505 0.507
#> [97] 0.496 0.498 0.498 0.502 0.496 0.499 0.501 0.503 0.495 0.505 0.500 0.510
#> [109] 0.499 0.500 0.495 0.495 0.495 0.495 0.510 0.507 0.500 0.506 0.502 0.505
#> [121] 0.503 0.505 0.498 0.510 0.505 0.510 0.496 0.507 0.498 0.499 0.505 0.495
#> [133] 0.500 0.510 0.505 0.506 0.498 0.506 0.497 0.505 0.500 0.498 0.500 0.507
#> [145] 0.496 0.505 0.510 0.498 0.505 0.500 0.495 0.505 0.510 0.510 0.501 0.510
#> [157] 0.503 0.505 0.500 0.498 0.501 0.507 0.505 0.505 0.506 0.501 0.505 0.505
由 reprex 包于 2024 年 8 月 30 日创建(v2.0.1)
会议信息sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.1.1 (2021-08-10)
#> os macOS Big Sur 10.16
#> system x86_64, darwin17.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Australia/Adelaide
#> date 2024-08-30
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib
#> backports 1.5.0 2024-05-23 [1]
#> checkmate 2.3.1 2023-12-04 [1]
#> cli 3.6.3 2024-06-21 [1]
#> codetools 0.2-18 2020-11-04 [2]
#> colorspace 2.1-0 2023-01-23 [1]
#> crayon 1.4.1 2021-02-08 [2]
#> data.table 1.15.4 2024-03-30 [1]
#> dictionar6 0.1.3 2021-09-13 [1]
#> digest 0.6.36 2024-06-23 [1]
#> distr6 1.8.4 2024-06-13 [1]
#> dplyr 1.1.3 2023-09-03 [1]
#> evaluate 0.24.0 2024-06-10 [1]
#> fansi 1.0.6 2023-12-08 [1]
#> fastmap 1.1.0 2021-01-25 [2]
#> fs 1.5.0 2020-07-31 [2]
#> future 1.33.2 2024-03-26 [1]
#> generics 0.1.3 2022-07-05 [1]
#> ggplot2 3.5.1 2024-04-23 [1]
#> globals 0.16.3 2024-03-08 [1]
#> glue 1.7.0 2024-01-09 [1]
#> gtable 0.3.5 2024-04-22 [1]
#> highr 0.9 2021-04-16 [2]
#> htmltools 0.5.6 2023-08-10 [1]
#> jsonlite 1.7.2 2020-12-09 [2]
#> knitr 1.33 2021-04-24 [2]
#> lattice 0.20-44 2021-05-02 [2]
#> lgr 0.4.4 2022-09-05 [1]
#> lifecycle 1.0.4 2023-11-07 [1]
#> listenv 0.9.1 2024-01-29 [1]
#> magrittr 2.0.3 2022-03-30 [1]
#> Matrix 1.3-4 2021-06-01 [2]
#> mlr3 * 0.20.0 2024-06-28 [1]
#> mlr3extralearners * 0.8.0-9000 2024-06-15 [1]
#> mlr3misc 0.15.1 2024-06-24 [1]
#> mlr3pipelines * 0.6.0 2024-07-16 [1]
#> mlr3proba * 0.6.3 2024-06-13 [1]
#> mlr3viz 0.9.0 2024-07-01 [1]
#> munsell 0.5.1 2024-04-01 [1]
#> ooplah 0.2.0 2022-01-21 [1]
#> palmerpenguins 0.1.1 2022-08-15 [1]
#> paradox 1.0.1 2024-07-09 [1]
#> parallelly 1.37.1 2024-02-29 [1]
#> param6 0.2.4 2023-11-22 [1]
#> pillar 1.9.0 2023-03-22 [1]
#> pkgconfig 2.0.3 2019-09-22 [2]
#> R6 2.5.1 2021-08-19 [1]
#> Rcpp 1.0.12 2024-01-09 [1]
#> reprex 2.0.1 2021-08-05 [1]
#> RhpcBLASctl 0.23-42 2023-02-11 [1]
#> rlang 1.1.4 2024-06-04 [1]
#> rmarkdown 2.10 2021-08-06 [2]
#> rstudioapi 0.15.0 2023-07-07 [1]
#> scales 1.3.0 2023-11-28 [1]
#> sessioninfo 1.1.1 2018-11-05 [2]
#> set6 0.2.6 2023-11-22 [1]
#> stringi 1.7.3 2021-07-16 [2]
#> stringr 1.5.0 2022-12-02 [1]
#> survival 3.7-0 2024-06-05 [1]
#> tibble 3.2.1 2023-03-20 [1]
#> tidyselect 1.2.0 2022-10-10 [1]
#> utf8 1.2.4 2023-10-22 [1]
#> uuid 1.2-0 2024-01-14 [1]
#> vctrs 0.6.5 2023-12-01 [1]
#> withr 3.0.0 2024-01-16 [1]
#> xfun 0.25 2021-08-06 [2]
#> xgboost * 1.7.8.1 2024-07-24 [1]
#> yaml 2.2.1 2020-02-01 [2]
#> source
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> Github (xoopR/distr6@95d7359)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> Github (mlr-org/mlr3extralearners@6dc6965)
#> CRAN (R 4.1.1)
#> Github (mlr-org/mlr3pipelines@c542a26)
#> Github (mlr-org/mlr3proba@5205752)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> Github (xoopR/param6@0fa3577)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> Github (xoopR/set6@a901255)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#>
#> [1] /Users/Lee/Library/R/x86_64/4.1/library
#> [2] /Library/Frameworks/R.framework/Versions/4.1/Resources/library
谢谢你。
很难肯定地说,尽管我非常肯定我们正在做正确的数据转换。一些可能需要检查的事情:
xgboost::xgb.DMatrix
进行相同的编码?对于 cox
目标,我们只需要否定审查观察的标签(观察时间):请参阅 here 了解 xgboost 的内部辅助函数。还有一个“类似的问题”,其中有一些代码可以完全执行该转换。我认为那部分是相同的。
两个版本都有nrounds = 1
新闻) 我会尝试在没有因子编码的数据集中进行测试,例如
tsk("gbcs")
watchlist
mlr3
中它是空的(NULL
)。在最新版本中,我们支持内部验证和提前停止 xgboost 顺便说一句,但在这里我将确保每个参数都相同,并且输出 raw
xgboost 模型相同(预测肯定会遵循)。