我正在尝试在我的数据上训练 ASR 的 Wav2Vec2CTC 模型。我正在使用 CTC 损失。代码如下:
def train_model(model, train_loader, optimizer, criterion, device):
model.train()
total_loss = 0.0
for batch in tqdm(train_loader, desc="Training", unit="batch"):
input_values = batch["input_values"].to(device)
labels = batch["labels"].to(device)
labels = labels.squeeze(1)
optimizer.zero_grad()
logits = model(input_values).logits
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
input_lengths = torch.full(size=(input_values.shape[0],), fill_value=input_values.shape[1], dtype=torch.long).to(device)
label_lengths = torch.full(size=(labels.shape[0],), fill_value=labels.shape[1], dtype=torch.long).to(device)
print(input_lengths.size())
print(label_lengths.size())
loss = criterion(log_probs, labels, input_lengths, label_lengths)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
但我从 CTCloss 得到:“运行时错误:input_lengths 的大小必须为 batch_size”。 我使用的是batch_size = 4并且input_lengths.size(),label_lengths.size()都给出torch.Size([4])。 print(input_lengths) 还给出一个张量 [250000, 250000, 250000, 250000]。我不知道如何解决这个问题。
批量大小与提供给 CTC 损失函数的 input_lengths 张量大小不匹配。
input_lengths = torch.full(size=(input_values.shape[0],), fill_value=input_values.shape[1], dtype=torch.long).to(device)
def train_model(model, train_loader, optimizer, criterion, device):
model.train()
total_loss = 0.0
for batch in tqdm(train_loader, desc="Training", unit="batch"):
input_values = batch["input_values"].to(device)
labels = batch["labels"].to(device)
labels = labels.squeeze(1)
# Compute actual lengths of input sequences
input_lengths = torch.tensor([len(seq) for seq in input_values]).to(device)
label_lengths = torch.full(size=(labels.shape[0],), fill_value=labels.shape[1], dtype=torch.long).to(device)
optimizer.zero_grad()
logits = model(input_values).logits
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
loss = criterion(log_probs, labels, input_lengths, label_lengths)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)