我用Python编写代码并使用pytorch库来实现GAN,这是一种生成图片的方法。这是我的代码:
class Reshape(torch.nn.Module):
def __init__(self, *shape):
super().__init__()
self.shape = shape
def forward(self, x):
return x.reshape(x.size(0), *self.shape)
这是发电机:
class Generator(torch.nn.Module):
def __init__(self, z_dim=64, num_channels=1):
super().__init__()
self.z_dim = z_dim
self.net = nn.Sequential(
nn.Linear(z_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 64 * 7 * 7),
nn.BatchNorm1d(64 * 7 * 7),
nn.ReLU(),
Reshape(64, 7, 7),
nn.PixelShuffle(2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.PixelShuffle(2),
nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding=1)
)
def forward(self, z):
return self.net(z)
这是鉴别器:
class Discriminator(torch.nn.Module):
def __init__(self, num_channels=1):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, padding=1, stride=2),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, padding=1, stride=2),
nn.ReLU(),
Reshape(64*7*7),
nn.Linear(64*7*7, 512),
nn.ReLU(),
nn.Linear(512, 1),
Reshape()
)
def forward(self, x):
return self.net(x)
这是计算损失的代码:
def loss_nonsaturating(d, g, x_real, *, device):
'''
Input Arguments:
- x_real (torch.Tensor): training data samples (64, 1, 28, 28)
- device (torch.device): 'cpu' by default
Returns:
- d_loss (torch.Tensor): nonsaturating discriminator loss
- g_loss (torch.Tensor): nonsaturating generator loss
'''
z = torch.randn(x_real.shape[0], g.z_dim, device=device)
gz = g(z)
dgz = F.sigmoid(d(gz))
dx = d(x_real)
real_label = torch.ones(x_real.shape[0], device=device)
fake_label = torch.zeros(x_real.shape[0], device=device)
bce_loss = F.binary_cross_entropy_with_logits
g_loss = bce_loss(dgz, real_label).mean()
d_loss = bce_loss(dx, real_label).mean() + bce_loss(dgz, fake_label).mean()
return d_loss, g_loss
这是训练模型的代码:
def build_input(x, y, device):
x_real = x.to(device)
y_real = y.to(device)
return x_real, y_real
num_latents = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = Generator(z_dim=64)
d = Discriminator()
g_optimizer = torch.optim.Adam(g.parameters(), lr=1e-3)
d_optimizer = torch.optim.Adam(d.parameters(), lr=1e-3)
iter_max = 1000
torch.autograd.set_detect_anomaly(True)
with tqdm(total=int(iter_max)) as pbar:
for idx, (x, y) in enumerate(train_loader):
x_real, y_real = build_input(x, y, device)
g_loss, d_loss = loss_nonsaturating(d, g, x_real, device=device)
d_optimizer.zero_grad()
d_loss.backward(retain_graph=True)
d_optimizer.step()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
我收到此错误:
0%| | 0/1000 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:251: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>
ColabKernelApp.launch_instance()
File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance
app.start()
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start
self.io_loop.start()
File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
self.asyncio_loop.run_forever()
File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>
lambda f: self._run_callback(functools.partial(callback, future))
File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback
ret = callback()
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner
self.ctx_run(self.run)
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
yielded = self.gen.send(value)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 377, in dispatch_queue
yield self.process_one()
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 250, in wrapper
runner = Runner(ctx_run, result, future, yielded)
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 748, in __init__
self.ctx_run(self.run)
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
yielded = self.gen.send(value)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
yield gen.maybe_future(dispatch(*args))
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
yielded = ctx_run(next, result)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
yield gen.maybe_future(handler(stream, idents, msg))
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
yielded = ctx_run(next, result)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
self.do_execute(
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
yielded = ctx_run(next, result)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
result = self._run_cell(
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
return runner(coro)
File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
coro.send(None)
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
if (await self.run_code(code, result, async_=asy)):
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-20-10c36497e22a>", line 17, in <cell line: 13>
g_loss, d_loss = loss_nonsaturating(d, g, x_real, device=device)
File "<ipython-input-18-c563e2132852>", line 16, in loss_nonsaturating
dx = d(x_real)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "<ipython-input-7-e00f0cc91a93>", line 24, in forward
return self.net(x)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 215, in forward
input = module(input)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
0%| | 0/1000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-20-10c36497e22a> in <cell line: 13>()
23
24 g_optimizer.zero_grad()
---> 25 g_loss.backward()
26 g_optimizer.step()
27
1 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
252 tensors,
253 grad_tensors_,
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [512, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
我做了任何事情来解决问题,但没有帮助。我替换了生成器和鉴别器反向传播代码,但在第二个代码中再次出现错误。你知道我该怎么办吗?
我只是改变了火车循环中的顺序,它就固定了!
with tqdm(total=int(iter_max)) as pbar:
for idx, (x, y) in enumerate(train_loader):
x_real, y_real = build_input(x, y, device)
g_loss, d_loss = loss_nonsaturating(d, g, x_real, device=device)
d_loss.backward(retain_graph=True)
g_loss.backward()
d_optimizer.zero_grad()
g_optimizer.zero_grad()
d_optimizer.step()
g_optimizer.step()