作为我之前的问题的后续问题,我对提高现有递归采样函数的性能感兴趣。
通过递归采样,我的意思是为给定的暴露 ID 随机选择最多 n 个唯一的未暴露 ID,并为另一个暴露 ID 从剩余的未暴露 ID 中随机选择最多 n 个唯一的未暴露 ID。如果给定的公开 ID 没有剩余的未公开 ID,则该公开 ID 将被忽略。
原函数如下:
recursive_sample <- function(data, n) {
groups <- unique(data[["exposed"]])
out <- data.frame(exposed = character(), unexposed = character())
for (group in groups) {
chosen <- data %>%
filter(exposed == group,
!unexposed %in% out$unexposed) %>%
group_by(unexposed) %>%
slice(1) %>%
ungroup() %>%
sample_n(size = min(n, nrow(.)))
out <- rbind(out, chosen)
}
out
}
我能够创建一个更高效的,如下所示:
recursive_sample2 <- function(data, n) {
groups <- unique(data[["exposed"]])
out <- tibble(exposed = integer(), unexposed = integer())
for (group in groups) {
chosen <- data %>%
filter(exposed == group,
!unexposed %in% out$unexposed) %>%
filter(!duplicated(unexposed)) %>%
sample_n(size = min(n, nrow(.)))
out <- bind_rows(out, chosen)
}
out
}
样本数据和基准:
set.seed(123)
df <- tibble(exposed = rep(1:100, each = 100),
unexposed = sample(1:7000, 10000, replace = TRUE))
microbenchmark(f1 = recursive_sample(df, 5),
f2 = recursive_sample2(df, 5),
times = 10)
Unit: milliseconds
expr min lq mean median uq max neval cld
f1 1307.7198 1316.5276 1379.0533 1371.3952 1416.6360 1540.955 10 b
f2 839.0086 865.2547 914.8327 901.2288 970.9518 1036.170 10 a
但是,对于我的实际数据集,我需要一个更高效(即更快)的函数。任何关于更高效版本的想法,无论是在
data.table
中,涉及并行化还是其他方法,都是受欢迎的。
2023 年 12 月编辑
@minem 和 @ThomasIsCoding 的版本要快得多;但是,它们仅返回部分正确的结果。考虑样本数据,例如:
df <- structure(list(exposed = c(4L, 4L, 1L, 1L, 1L, 3L, 2L, 2L, 2L,
2L, 2L), unexposed = c(1L, 2L, 1L, 2L, 3L, 10L, 4L, 5L, 7L, 8L,
9L)), class = "data.frame", row.names = c("1", "2", "3", "4",
"5", "6", "7", "8", "9", "10", "11"))
exposed unexposed
1 4 1
2 4 2
3 1 1
4 1 2
5 1 3
6 3 10
7 2 4
8 2 5
9 2 7
10 2 8
11 2 9
我希望explore == 4被采样两次,expose == 1被采样一次,expose == 3被采样一次,expose == 2被采样两次。换句话说,抽样程序应反映所提供的数据顺序。所需输出:
exposed unexposed
1 4 2
2 4 1
3 1 3
4 3 10
5 2 4
6 2 9
来自@minem的选项永远不会返回exposeed == 4,并且始终仅选择unexposeed == 1 OR 2来表示exposeed == 1。
recursive_sample_minem <- function(data, n) {
i <- vector('integer')
unexposed2 <- vector(class(data$unexposed))
ux <- data$unexposed
exl <- split(1:nrow(data), data$exposed)
for (ii in exl) {
f2 <- !ux[ii] %in% unexposed2
f12 <- ii[f2]
dn <- !duplicated(ux[f12])
id3 <- f12[dn]
is <- id3[sample.int(min(length(id3), n))]
i <- c(i, is)
unexposed2 <- ux[i]
}
out <- data[i, ]
out
}
recursive_sample_minem(df, 2)
exposed unexposed
4 1 2
3 1 1
8 2 5
7 2 4
6 3 10
@ThomasIsCoding 中的选项也从不包含公开的 == 4,并且还会返回某些公开 ID 的错误次数:
recursice_sample_ThomasIsCoding <- function(data, n) {
Reduce(
\(x, y) {
rbind(x, head(subset(y, !unexposed %in% x$unexposed), n))
},
split(data[sample(1:nrow(data)), ], ~ exposed)
)
}
recursice_sample_ThomasIsCoding(df, 2)
exposed unexposed
3 1 1
4 1 2
5 1 3
10 2 8
9 2 7
6 3 10
处理向量要快得多:
recursive_sample3 <- function(data, n) {
groups <- unique(data[["exposed"]])
# working on vectors is faster
id <- 1:nrow(data)
i <- vector('integer')
unexposed2 <- vector(class(data$unexposed))
ex <- data$exposed
ux <- data$unexposed
for (group in groups) {
f1 <- ex == group # first filter
f2 <- !ux[f1] %in% unexposed2 # 2nd filter (only on those that match 1st)
id3 <- id[f1][f2][!duplicated(ux[f1][f2])] # check duplicates only on needed
# and select necesary row ids
is <- sample(id3, size = min(length(id3), n)) # sample row ids
i <- c(i, is) # add to list
unexposed2 <- ux[i] # resave unexposed2
}
out <- data[i, ] # only one data.frame subset
out$id <- NULL
out
}
基准:
microbenchmark(f1 = recursive_sample(df, 5),
f2 = recursive_sample2(df, 5),
f3 = recursive_sample3(df, 5),
times = 3)
# Unit: milliseconds
# expr min lq mean median uq max neval cld
# f1 1399.8988 1407.1939 1422.008133 1414.4889 1433.06280 1451.6367 3 a
# f2 667.0813 673.7229 678.106400 680.3645 683.61895 686.8734 3 b
# f3 6.2399 6.2625 9.531267 6.2851 11.17695 16.0688 3 c
迭代
recursive_sample3
并纳入 sample
的关注点:
f_minem <- function(data, n) {
i <- vector('integer')
unexposed2 <- vector(class(data$unexposed))
ux <- data$unexposed
exl <- split(1:nrow(data), data$exposed)
for (ii in exl) {
f2 <- !ux[ii] %in% unexposed2
f12 <- ii[f2]
dn <- !duplicated(ux[f12])
id3 <- f12[dn]
is <- id3[sample.int(min(length(id3), n))]
i <- c(i, is)
unexposed2 <- ux[i]
}
out <- data[i, ]
out
}
基准 nr2:
microbenchmark::microbenchmark(
recursive_sample3 = recursive_sample3(df, 5L),
recursive_sample4 = recursive_sample4(setDT(df), 5L),
f_minem = f_minem(df, 5L),
setup = {df <- copy(data)}
, times = 10
)
# Unit: milliseconds
# expr min lq mean median uq max neval cld
# recursive_sample3 6.2102 6.2974 9.63296 6.43245 16.3367 17.0746 10 a
# recursive_sample4 3.5145 3.6249 3.67077 3.67075 3.7513 3.7970 10 b
# f_minem 2.1705 2.1920 2.27510 2.23215 2.3784 2.4585 10 b
一个
data.table
解决方案,保留 setdiff
中使用的采样值的运行列表(或来自 collapse
的 %!in%):
library(data.table)
library(collapse) # for %!in%
recursive_sample4 <- function(data, n) {
sampled <- vector("list", uniqueN(data$exposed))
data[
,.(
unexposed = {
x <- setdiff(unexposed, unlist(sampled))
sampled[[.GRP]] <- x[sample.int(min(length(x), n))]
}
), exposed
]
}
recursive_sample5 <- function(data, n) {
sampled <- vector("list", uniqueN(data$exposed))
data[
,.(
unexposed = {
x <- unexposed[unexposed %!in% unlist(sampled)]
sampled[[.GRP]] <- x[sample.int(min(length(x), n))]
}
), exposed
]
}
时间(包括 @minem 的
recursive_sample3
):
data <- copy(df)
microbenchmark::microbenchmark(
recursive_sample2 = recursive_sample2(df, 5L),
recursive_sample3 = recursive_sample3(df, 5L),
recursive_sample4 = recursive_sample4(setDT(df), 5L),
recursive_sample5 = recursive_sample5(setDT(df), 5L),
setup = {df <- copy(data)}
)
#> Unit: milliseconds
#> expr min lq mean median uq max neval
#> recursive_sample2 416.5425 427.38700 452.520780 436.58280 459.79430 614.6392 100
#> recursive_sample3 4.5211 5.16330 6.765060 5.79820 6.95425 14.0693 100
#> recursive_sample4 3.2038 3.57650 4.676284 4.41120 4.90855 11.6975 100
#> recursive_sample5 2.2327 2.58255 3.384131 3.27405 3.93265 8.7091 100
请注意,当第一个参数的长度为 1 时,由于
recursive_sample3
的行为,sample
可能会给出错误的结果:
set.seed(123)
df <- tibble(exposed = rep(1:100, each = 100),
unexposed = sample(1:700, 10000, replace = TRUE))
nrow(recursive_sample3(df, 10L))
#> [1] 704
更简洁的解决方案可能是使用
Reduce
+ split
,我们首先对 data
的行进行洗牌,然后按组进行采样 迭代
ftic <- function(data, n) {
Reduce(
\(x, y) {
rbind(x, head(subset(y, !unexposed %in% x$unexposed), n))
},
split(data[sample(1:nrow(data)), ], ~exposed)
)
}
下面是更严格的压力测试,即
data
1e6
行,其中方法包括:
ftmfmnk <- function(data, n) {
groups <- unique(data[["exposed"]])
out <- tibble(exposed = integer(), unexposed = integer())
for (group in groups) {
chosen <- data %>%
filter(
exposed == group,
!unexposed %in% out$unexposed
) %>%
filter(!duplicated(unexposed)) %>%
sample_n(size = min(n, nrow(.)))
out <- bind_rows(out, chosen)
}
out
}
fminem <- function(data, n) {
groups <- unique(data[["exposed"]])
# working on vectors is faster
id <- 1:nrow(data)
i <- vector("integer")
unexposed2 <- vector(class(data$unexposed))
ex <- data$exposed
ux <- data$unexposed
for (group in groups) {
f1 <- ex == group # first filter
f2 <- !ux[f1] %in% unexposed2 # 2nd filter (only on those that match 1st)
id3 <- id[f1][f2][!duplicated(ux[f1][f2])] # check duplicates only on needed
# and select necesary row ids
is <- sample(id3, size = min(length(id3), n)) # sample row ids
i <- c(i, is) # add to list
unexposed2 <- ux[i] # resave unexposed2
}
out <- data[i, ] # only one data.frame subset
out$id <- NULL
out
}
ftic <- function(data, n) {
Reduce(
\(x, y) {
rbind(x, head(subset(y, !unexposed %in% x$unexposed), n))
},
split(data[sample(1:nrow(data)), ], ~exposed)
)
}
基准测试如下
set.seed(123)
df <- tibble(
exposed = rep(1:1000, each = 1000),
unexposed = sample(1:70000, 1000000, replace = TRUE)
)
mbm <- microbenchmark(
tmfmnk = ftmfmnk(df, 5),
minem = fminem(df, 5),
tic = ftic(df, 5),
times = 10
)
boxplot(mbm)
我们将会看到
> mbm
Unit: milliseconds
expr min lq mean median uq max neval
tmfmnk 36809.9563 44276.3545 43780.8407 44897.2661 46175.1031 46948.8906 10
minem 5361.2796 5932.7752 5923.8811 6010.7775 6047.3716 6233.2919 10
tic 504.5749 519.5997 641.7935 607.2825 729.4545 868.1283 10
我这里没有任何先进的技术,只是一个带有
for
循环的动态编程方案,我相信一定有比我的更高效的方法
dp <- function(df, n) {
d <- table(df)
out <- list()
rnm <- row.names(d)
cnm <- colnames(d)
for (i in 1:nrow(d)) {
v <- which(d[i, ] > 0)
l <- length(v)
idx <- v[sample(l, min(l, n))]
out[[i]] <- data.frame(exposed = rnm[i], unexposed = cnm[idx])
d[, idx] <- 0
}
do.call(rbind, out)
}
和基准测试
set.seed(123)
df <- tibble(
exposed = rep(1:100, each = 100),
unexposed = sample(1:7000, 10000, replace = TRUE)
)
mbm <- microbenchmark(
f1 = recursive_sample(df, 5),
f2 = recursive_sample2(df, 5),
f3 = dp(df, 5),
times = 10
)
boxplot(mbm)
表演
> mbm
Unit: milliseconds
expr min lq mean median uq max neval
f1 1271.0135 1302.4310 1449.2193 1326.7630 1686.4329 1888.4549 10
f2 507.9350 516.8854 617.0313 559.0422 706.4300 801.0124 10
f3 212.8944 247.0066 278.1792 271.9010 309.7377 354.4320 10
此外,要检查结果
res <- dp(df, 5)
,我们可以使用
> table(res$exposed)
1 10 100 11 12 13 14 15 16 17 18 19 2 20 21 22 23 24 25 26
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
27 28 29 3 30 31 32 33 34 35 36 37 38 39 4 40 41 42 43 44
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
45 46 47 48 49 5 50 51 52 53 54 55 56 57 58 59 6 60 61 62
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
63 64 65 66 67 68 69 7 70 71 72 73 74 75 76 77 78 79 8 80
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
81 82 83 84 85 86 87 88 89 9 90 91 92 93 94 95 96 97 98 99
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
> anyDuplicated(res$unexposed)
[1] 0