xgboost 预测(cox 线性预测器)与 mlr3 xgboost.cox

问题描述 投票:0回答:1

我想在 mlr3 和 xgboost 包中重现 XGBoost 模型的拟合(训练和后续预测)。为了简单起见,请参阅以下使用 Lung 数据集的示例,并在训练数据集上进行预测。 xgboost (xgb_pred) 和 mlr3 (mlr3_xgb$lp) 的线性预测变量并不完全相同。任何关于为什么会出现这种情况的建议将不胜感激(希望这只是我的编码中的一个小故障或缺乏理解)。

library(mlr3)
library(mlr3pipelines)
library(mlr3extralearners)
library(mlr3proba)
library(xgboost)

## mlr3 as an example ----
task_lung = tsk('lung')
lung = task_lung$data()
xgb_basic = as_learner(
  po("encode") %>>%  lrn("surv.xgboost.cox",  eta = 0.0103))

set.seed(123)
xgb_basic$train(task_lung)
mlr3_xgb = xgb_basic$predict(task_lung)


## use xgboost package -----
# labels to be attached to dataset
label <- ifelse(lung$status == 0, lung$time, -lung$time) # label
y_lower_bound = lung$time
y_upper_bound = ifelse(lung$status==0, +Inf, lung$time)

xgb_data=model.matrix(~.+0, data = lung[,-c(1,2),with=F]) # one hot coding 

# Data matrix
dmat = xgb.DMatrix(xgb_data, label=label) # for cox


params <- list(objective='survival:cox',  # train
               eval_metric='cox-nloglik',
               learning_rate=0.0103)  #aka eta

set.seed(123) 
bst <- xgb.train(params=params, 
                 data = dmat, 
                 nrounds=1, 
                 watchlist=list(train = dmat, eval=dmat))
#> [1]  train-cox-nloglik:3.896049  eval-cox-nloglik:3.896049

xgb_pred = predict(bst, newdata=dmat)

round(exp(mlr3_xgb$lp),3)
#>   [1] 0.500 0.510 0.496 0.510 0.500 0.503 0.500 0.498 0.498 0.499 0.498 0.510
#>  [13] 0.500 0.501 0.498 0.499 0.499 0.513 0.502 0.513 0.499 0.513 0.510 0.513
#>  [25] 0.499 0.496 0.513 0.499 0.500 0.505 0.500 0.500 0.513 0.505 0.498 0.499
#>  [37] 0.496 0.499 0.498 0.498 0.500 0.499 0.499 0.500 0.503 0.499 0.510 0.510
#>  [49] 0.495 0.499 0.505 0.499 0.499 0.513 0.504 0.499 0.498 0.499 0.505 0.498
#>  [61] 0.503 0.503 0.504 0.496 0.500 0.499 0.498 0.499 0.513 0.499 0.505 0.510
#>  [73] 0.513 0.499 0.495 0.499 0.505 0.499 0.503 0.513 0.505 0.503 0.500 0.513
#>  [85] 0.510 0.500 0.502 0.505 0.505 0.499 0.513 0.500 0.505 0.510 0.496 0.496
#>  [97] 0.499 0.503 0.505 0.496 0.499 0.503 0.513 0.505 0.513 0.499 0.498 0.499
#> [109] 0.503 0.505 0.500 0.510 0.513 0.502 0.502 0.499 0.495 0.504 0.500 0.499
#> [121] 0.498 0.495 0.503 0.499 0.499 0.499 0.501 0.499 0.500 0.503 0.499 0.513
#> [133] 0.499 0.495 0.498 0.499 0.501 0.495 0.505 0.498 0.510 0.503 0.505 0.498
#> [145] 0.499 0.500 0.499 0.495 0.502 0.499 0.513 0.495 0.495 0.495 0.510 0.495
#> [157] 0.505 0.502 0.498 0.513 0.500 0.495 0.496 0.499 0.499 0.500 0.499 0.499
round(xgb_pred,3)
#>   [1] 0.496 0.499 0.495 0.501 0.497 0.495 0.502 0.499 0.499 0.496 0.495 0.501
#>  [13] 0.496 0.496 0.496 0.498 0.496 0.495 0.499 0.495 0.496 0.495 0.501 0.495
#>  [25] 0.495 0.497 0.495 0.496 0.500 0.501 0.500 0.495 0.501 0.503 0.500 0.505
#>  [37] 0.498 0.505 0.505 0.496 0.498 0.496 0.496 0.501 0.499 0.505 0.505 0.495
#>  [49] 0.505 0.505 0.500 0.496 0.500 0.495 0.498 0.496 0.498 0.499 0.503 0.498
#>  [61] 0.495 0.506 0.510 0.506 0.496 0.505 0.505 0.503 0.495 0.496 0.505 0.503
#>  [73] 0.498 0.505 0.500 0.505 0.496 0.499 0.498 0.501 0.503 0.506 0.505 0.495
#>  [85] 0.495 0.502 0.495 0.500 0.497 0.497 0.495 0.496 0.496 0.499 0.505 0.507
#>  [97] 0.496 0.498 0.498 0.502 0.496 0.499 0.501 0.503 0.495 0.505 0.500 0.510
#> [109] 0.499 0.500 0.495 0.495 0.495 0.495 0.510 0.507 0.500 0.506 0.502 0.505
#> [121] 0.503 0.505 0.498 0.510 0.505 0.510 0.496 0.507 0.498 0.499 0.505 0.495
#> [133] 0.500 0.510 0.505 0.506 0.498 0.506 0.497 0.505 0.500 0.498 0.500 0.507
#> [145] 0.496 0.505 0.510 0.498 0.505 0.500 0.495 0.505 0.510 0.510 0.501 0.510
#> [157] 0.503 0.505 0.500 0.498 0.501 0.507 0.505 0.505 0.506 0.501 0.505 0.505

reprex 包于 2024 年 8 月 30 日创建(v2.0.1)

会议信息
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.1.1 (2021-08-10)
#>  os       macOS Big Sur 10.16         
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Australia/Adelaide          
#>  date     2024-08-30                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package           * version    date       lib
#>  backports           1.5.0      2024-05-23 [1]
#>  checkmate           2.3.1      2023-12-04 [1]
#>  cli                 3.6.3      2024-06-21 [1]
#>  codetools           0.2-18     2020-11-04 [2]
#>  colorspace          2.1-0      2023-01-23 [1]
#>  crayon              1.4.1      2021-02-08 [2]
#>  data.table          1.15.4     2024-03-30 [1]
#>  dictionar6          0.1.3      2021-09-13 [1]
#>  digest              0.6.36     2024-06-23 [1]
#>  distr6              1.8.4      2024-06-13 [1]
#>  dplyr               1.1.3      2023-09-03 [1]
#>  evaluate            0.24.0     2024-06-10 [1]
#>  fansi               1.0.6      2023-12-08 [1]
#>  fastmap             1.1.0      2021-01-25 [2]
#>  fs                  1.5.0      2020-07-31 [2]
#>  future              1.33.2     2024-03-26 [1]
#>  generics            0.1.3      2022-07-05 [1]
#>  ggplot2             3.5.1      2024-04-23 [1]
#>  globals             0.16.3     2024-03-08 [1]
#>  glue                1.7.0      2024-01-09 [1]
#>  gtable              0.3.5      2024-04-22 [1]
#>  highr               0.9        2021-04-16 [2]
#>  htmltools           0.5.6      2023-08-10 [1]
#>  jsonlite            1.7.2      2020-12-09 [2]
#>  knitr               1.33       2021-04-24 [2]
#>  lattice             0.20-44    2021-05-02 [2]
#>  lgr                 0.4.4      2022-09-05 [1]
#>  lifecycle           1.0.4      2023-11-07 [1]
#>  listenv             0.9.1      2024-01-29 [1]
#>  magrittr            2.0.3      2022-03-30 [1]
#>  Matrix              1.3-4      2021-06-01 [2]
#>  mlr3              * 0.20.0     2024-06-28 [1]
#>  mlr3extralearners * 0.8.0-9000 2024-06-15 [1]
#>  mlr3misc            0.15.1     2024-06-24 [1]
#>  mlr3pipelines     * 0.6.0      2024-07-16 [1]
#>  mlr3proba         * 0.6.3      2024-06-13 [1]
#>  mlr3viz             0.9.0      2024-07-01 [1]
#>  munsell             0.5.1      2024-04-01 [1]
#>  ooplah              0.2.0      2022-01-21 [1]
#>  palmerpenguins      0.1.1      2022-08-15 [1]
#>  paradox             1.0.1      2024-07-09 [1]
#>  parallelly          1.37.1     2024-02-29 [1]
#>  param6              0.2.4      2023-11-22 [1]
#>  pillar              1.9.0      2023-03-22 [1]
#>  pkgconfig           2.0.3      2019-09-22 [2]
#>  R6                  2.5.1      2021-08-19 [1]
#>  Rcpp                1.0.12     2024-01-09 [1]
#>  reprex              2.0.1      2021-08-05 [1]
#>  RhpcBLASctl         0.23-42    2023-02-11 [1]
#>  rlang               1.1.4      2024-06-04 [1]
#>  rmarkdown           2.10       2021-08-06 [2]
#>  rstudioapi          0.15.0     2023-07-07 [1]
#>  scales              1.3.0      2023-11-28 [1]
#>  sessioninfo         1.1.1      2018-11-05 [2]
#>  set6                0.2.6      2023-11-22 [1]
#>  stringi             1.7.3      2021-07-16 [2]
#>  stringr             1.5.0      2022-12-02 [1]
#>  survival            3.7-0      2024-06-05 [1]
#>  tibble              3.2.1      2023-03-20 [1]
#>  tidyselect          1.2.0      2022-10-10 [1]
#>  utf8                1.2.4      2023-10-22 [1]
#>  uuid                1.2-0      2024-01-14 [1]
#>  vctrs               0.6.5      2023-12-01 [1]
#>  withr               3.0.0      2024-01-16 [1]
#>  xfun                0.25       2021-08-06 [2]
#>  xgboost           * 1.7.8.1    2024-07-24 [1]
#>  yaml                2.2.1      2020-02-01 [2]
#>  source                                    
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  Github (xoopR/distr6@95d7359)             
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  Github (mlr-org/mlr3extralearners@6dc6965)
#>  CRAN (R 4.1.1)                            
#>  Github (mlr-org/mlr3pipelines@c542a26)    
#>  Github (mlr-org/mlr3proba@5205752)        
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  Github (xoopR/param6@0fa3577)             
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  Github (xoopR/set6@a901255)               
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#> 
#> [1] /Users/Lee/Library/R/x86_64/4.1/library
#> [2] /Library/Frameworks/R.framework/Versions/4.1/Resources/library

谢谢你。

xgboost mlr3proba
1个回答
0
投票

很难肯定地说,尽管我非常肯定我们正在做正确的数据转换。一些可能需要检查的事情:

  • 您是否对
    xgboost::xgb.DMatrix
    进行相同的编码?对于
    cox
    目标,我们只需要否定审查观察的标签(观察时间):请参阅 here 了解 xgboost 的内部辅助函数。还有一个“类似的问题”,其中有一些代码可以完全执行该转换。我认为那部分是相同的。 两个版本都有
  • nrounds = 1
  • 吗? (最近默认改为1000,见
    新闻
    我会尝试在没有因子编码的数据集中进行测试,例如
  • tsk("gbcs")
  • 以在调查此类事情时简化事情。
    您在手动版本中设置了
  • watchlist
  • ,在
    mlr3
    中它是空的(
    NULL
    )。在最新版本中,我们支持内部验证和提前停止 xgboost 顺便说一句,但在这里我将确保每个参数都相同,并且输出
    raw
    xgboost 模型相同(预测肯定会遵循)。
    
        
© www.soinside.com 2019 - 2024. All rights reserved.