为什么我会收到“运行时错误:input_lengths 的大小必须为batch_size”,尽管 input_lengths 等于batch_size?

问题描述 投票:0回答:1

我正在尝试在我的数据上训练 ASR 的 Wav2Vec2CTC 模型。我正在使用 CTC 损失。代码如下:

def train_model(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0

    for batch in tqdm(train_loader, desc="Training", unit="batch"):
        input_values = batch["input_values"].to(device)
        labels = batch["labels"].to(device)
        labels = labels.squeeze(1)
        optimizer.zero_grad()
        logits = model(input_values).logits
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        input_lengths = torch.full(size=(input_values.shape[0],), fill_value=input_values.shape[1], dtype=torch.long).to(device)
        label_lengths = torch.full(size=(labels.shape[0],), fill_value=labels.shape[1], dtype=torch.long).to(device)
        print(input_lengths.size())
        print(label_lengths.size())
        loss = criterion(log_probs, labels, input_lengths, label_lengths)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

但我从 CTCloss 得到:“运行时错误:input_lengths 的大小必须为 batch_size”。 我使用的是batch_size = 4并且input_lengths.size(),label_lengths.size()都给出torch.Size([4])。 print(input_lengths) 还给出一个张量 [250000, 250000, 250000, 250000]。我不知道如何解决这个问题。

python pytorch speech-recognition ctc
1个回答
0
投票

批量大小与提供给 CTC 损失函数的 input_lengths 张量大小不匹配。

input_lengths = torch.full(size=(input_values.shape[0],), fill_value=input_values.shape[1], dtype=torch.long).to(device)

def train_model(model, train_loader, optimizer, criterion, device):
model.train()
total_loss = 0.0

for batch in tqdm(train_loader, desc="Training", unit="batch"):
    input_values = batch["input_values"].to(device)
    labels = batch["labels"].to(device)
    labels = labels.squeeze(1)

    # Compute actual lengths of input sequences
    input_lengths = torch.tensor([len(seq) for seq in input_values]).to(device)
    label_lengths = torch.full(size=(labels.shape[0],), fill_value=labels.shape[1], dtype=torch.long).to(device)

    optimizer.zero_grad()
    logits = model(input_values).logits
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    
    loss = criterion(log_probs, labels, input_lengths, label_lengths)
    loss.backward()
    optimizer.step()
    
    total_loss += loss.item()

return total_loss / len(train_loader)
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.