我做了一个分类器(SimpleModel)。 当我用 input=torch.randn(100,1) 和 label=(1 if input>0, else 0) 训练分类器时,分类器训练得很好,验证准确度~1.
但是,当我使用 input=gen(torch.randn(100,1)) 和 label=(1 if input>0, else 0) 训练分类器时,其中 gen() 是我制作的模型。分类器没有训练好。每次运行的训练/验证损失和训练/验证精度差异很大,训练效果不佳。
什么问题?
代码是:
import torch
from torch import nn
from torch import optim
import matplotlib.pyplot as plt
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear1=nn.Linear(1,1)
self.sigmoid=nn.Sigmoid()
def forward(self,x):
x=self.linear1(x)
x=self.sigmoid(x)
return x
class Gen(nn.Module):
def __init__(self):
super(Gen, self).__init__()
self.linear1=nn.Linear(1,2)
self.relu1=nn.ReLU()
self.linear2=nn.Linear(2,1)
def forward(self,x):
x=self.linear1(x)
x=self.relu1(x)
x=self.linear2(x)
return x
model=SimpleModel()
gen=Gen()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
loss_history=[]
val_loss_history=[]
acc_history=[]
val_acc_history=[]
for step in range(10000):
z=torch.randn(100,1)
label=(z>0).float().reshape(-1,1)
x=gen(z)
out=model(x)
loss=criterion(out,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc=(out.round()==label).float().mean()
val_x=torch.randn(100,1)
val_label=(val_x>0).float().reshape(-1,1)
val_out=model(val_x)
val_loss=criterion(val_out,val_label)
val_acc=(val_out.round()==val_label).float().mean()
print(loss.item(), acc.item(), val_loss.item(), val_acc.item())
loss_history.append(loss.item())
val_loss_history.append(val_loss.item())
acc_history.append(acc.item())
val_acc_history.append(val_acc.item())
plt.figure()
plt.plot(loss_history, label="train loss")
plt.plot(val_loss_history, label="validation loss")
plt.legend()