numpy.random.Generator.permutation()
没有提供在单个函数调用中返回多个结果的明显方法。给定一个 (1d) numpy 数组 x
,我想对 n
的 x
排列进行采样(每个长度为 len(x)),并将结果作为形状为 (n, len(x))
的 numpy 数组。生成许多排列的一种简单方法是np.array([rng.permutation(x) for _ in range(n)])
。这并不理想,主要是因为循环是在 Python 中而不是在已编译的 numpy 函数内。
import numpy as np
rng = np.random.default_rng(1234)
# x is big enough to not want to enumerate all permutations
x = rng.standard_normal(size=20)
n = 10000
perms = np.array([rng.permutation(x) for _ in range(n)])
我的用例是通过强力搜索来查找最小化特定属性的排列(构成“足够好”的搜索解决方案)。我可以使用 numpy 运算计算每个排列的感兴趣属性,这些运算可以很好地对所得排列矩阵进行矢量化/广播。事实证明,天真地生成排列矩阵是我的代码的瓶颈。有更好的办法吗?
您可以使用
rng.permuted
代替 rng.permutation
并将其与 np.tile
结合使用,以便多次重复 x
并独立地打乱每个重复。方法如下:
perms = rng.permuted(np.tile(x, n).reshape(n,x.size), axis=1)
在我的机器上这比您的初始代码快大约 10 倍。
请注意,Jérome 的解决方案提供了“n”行的数组,但它可能包含重复。不同的行可能具有相同的“x”顺序(特别是如果“n”大于“x”)
如果您需要在不重复的情况下进行采样(就像我的情况一样),您可以随时执行
set(list(perm))
并保留唯一的组合“x”值
遗憾的是我无法发表评论(声誉太低),但我使用长度为 500,000 的 x 对此代码进行了基准测试,实际上我发现第二种方式速度较慢。
在这里,我创建了 500,000 个长“索引”数组的 10,000 种排列(以尝试节省内存!):
# Hey, lets do all the permutations up front!
rng = np.random.default_rng(43)
indices = np.arange(len(rp), len(rn))
MAX_PERMS = 10000
time_now = time.time()
perms = np.array([rng.permutation(indices) for _ in range(MAX_PERMS)])
print(
f"Num permutations: {MAX_PERMS:9n}. Time taken: {time.time() - time_now:0.6f}."
)
time_now = time.time()
perms = rng.permuted(
np.tile(indices, MAX_PERMS).reshape(MAX_PERMS, indices.size), axis=1
)
print(
f"Num permutations: {MAX_PERMS:9n}. Time taken: {time.time() - time_now:0.6f}."
)
结果:
Num permutations: 10000. Time taken: 85.126815.
Num permutations: 10000. Time taken: 102.404646.