对 numpy 数组的许多随机排列进行采样的最快方法

问题描述 投票:0回答:3

与许多其他 numpy/random 函数不同,

numpy.random.Generator.permutation()
没有提供在单个函数调用中返回多个结果的明显方法。给定一个 (1d) numpy 数组
x
,我想对
n
x
排列进行采样(每个长度为 len(x)),并将结果作为形状为
(n, len(x))
的 numpy 数组。生成许多排列的一种简单方法是
np.array([rng.permutation(x) for _ in range(n)])
。这并不理想,主要是因为循环是在 Python 中而不是在已编译的 numpy 函数内。

import numpy as np

rng = np.random.default_rng(1234)
# x is big enough to not want to enumerate all permutations
x = rng.standard_normal(size=20)
n = 10000
perms = np.array([rng.permutation(x) for _ in range(n)])

我的用例是通过强力搜索来查找最小化特定属性的排列(构成“足够好”的搜索解决方案)。我可以使用 numpy 运算计算每个排列的感兴趣属性,这些运算可以很好地对所得排列矩阵进行矢量化/广播。事实证明,天真地生成排列矩阵是我的代码的瓶颈。有更好的办法吗?

python numpy performance permutation
3个回答
1
投票

您可以使用

rng.permuted
代替
rng.permutation
并将其与
np.tile
结合使用,以便多次重复
x
并独立地打乱每个重复。方法如下:

perms = rng.permuted(np.tile(x, n).reshape(n,x.size), axis=1)

在我的机器上这比您的初始代码快大约 10 倍。


0
投票

请注意,Jérome 的解决方案提供了“n”行的数组,但它可能包含重复。不同的行可能具有相同的“x”顺序(特别是如果“n”大于“x”)

如果您需要在不重复的情况下进行采样(就像我的情况一样),您可以随时执行

set(list(perm))
并保留唯一的组合“x”值


0
投票

遗憾的是我无法发表评论(声誉太低),但我使用长度为 500,000 的 x 对此代码进行了基准测试,实际上我发现第二种方式速度较慢。

在这里,我创建了 500,000 个长“索引”数组的 10,000 种排列(以尝试节省内存!):


    # Hey, lets do all the permutations up front!
    rng = np.random.default_rng(43)

    indices = np.arange(len(rp), len(rn))

    MAX_PERMS = 10000

    time_now = time.time()
    perms = np.array([rng.permutation(indices) for _ in range(MAX_PERMS)])
    print(
        f"Num permutations: {MAX_PERMS:9n}. Time taken: {time.time() - time_now:0.6f}."
    )

    time_now = time.time()
    perms = rng.permuted(
        np.tile(indices, MAX_PERMS).reshape(MAX_PERMS, indices.size), axis=1
    )
    print(
        f"Num permutations: {MAX_PERMS:9n}. Time taken: {time.time() - time_now:0.6f}."
    )

结果:

Num permutations:     10000. Time taken:  85.126815.
Num permutations:     10000. Time taken: 102.404646.
© www.soinside.com 2019 - 2024. All rights reserved.