我有一些使用对数概率的代码。当我想从概率分布中抽取样本时,我使用
import numpy as np
probs = np.exp(logprobs)
probs /= probs.sum()
sample = np.random.choice(X, p=probs, size=1)[0]
但是这里的求幂和除法有一些开销。而 numpy
random.choice
函数要求概率在 0 和 1 之间,并且总和为 1.
有什么快速的技巧可以让我使用非归一化对数概率数组来抽样吗?我一次只需要一个样本,绘制它的频率只需要与对数概率成正比。
使用Gumbel-max技巧。在 this answer on Cross Validated 中查看更多解释和参考资料。这是一个最小的代码示例:
import numpy as np
# Assume we only have log-probabilities (for sampling, even logits will do)
log_probs = np.log([0.1, 0.2, 0.3, 0.4])
num_categories = len(log_probs)
# Sample a single category
gumbels = np.random.gumbel(size=num_categories)
sample = np.argmax(log_probs + gumbels)
np.random.choice
相同的界面。它只能通过替换进行采样,并且只返回索引。
from typing import Union
import numpy as np
def random_choice_log_space(
logits: np.ndarray,
size: int = 1,
random_state: Union[np.random.RandomState, int] = None,
) -> np.ndarray:
"""
Sample (with replacement) from a categorical distribution parametrized by logits or
log-probabilities.
Parameters
----------
logits : np.ndarray
the last dimension contains log-probabilities (e.g., out of a log-softmax
function) or unnormalized logits corresponding to the categorical
distribution(s)
size : int, optional
sample size, by default 1
random_state : Union[np.random.RandomState, int], optional
``np.random.RandomState`` object or an integer seed, by default None
Returns
-------
np.ndarray
sampled indexes
Raises
------
ValueError
if `size` is negative
"""
if size < 0:
raise ValueError("size must be at least 1.")
# Independently sample as many Gumbels as needed. During addition, they'll be
# broadcasted
_gumbels_shape = (size,) + logits.shape if size > 1 else logits.shape
# Create a RandomState if needed
if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(seed=random_state)
gumbels = random_state.gumbel(size=_gumbels_shape)
gumbels_rescaled: np.ndarray = logits + gumbels
return gumbels_rescaled.argmax(axis=-1)
如文档字符串中所述,您可以将对数概率或非标准化 logits 传递给输入
logits
。那是因为这两个输入仅相差一个常数——具体来说,log-sum-exp(probabilities)——这是无关紧要的,因为采用了 argmax。
为了使
random_choice_log_space
正确,它需要从logits
隐含的概率分布中独立采样。独立性部分已经很清楚了。所以我们只需要将样本的经验分布与实际分布进行比较即可。
import numpy as np
import pandas as pd
from scipy.special import logsumexp, softmax
_probs = np.array([0.1, 0.2, 0.3, 0.4])
log_probs = np.log(_probs)
logits = np.log(_probs) + logsumexp(_probs, axis=-1)
# You start out with access to log_probs or logits
num_categories = len(_probs)
sample_size = 500_000
seed = 123
random_state = np.random.RandomState(seed)
# helper function
def empirical_distr(discrete_samples):
return (pd.Series(discrete_samples)
.value_counts(normalize=True)
.sort_index()
.to_numpy())
# np.random.choice (select one at a time) AKA vanilla sampling
def random_choice_log_space_vanilla(logits, size, random_state=None):
probs = softmax(logits, axis=-1)
if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(seed=random_state)
return random_state.choice(len(probs), p=probs, size=size, replace=True)
samples = random_choice_log_space_vanilla(logits, size=sample_size, random_state=random_state)
distr_vanilla = empirical_distr(samples)
# random_choice_log_space for log-probabilities input
samples = random_choice_log_space(log_probs, size=sample_size, random_state=random_state)
distr_log_probs = empirical_distr(samples)
# random_choice_log_space for logits input
samples = random_choice_log_space(logits, size=sample_size, random_state=random_state)
distr_logits = empirical_distr(samples)
print(pd.DataFrame({'rel error (vanilla)': (distr_vanilla - _probs)/_probs,
'rel error (log-probs)': (distr_log_probs - _probs)/_probs,
'rel error (logits)': (distr_logits - _probs)/_probs},
index=pd.Index(range(num_categories), name='category')))
rel error (vanilla) rel error (log-probs) rel error (logits)
category
0 -0.005760 0.000560 -0.002960
1 0.004730 0.000380 -0.002560
2 -0.002367 -0.000567 0.004273
3 0.000850 0.000095 -0.001185
如果
random_choice_log_space
总是比 softmaxing 和使用 np.random.choice
慢,那么这些工作都无关紧要。幸运的是,在这种情况下存在一个足够普遍的问题。这是你问题中的问题:你有logits
,你想对一个元素进行采样。
from time import time
from scipy.stats import trim_mean
def time_func(func, *args, num_replications: int=50, **kwargs) -> list[float]:
'''
Returns a list, `times`, where `times[i]` is the time it took to run
`func(*args, **kwargs)` at replication `i` for `i in range(num_replications)`.
'''
times = []
for _ in range(num_replications):
time_start = time()
_ = func(*args, **kwargs)
time_end = time()
times.append(time_end - time_start)
return times
category_sizes = np.power(2, np.arange(1, 14+1))
num_replications = 100
times_vanilla = []
times_gumbel = []
for size in category_sizes:
logits = np.random.normal(size=size)
times_vanilla.append(time_func(random_choice_log_space_vanilla, logits, size=1,
num_replications=num_replications))
times_gumbel.append(time_func(random_choice_log_space, logits, size=1,
num_replications=num_replications))
(pd.DataFrame({'vanilla': trim_mean(times_vanilla, 0.1, axis=1),
'Gumbel': trim_mean(times_gumbel, 0.1, axis=1)},
index=pd.Index(category_sizes, name='# categories'))
.plot.bar(title='Categorical sampling',
figsize=(8,5),
ylabel='mean wall-clock time (sec)'));
情节(我多次运行它)表明,对于较大的类别大小,比较变得不稳定。
注:预先计算并在绘制中重复使用 Gumbel 样本 将显着加快技巧。但我担心抽取的样本会依赖于常见的 Gumbel 样本。 (独立于 logits 数据的 Gumbel 样本是不够的,这是链接评论所说的。)我将进一步研究这个想法并在此处更新。