我正在尝试复制此网页中介绍的实验https://adversarial-ml-tutorial.org/adversarial_examples/
我获得了 jupyter 笔记本并加载到我的本地主机中并使用 Jupiter 笔记本打开它。当我运行以下代码以使用以下代码获取数据集时:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)
我收到以下错误:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz
0/? [00:00<?, ?it/s]
---------------------------------------------------------------------------
HTTPError Traceback (most recent call last)
<ipython-input-15-e6f62798f426> in <module>
2 from torch.utils.data import DataLoader
3
----> 4 mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
5 mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
6 train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
~\Anaconda3\lib\site-packages\torchvision\datasets\mnist.py in __init__(self, root, train, transform, target_transform, download)
77
78 if download:
---> 79 self.download()
80
81 if not self._check_exists():
~\Anaconda3\lib\site-packages\torchvision\datasets\mnist.py in download(self)
144 for url, md5 in self.resources:
145 filename = url.rpartition('/')[2]
--> 146 download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
147
148 # process and save as torch files
~\Anaconda3\lib\site-packages\torchvision\datasets\utils.py in download_and_extract_archive(url, download_root, extract_root, filename, md5, remove_finished)
254 filename = os.path.basename(url)
255
--> 256 download_url(url, download_root, filename, md5)
257
258 archive = os.path.join(download_root, filename)
~\Anaconda3\lib\site-packages\torchvision\datasets\utils.py in download_url(url, root, filename, md5)
82 )
83 else:
---> 84 raise e
85 # check integrity of downloaded file
86 if not check_integrity(fpath, md5):
~\Anaconda3\lib\site-packages\torchvision\datasets\utils.py in download_url(url, root, filename, md5)
70 urllib.request.urlretrieve(
71 url, fpath,
---> 72 reporthook=gen_bar_updater()
73 )
74 except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
~\Anaconda3\lib\urllib\request.py in urlretrieve(url, filename, reporthook, data)
245 url_type, path = splittype(url)
246
--> 247 with contextlib.closing(urlopen(url, data)) as fp:
248 headers = fp.info()
249
~\Anaconda3\lib\urllib\request.py in urlopen(url, data, timeout, cafile, capath, cadefault, context)
220 else:
221 opener = _opener
--> 222 return opener.open(url, data, timeout)
223
224 def install_opener(opener):
~\Anaconda3\lib\urllib\request.py in open(self, fullurl, data, timeout)
529 for processor in self.process_response.get(protocol, []):
530 meth = getattr(processor, meth_name)
--> 531 response = meth(req, response)
532
533 return response
~\Anaconda3\lib\urllib\request.py in http_response(self, request, response)
639 if not (200 <= code < 300):
640 response = self.parent.error(
--> 641 'http', request, response, code, msg, hdrs)
642
643 return response
~\Anaconda3\lib\urllib\request.py in error(self, proto, *args)
567 if http_err:
568 args = (dict, 'default', 'http_error_default') + orig_args
--> 569 return self._call_chain(*args)
570
571 # XXX probably also want an abstract factory that knows when it makes
~\Anaconda3\lib\urllib\request.py in _call_chain(self, chain, kind, meth_name, *args)
501 for handler in handlers:
502 func = getattr(handler, meth_name)
--> 503 result = func(*args)
504 if result is not None:
505 return result
~\Anaconda3\lib\urllib\request.py in http_error_default(self, req, fp, code, msg, hdrs)
647 class HTTPDefaultErrorHandler(BaseHandler):
648 def http_error_default(self, req, fp, code, msg, hdrs):
--> 649 raise HTTPError(req.full_url, code, msg, hdrs, fp)
650
651 class HTTPRedirectHandler(BaseHandler):
HTTPError: HTTP Error 403: Forbidden
非常感谢任何解决此问题的帮助。 我也可以直接从链接下载数据集,但我不知道如何使用它!
是的,这是一个已知的错误:https://github.com/pytorch/vision/issues/3500
可能的解决方案可以是修补
MNIST
download
方法。
但需要安装
wget
。
对于 Linux:
sudo apt install wget
对于 Windows:
choco install wget
import os
import subprocess as sp
from torchvision.datasets.mnist import MNIST, read_image_file, read_label_file
from torchvision.datasets.utils import extract_archive
def patched_download(self):
"""wget patched download method.
"""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
# download files
for url, md5 in self.resources:
filename = url.rpartition('/')[2]
download_root = os.path.expanduser(self.raw_folder)
extract_root = None
remove_finished = False
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
# Use wget to download archives
sp.run(["wget", url, "-P", download_root])
archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished)
# process and save as torch files
print('Processing...')
training_set = (
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f)
print('Done!')
MNIST.download = patched_download
mnist_train = MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size=1, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=1, shuffle=False)
此 MNIST 下载 403 错误已在 2021 年发布的 torchvision v0.10.0 中修复。