我试图重新实现 Huggingface 变形金刚模型的 model.generate() 函数。我这样做是为了实现 logit-bias,这是正常函数不允许的。但在我达到这个目标之前,我的 top-p 采样遇到了很多问题。
这是代码片段:
generation_args = {
"max_new_tokens": 500,
"temperature": 0.4, # Adjust temperature if needed for more or less randomness
"do_sample": True, # Enable sampling
"top_p": 0.5, # Set the cumulative probability for nucleus sampling
"top_k": None, # Optionally, you can set top_k if you want to use it alongside or instead of top_p
}
def top_p_filtering(logits, top_p):
"""Filter the logits using top-p (nucleus) sampling."""
# Sort logits in descending order and get the sorted indices
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# Compute the cumulative probabilities of the sorted logits
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
# Create a mask for the tokens to keep
sorted_indices_to_keep = cumulative_probs <= top_p
# Ensure that at least one token is kept (the first token, which has the highest logit)
sorted_indices_to_keep[..., 0] = True
# Filter out the tokens to remove by setting their logits to negative infinity
logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')
return logits
def custom_generate(input_ids, streamer, max_new_tokens, temperature, top_p):
past_key_values = None
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
use_cache=True
)
logits = outputs.logits[:, -1, :] # Get logits of the last token
# Apply temperature to logits
if temperature != 1.0:
logits = logits / temperature
# Apply top-p sampling
if top_p is not None and top_p < 1.0:
logits = top_p_filtering(logits, top_p)
print("1")
next_token_probs = torch.nn.functional.softmax(logits, dim=-1)
print("2")
# Check if next_token_probs contains valid probabilities
next_token_id = torch.multinomial(next_token_probs,
num_samples=1)
print("3")
streamer.put(next_token_id) # Pass the tensor directly to the streamer
input_ids = next_token_id # Set the next input to the last generated token
attention_mask = torch.cat(
[attention_mask, torch.ones((attention_mask.shape[0], 1), device=attention_mask.device)], dim=1)
past_key_values = outputs.past_key_values
if next_token_id.item() == tokenizer.eos_token_id:
break
with torch.no_grad():
custom_generate(input_ids, streamer, generation_args["max_new_tokens"], generation_args["temperature"], generation_args["top_p"])
我面临的错误:
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [10,0,0], thread: [63,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Exception in thread Thread-18 (generate):
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/usr/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/mnt/c/Users/User/Documents/EmpatheticChatBot/Inference-Server.py", line 130, in generate
custom_generate(input_ids, streamer, generation_args["max_new_tokens"], generation_args["temperature"], generation_args["top_p"])
File "/mnt/c/Users/User/Documents/EmpatheticChatBot/Inference-Server.py", line 108, in custom_generate
next_token_id = torch.multinomial(next_token_probs,
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
整个问题只有在添加 top-p 采样后才出现。
我预计我的采样会起作用,因为我已经检查了我的代码大约 30 次。 ChatGPT 说这段代码是完美的,而且我的错误真的很难调试。我的假设是值被错误地过滤或将它们设置为“坏”值。
问题是您在这一行所做的索引:
logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')
出于我将解释的原因,这会导致索引越界错误。越界索引是导致
CUDA error: device-side assert triggered
错误的常见原因。
考虑以下因素:
import torch
import torch.nn as nn
torch.manual_seed(42)
top_p = 0.2
logits = torch.randn(8, 128) # random logits
# sort logits
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# calculate cumulative probs
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
# apply top p threshold to cumulative probs
sorted_indices_to_keep = cumulative_probs <= top_p
# ensure at least one index is kept
sorted_indices_to_keep[..., 0] = True
# this is the problem: logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')
print(logits.shape, sorted_indices[~sorted_indices_to_keep].shape)
> torch.Size([8, 128]) torch.Size([989])
当您索引
sorted_indices[~sorted_indices_to_keep]
时,两个输入的形状为 (8, 128)
,但输出的形状为 (989,)
(或类似,具体取决于虚拟 logits 的随机种子)。
发生这种情况是因为每行中
sorted_indices_to_keep
的 True
值数量不规则。这意味着索引操作无法将输出解析为每行大小相同的干净的二维张量。 Pytorch 通过从索引张量返回每个 True
值的展开向量来处理这种情况。
这意味着当您尝试计算
logits[sorted_indices[~sorted_indices_to_keep]]
时,您正在使用一个长的 1D 张量来索引到一个小的 2D 张量。如果您在 CPU 上运行此程序,您会收到类似 IndexError: index 20 is out of bounds for dimension 0 with size 8
的错误。当您在 GPU 上运行时,您会收到 Cuda 断言错误。
要解决此问题,请使用
scatter
操作。使用这样的东西:
def top_p_filtering(logits, top_p, shift_indices=True, debug=False):
"""Filter the logits using top-p (nucleus) sampling."""
# Sort logits in descending order and get the sorted indices
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# Compute the cumulative probabilities of the sorted logits
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
# Create a mask for the tokens to keep
sorted_indices_to_keep = cumulative_probs <= top_p
# Optional: shift indices to the right. This results in keeping the first
# token above the top_p threshold. Skip this line to ensure that all
# token probs are strictly below the top_p threshold
if shift_indices:
sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
# Ensure that at least one token is kept (the first token, which has the highest logit)
sorted_indices_to_keep[..., 0] = True
# Use scatter to create top_p mask
mask = sorted_indices_to_keep.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_keep)
# Optional debug check to make sure top_p is being honored
# Note we need to compute probs before masking because applying softmax
# after masking will result in a distribution that sums to 1
if debug:
probs = torch.nn.functional.softmax(logits, dim=-1)
probs[~mask] = 0
print(probs.sum(-1))
# Use mask to set logit vals to -inf
logits[~mask] = float('-inf')
return logits