我正在为情感分类任务训练多语言 bert 模型。我在 1 台机器上有 2 个 GPU,因此我使用 Huggingface
Accelerator
进行分布式训练。但是当我运行代码时,它会抛出运行时错误。
class BERTModel(nn.Module):
def __init__(self):
super(BERTModel, self).__init__()
self.bert = transformers.BertModel.from_pretrained("bert-base-multilingual-uncased")
self.bert_drop = nn.Dropout(0.3)
self.out = nn.Linear(768 * 2, 1) # *2 since we have 2 pooling layers
def forward(self, ids, mask, token_type_ids):
o1, _ = self.bert(
ids,
attention_mask=mask,
token_type_ids=token_type_ids
)
mean_pooling = torch.mean(o1, 1)
max_pooling, _ = torch.max(o1, 1)
cat = torch.cat((mean_pooling, max_pooling), 1)
bo = self.bert_drop(cat)
output = self.out(bo)
return output
def train_fn(data_loader, model, optimizer, scheduler):
"""
Training Function for the Model
parameters: data_loader - PyTorch DataLoader
model - The Model to be used for training
optimizer - The Optimizer to be used for training
scheduler - The Learning Rate Scheduler
returns: None
"""
accelerator = Accelerator()
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
model.train()
for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
ids = d["ids"]
token_type_ids = d["token_type_ids"]
mask = d["mask"]
targets = d["targets"]
ids = ids.to(torch.long)
token_type_ids = token_type_ids.to(torch.long)
mask = mask.to(torch.long)
targets = targets.to(torch.float)
optimizer.zero_grad()
outputs = model(ids=ids, mask=mask, token_type_ids=token_type_ids)
loss = loss_fn(outputs, targets)
if bi % 1000 == 0:
print(f"bi={bi}, loss={loss}")
accelerator.backward(loss)
optimizer.step()
scheduler.step()
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
<timed exec> in <module>
/opt/conda/lib/python3.6/site-packages/accelerate/notebook_launcher.py in notebook_launcher(function, args, num_processes, use_fp16, use_port)
107 try:
108 print(f"Launching a training on {num_processes} GPUs.")
--> 109 start_processes(launcher, nprocs=num_processes, start_method="fork")
110 finally:
111 # Clean up the environment variables set.
/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
156
157 # Loop on join until it returns True or raises an exception.
--> 158 while not context.join():
159 pass
160
/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py in join(self, timeout)
117 msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
118 msg += original_trace
--> 119 raise Exception(msg)
120
121
Exception:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
fn(i, *args)
File "/opt/conda/lib/python3.6/site-packages/accelerate/utils.py", line 274, in __call__
self.launcher(*args)
File "<timed exec>", line 276, in run
File "<timed exec>", line 74, in train_fn
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 726, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.6/site-packages/torch/cuda/amp/autocast_mode.py", line 135, in decorate_autocast
return func(*args, **kwargs)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 585, in forward
self.reducer.prepare_for_backward([])
RuntimeError: Expected to have finished reduction in the prior iteration before
starting a new one. This error indicates that your module has parameters that
were not used in producing loss. You can enable unused parameter detection by (1)
passing the keyword argument `find_unused_parameters=True` to
`torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward`
function outputs participate in calculating loss. If you already have done the
above two steps, then the distributed data parallel module wasn't able to locate
the output tensors in the return value of your module's `forward` function.
Please include the loss function and the structure of the return value of
`forward` of your module when reporting this issue (e.g. list, dict, iterable).
由于官方文档的误导性,
而不是使用,
from accelerate import Accelerator, DistributedDataParallelKwargs
accelerator = Accelerator()
在实际案例中,
from accelerate import Accelerator, DistributedDataParallelKwargs
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
希望有帮助:p