如何在Pytorch中制作空张量?

问题描述 投票:0回答:2

在Python中,我们可以通过执行

a = []
轻松创建一个空列表。我想做类似的事情,但是使用 Pytorch 张量。

如果您想知道为什么我需要它,我想获取给定数据加载器内的所有数据(以创建另一个客户数据加载器)。拥有一个空张量可以帮助我使用 for 循环收集张量内的所有数据。这是它的 sudo 代码。

all_data_tensor = # An empty tensor

for data in dataloader:
  all_data_tensor  = torch.cat((all_data_tensor, data), 0)

有什么办法可以做到这一点吗?

python pytorch tensor
2个回答
6
投票

我们可以使用 torch.empty 来做到这一点。但请注意

torch.empty
需要维度,我们应该将 0 赋予第一个维度以获得空张量。

代码将是这样的:

# suppose the data generated by the dataloader has the size of (batch, 25)
all_data_tensor = torch.empty((0, 25), dtype=torch.float32)
# first dimension should be zero.

for data in dataloader:
  all_data_tensor = torch.cat((all_data_tensor, data), 0)

0
投票

您可以使用任何允许大小参数的火炬函数创建一个空张量。只需传递 0 作为维度之一即可。因此,以下所有内容都会产生大小相同的空张量

(0,25)
,即具有 0 行和 25 列。

w = torch.empty(0, 25)
x = torch.zeros(0, 25)
y = torch.ones(0, 25)
z = torch.rand(0, 25)
v = torch.randn(0, 25)
# etc.

对这些(例如

tolist()
)调用
x.tolist()
会返回
[]


在大多数情况下,您可能不需要初始化这样的空张量。一些替代方案:

创建单个 Python 列表并稍后连接

例如,与其在循环中连接张量,不如先创建一个列表,最后再创建一次张量,这样会快得多。

对于OP中的示例,您可以这样做:

# if `dataloader` is a list
all_data_tensor = torch.cat(dataloader, dim=0)

# if `dataloader` is a generator
all_data_tensor = torch.cat(list(dataloader), dim=0)

以下 timeit 测试表明,如果我们连接 100 个 1000x4 张量,创建一次列表的速度会快 10 倍以上。随着单个张量的大小和/或张量数量的增加,这种差距会进一步增加。

%%timeit
dataloader = (torch.ones(1000, 4) for _ in range(1, 100))
data_list = list(dataloader)                # <--- create a list of tensors first
all_tensors = torch.cat(data_list, dim=0)   # <--- concatenate them once
# 2.75 ms ± 271 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


%%timeit
dataloader = (torch.ones(1000, 4) for _ in range(1, 100))
all_data_tensor = torch.empty(0, 4)         # <--- initialize 
for data in dataloader:
    all_data_tensor = torch.cat((all_data_tensor, data), dim=0)
# 38.3 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

用第一个张量初始化

需要在循环中连接的一个用例是连接数千个张量,因此从它们中创建单个列表会消耗太多内存。在这种情况下,循环连接可能会很有用。然而,我们可以使用第一个张量进行初始化并连接到它,而不是初始化一个空张量。

# if `dataloader` is a generator
all_data_tensor = next(dataloader)          # <--- initialize with first tensor
for data in dataloader:
    all_data_tensor = torch.cat((all_data_tensor, data), dim=0)
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.