我正在尝试加速一个函数,该函数以许多类别的可能组合为多个记录随机采样多个记录,并确保它们是唯一的(即,假设有3条记录,其中任何一个可以是0或1,我想要10个随机的记录,它们可能是记录的唯一可能组合)。
如果我不使用numba,我可能会做这样的事情:
import numpy as np
def myfunc(categories, NumberOfRecords, maxsamples):
return np.unique( np.random.choice(np.arange(categories), size=(maxsamples*10, NumberOfRecords), replace=True), axis=0 )[0:maxsamples]
令人讨厌的是,numba在np.unique中不支持轴,所以我可以做类似的事情,但是某些记录可能证明是不唯一的。
from numba import njit, int64
import numpy as np
@njit(int64[:,:](int64, int64, int64), cache=True)
def myfunc(categories, NumberOfRecords, maxsamples):
return np.random.choice(np.arange(categories), size=(maxsamples, NumberOfRecords), replace=True)
myfunc(categories=2, NumberOfRecords=3, maxsamples=10)
例如在一个电话中(显然这里有一些随机性),我得到了下面的内容(索引1和6、3和4、7和9是相同的行):
array([[0, 1, 1],
[1, 1, 0],
[0, 1, 0],
[1, 0, 1],
[1, 0, 1],
[1, 1, 1],
[1, 1, 0],
[1, 0, 0],
[0, 0, 0],
[1, 0, 0]])
我的问题是:
这不能回答您提出的确切问题,但是您可能会发现它很有用。在下面,我不使用numba,但是所有操作都使用矢量化numpy函数。
您生成的结果的每一行都可以解释为以N为底的整数,其中N是类别数。通过这种解释,您想要的是在不替换整数[0,1,... N ** R-1]的情况下进行采样,其中R是“记录”的数目。为此,可以将choice
函数与参数replace=False
一起使用。一旦有了它,就需要将所选的整数转换为以N为底。为此,我使用函数int2base
,它是我在different answer中编写的函数的精简版本。
这里是代码:
import numpy as np def int2base(x, base, ndigits): # x = np.asarray(x) # Uncomment this line for general purpose use. powers = base ** np.arange(ndigits) digits = (x.reshape(x.shape + (1,)) // powers) % base return digits def makesample(ncategories, nrecords, nsamples, rng=None): if rng is None: rng = np.random.default_rng() n = ncategories ** nrecords choices = rng.choice(n, replace=False, size=nsamples) return int2base(choices, ncategories, nrecords)
在
makesample
中,我包括了可选参数rng
。它允许您指定保存choice
功能的对象。如果未提供,则使用np.random.default_rng()
。
示例:
In [118]: makesample(2, 3, 6) Out[118]: array([[0, 1, 1], [0, 0, 1], [1, 0, 1], [0, 0, 0], [1, 1, 0], [1, 1, 1]]) In [119]: makesample(5, 4, 12) Out[119]: array([[3, 4, 0, 1], [2, 0, 2, 0], [4, 2, 4, 3], [0, 1, 0, 4], [0, 2, 0, 1], [1, 2, 0, 1], [0, 3, 0, 4], [3, 3, 0, 3], [3, 4, 1, 4], [2, 4, 1, 1], [3, 4, 1, 0], [1, 1, 4, 4]])
makesample
将引发异常,如果您请求太多样本:
In [120]: makesample(2, 3, 10)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-120-80044e78a60a> in <module>
----> 1 makesample(2, 3, 10)
~/code_snippets/python/numpy/random_samples_for_so_question.py in makesample(ncategories, nrecords, nsamples, rng)
17 rng = np.random.default_rng()
18 n = ncategories ** nrecords
---> 19 choices = rng.choice(n, replace=False, size=nsamples)
20 return int2base(choices, ncategories, nrecords)
_generator.pyx in numpy.random._generator.Generator.choice()
ValueError: Cannot take a larger sample than population when 'replace=False'