我有以下火炬数据集(我已替换实际代码以从具有随机数生成的文件中读取数据,以使其可重现性最小):
from torch.utils.data import Dataset
import torch
class TempDataset(Dataset):
def __init__(self, window_size=200):
self.window = window_size
self.x = torch.randn(4340, 10, dtype=torch.float32) # None
self.y = torch.randn(4340, 3, dtype=torch.float32)
self.len = len(self.x) - self.window + 1 # = 4340 - 200 + 1 = 4141
# Hence, last window start index = 4140
# And last window will range from 4140 to 4339, i.e. total 200 elements
def __len__(self):
return self.len
def __getitem__(self, index):
# AFAIU, below if-condition should NEVER evaluate to True as last index with which
# __getitem__ is called should be self.len - 1
if index == self.len:
print('self.__len__(): ', self.__len__())
print('Tried to access eleemnt @ index: ', index)
return self.x[index: index + self.window], self.y[index + self.window - 1]
ds = TempDataset(window_size=200)
print('len: ', len(ds))
counter = 0 # no record is read yet
for x, y in ds:
counter += 1 # above line read one more record from the dataset
print('counter: ', counter)
它打印:
len: 4141
self.__len__(): 4141
Tried to access eleemnt @ index: 4141
counter: 4141
据我了解,
__getitem__()
被称为index
,范围从0
到__len__()-1
。如果这是正确的,那么当数据本身的长度是 4141 时,为什么它试图用索引 4141 调用 __getitem__()
?
我注意到的另一件事是,尽管使用
index = 4141
进行调用,但它似乎没有返回任何元素,这就是为什么 counter
保持在 4141
我的眼睛(或大脑)在这里缺少什么?
PS:虽然不会有任何效果,但为了确认一下,我也尝试用火炬
DataSet
包裹DataLoader
,但效果还是一样。
您在这里看到的是 Python 中未实现
__iter__
的类的标准迭代协议。
当你写
for x, y in ds:
,并且ds
没有__iter__
方法时,Python运行的相当于这样:
i = 0
while True:
try:
x, y = ds[i]
except IndexError:
break
[the body of your loop]
i = i + 1