max_epochs = 22
for epoch in range(max_epochs):
model.to(device)
# Training
train_loss = 0
model.train()
for local_batch, (centers, lefts, rights) in enumerate(training_generator):
# Transfer to GPU
centers, lefts, rights = toDevice(centers, device), toDevice(lefts, device), toDevice(rights, device)
# Model computations
optimizer.zero_grad()
datas = [centers, lefts, rights]
for data in datas:
imgs, angles = data
print("training image: ", imgs.shape)
outputs = model(imgs)
loss = criterion(outputs, angles.unsqueeze(1))
loss.backward()
optimizer.step()
train_loss += loss.data[0].item()
if local_batch % 100 == 0:
print('Loss: %.3f '
% (train_loss/(local_batch+1)))
TypeError Traceback (most recent call last)
<ipython-input-23-5f093ac2d4e3> in <cell line: 2>()
6 train_loss = 0
7 model.train()
----> 8 for local_batch, (centers, lefts, rights) in enumerate(training_generator):
9 # Transfer to GPU
10 centers, lefts, rights = toDevice(centers, device), toDevice(lefts, device), toDevice(rights, device)
3 frames
/usr/local/lib/python3.10/dist-packages/torch/_utils.py in reraise(self)
642 # instantiate since we don't know how to
643 raise RuntimeError(msg) from None
--> 644 raise exception
645
646
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-13-19d4a742dd38>", line 10, in __getitem__
center_img, steering_angle_center = augment(batch_samples[0], steering_angle)
File "<ipython-input-11-d1f2a7cf14ec>", line 5, in augment
current_image = current_image[65:-25, :, :]
TypeError: 'NoneType' object is not subscriptable
我一直在使用这个现有的代码自动驾驶汽车的深度学习。但是当我运行上面的代码块时,我收到了错误。在 enumerate(training_generator): 的 local_batch, (centers, lefts, rights) 这一行上突出显示了错误。我怎样才能摆脱这个错误?