在Python中,我们可以通过执行
a = []
轻松创建一个空列表。我想做类似的事情,但是使用 Pytorch 张量。
如果您想知道为什么我需要它,我想获取给定数据加载器内的所有数据(以创建另一个客户数据加载器)。拥有一个空张量可以帮助我使用 for 循环收集张量内的所有数据。这是它的 sudo 代码。
all_data_tensor = # An empty tensor
for data in dataloader:
all_data_tensor = torch.cat((all_data_tensor, data), 0)
有什么办法可以做到这一点吗?
我们可以使用 torch.empty 来做到这一点。但请注意
torch.empty
需要维度,我们应该将 0 赋予第一个维度以获得空张量。
代码将是这样的:
# suppose the data generated by the dataloader has the size of (batch, 25)
all_data_tensor = torch.empty((0, 25), dtype=torch.float32)
# first dimension should be zero.
for data in dataloader:
all_data_tensor = torch.cat((all_data_tensor, data), 0)
您可以使用任何允许大小参数的火炬函数创建一个空张量。只需传递 0 作为维度之一即可。因此,以下所有内容都会产生大小相同的空张量
(0,25)
,即具有 0 行和 25 列。
w = torch.empty(0, 25)
x = torch.zeros(0, 25)
y = torch.ones(0, 25)
z = torch.rand(0, 25)
v = torch.randn(0, 25)
# etc.
对这些(例如
tolist()
)调用x.tolist()
会返回[]
。
在大多数情况下,您可能不需要初始化这样的空张量。一些替代方案:
例如,与其在循环中连接张量,不如先创建一个列表,最后再创建一次张量,这样会快得多。
对于OP中的示例,您可以这样做:
# if `dataloader` is a list
all_data_tensor = torch.cat(dataloader, dim=0)
# if `dataloader` is a generator
all_data_tensor = torch.cat(list(dataloader), dim=0)
以下 timeit 测试表明,如果我们连接 100 个 1000x4 张量,创建一次列表的速度会快 10 倍以上。随着单个张量的大小和/或张量数量的增加,这种差距会进一步增加。
%%timeit
dataloader = (torch.ones(1000, 4) for _ in range(1, 100))
data_list = list(dataloader) # <--- create a list of tensors first
all_tensors = torch.cat(data_list, dim=0) # <--- concatenate them once
# 2.75 ms ± 271 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
dataloader = (torch.ones(1000, 4) for _ in range(1, 100))
all_data_tensor = torch.empty(0, 4) # <--- initialize
for data in dataloader:
all_data_tensor = torch.cat((all_data_tensor, data), dim=0)
# 38.3 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
需要在循环中连接的一个用例是连接数千个张量,因此从它们中创建单个列表会消耗太多内存。在这种情况下,循环连接可能会很有用。然而,我们可以使用第一个张量进行初始化并连接到它,而不是初始化一个空张量。
# if `dataloader` is a generator
all_data_tensor = next(dataloader) # <--- initialize with first tensor
for data in dataloader:
all_data_tensor = torch.cat((all_data_tensor, data), dim=0)