如果多个 GPU 级别上有不同长度的张量数组,则默认的
all_gather
方法不起作用,因为它要求长度相同。
例如,如果您有:
if gpu == 0:
q = torch.tensor([1.5, 2.3], device=torch.device(gpu))
else:
q = torch.tensor([5.3], device=torch.device(gpu))
如果我需要按如下方式收集这两个张量数组:
all_q = [torch.tensor([1.5, 2.3], torch.tensor[5.3])
默认的
torch.all_gather
不起作用,因为长度,2, 1
不同。
由于无法直接使用内置方法进行收集,因此我们需要按照以下步骤编写自定义函数:
dist.all_gather
获取所有数组的大小。dist.all_gather
获取所有填充数组。下面的函数可以做到这一点:
def all_gather(q, ws, device):
"""
Gathers tensor arrays of different lengths across multiple gpus
Parameters
----------
q : tensor array
ws : world size
device : current gpu device
Returns
-------
all_q : list of gathered tensor arrays from all the gpus
"""
local_size = torch.tensor(q.size(), device=device)
all_sizes = [torch.zeros_like(local_size) for _ in range(ws)]
dist.all_gather(all_sizes, local_size)
max_size = max(all_sizes)
size_diff = max_size.item() - local_size.item()
if size_diff:
padding = torch.zeros(size_diff, device=device, dtype=q.dtype)
q = torch.cat((q, padding))
all_qs_padded = [torch.zeros_like(q) for _ in range(ws)]
dist.all_gather(all_qs_padded, q)
all_qs = []
for q, size in zip(all_qs_padded, all_sizes):
all_qs.append(q[:size])
return all_qs
一旦我们能够执行上述操作,我们就可以根据需要轻松使用
torch.cat
进一步连接成单个数组:
torch.cat(all_q)
[torch.tensor([1.5, 2.3, 5.3])
改编自:github
这是 @omsrisagar 解决方案的扩展,支持任意维张量(不仅是一维张量)。
def all_gather_nd(tensor):
"""
Gathers tensor arrays of different lengths in a list.
The length dimension is 0. This supports any number of extra dimensions in the tensors.
All the other dimensions should be equal between the tensors.
Args:
tensor (Tensor): Tensor to be broadcast from current process.
Returns:
(Tensor): output list of tensors that can be of different sizes
"""
world_size = dist.get_world_size()
local_size = torch.tensor(tensor.size(), device=tensor.device)
all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
dist.all_gather(all_sizes, local_size)
max_length = max(size[0] for size in all_sizes)
length_diff = max_length.item() - local_size[0].item()
if length_diff:
pad_size = (length_diff, *tensor.size()[1:])
padding = torch.zeros(pad_size, device=tensor.device, dtype=tensor.dtype)
tensor = torch.cat((tensor, padding))
all_tensors_padded = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(all_tensors_padded, tensor)
all_tensors = []
for tensor_, size in zip(all_tensors_padded, all_sizes):
all_tensors.append(tensor_[:size[0]])
return all_tensors
请注意,这要求所有张量具有相同的维度数,并且除第一个维度外,所有维度都相等。
可以使用内置方法而无需填充,因为 PyTorch 1.6.0 引入了
all_to_all
:
import torch.distributed as dist
import torch
def all_gather_variable(tensor_list, tensor, group=None, async_op=False):
rank, world_size = dist.get_rank(), dist.get_world_size()
shape = torch.as_tensor(tensor.shape)
shapes = [torch.empty_like(shape) for _ in range(world_size)]
dist.all_gather(shapes, shape, group=group)
inputs = [tensor] * world_size
return dist.all_to_all(tensor_list, inputs, group, async_op)