使用 mlr3 进行 PipeOp 高基数因子编码时出错:“col_roles”名称无效

问题描述 投票:0回答:1

我正在尝试使用包含分类变量的数据和

glmnet
包来训练
mlr3
模型。由于分类变量有 27 个级别,因此我将其视为高基数特征并使用了影响编码。但是,我收到以下错误消息:

Error in .__Task__col_roles(self = self, private = private, super = super,  : 
  Assertion on 'names of col_roles' failed: Names must be a permutation of set {'feature','target','name','order','stratum','group','weight','coordinate','space','time'}, but has extra elements {'always_included'}.
This happened PipeOp high_cardinality_encoding's $train()

这是我的代码:

data <- read.csv("C:/Users/test.csv")
data$presence <- as.factor(data$presence)
data$habitat <- as.factor(data$habitat)
classif_task_sp <- mlr3spatial::as_task_classif_st(id = "A1", x = data[, which(!(names(data) %in% c("ID", "year")))], target = "presence", positive = "1", 
                                                   coordinate_names = c("x", "y"), crs = "EPSG:4326", coords_as_features = FALSE)
classif_task_sp$set_col_roles("presence", roles = c("target", "stratum"))
partition_classif_task_sp <- mlr3::partition(classif_task_sp, ratio = 0.67)

factor_encoding <- mlr3pipelines::po("removeconstants") %>>%
  ## mlr3pipelines::po("collapsefactors", no_collapse_above_prevalence = 0.01) %>>%
  mlr3pipelines::po("encodeimpact", affect_columns = selector_cardinality_greater_than(10), id = "high_cardinality_encoding") %>>%
  mlr3pipelines::po("encode", method = "one-hot", affect_columns = selector_cardinality_greater_than(3), id = "low_cardinality_encoding") %>>%
  mlr3pipelines::po("encode", method = "treatment", affect_columns = selector_type("factor"), id = "binary_encoding")

learner_glmnet <- mlr3tuningspaces::lts(mlr3::lrn("classif.glmnet", predict_type = "prob", standardize = FALSE))
learner_glmnet_factor_encoding <- mlr3::as_learner(factor_encoding %>>% learner_glmnet)

tuning <- mlr3tuning::auto_tuner(tuner = mlr3tuning::tnr("grid_search", resolution = 5, batch_size = 10),
                                 learner = learner_glmnet_factor_encoding,
                                 resampling = mlr3::rsmp("spcv_coords", folds = 2),
                                 measure = mlr3::msr("classif.prauc"),
                                 terminator = mlr3tuning::trm("evals", n_evals = 2, k = 0))

run_resampling <- mlr3::resample(classif_task_sp, learner = tuning, resampling = mlr3::rsmp("spcv_coords", folds = 2), store_models = TRUE)

run_training <- tuning$train(classif_task_sp, row_ids = partition_classif_task_sp$train)

这是数据集: https://www.dropbox.com/scl/fi/rfyj9oav5z5yipmkr4a9q/test.csv?rlkey=vsfsyhfgh4svnoos5z6t18u5q&st=vak2wayv&dl=0

mlr mlr3
1个回答
0
投票

您可以尝试更新您的软件包吗?这是一个错误,已通过 mlr3 0.21.1 和 mlr3fselect 1.2.1 解决。

© www.soinside.com 2019 - 2024. All rights reserved.