我遇到了 TypeError: empty(): argument 'size' must be tuple of SymInts, but found element of type str at pos 3
我是这个深度学习领域的新手,只是想将示例 DNN 代码修改为 FCN 代码
所以基本上结构是由其他人编写的,当我尝试修改一些输入参数时
与
我想知道为什么
有人可以给我一些建议吗?
class FCN(nn.Module):
def __init__(self, args):
super(FCN,self).__init__()
model_name = args.model_name
self.mdl_attrib = args.model_name.split('-')[1]
Input_Size = int(model_name.split('-I_')[1].split('-')[0].split('_')[0])
channel_size_list = list(map(int,re.split(r'[-_]',model_name.split('-M_')[1].split('-O_')[0])[::2]))
#a = type(channel_size_list[0])
kernel_size_list = re.split(r'[-_]',model_name.split('-M_')[1].split('-O_')[0])[1::2]
Outpt_Size = int(model_name.split('-O')[1].split('-')[0].split('_')[1])
self.NumLyrs = len(channel_size_list) + 2
self.UsedCmpx = False
if args.use_complex:
Dtype = torch.cfloat
self.UsedCmpx = True
else:
Dtype = torch.float
self.Conv2dLayerList = nn.ModuleList()
self.ActvationLrList = nn.ModuleList()
self.InstNormLyrList = nn.ModuleList()
if self.NumLyrs == 2:
self.Conv2dLayerList.append(nn.Conv2d(Input_Size, Outpt_Size, dtype = Dtype))
for LyInd in range(self.NumLyrs - 1):
if LyInd == self.NumLyrs - 2:
self.Conv2dLayerList.append(nn.Conv2d(channel_size_list[-1], Outpt_Size, kernel_size_list[LyInd],dtype = Dtype))
if self.mdl_attrib == 'Cls':
self.ActvationLrList.append(nn.Sigmoid())
elif LyInd == 0:
self.Conv2dLayerList.append(nn.Conv2d(Input_Size, channel_size_list[LyInd], kernel_size_list[LyInd], dtype = Dtype))
self.ActvationLrList.append(nn.CELU())
self.InstNormLyrList.append(nn.InstanceNorm2d(channel_size_list[LyInd]))
else:
self.Conv2dLayerList.append(nn.Conv2d(channel_size_list[LyInd-1],channel_size_list[LyInd], kernel_size_list[LyInd], dtype = Dtype))
self.ActvationLrList.append(nn.CELU())
self.InstNormLyrList.append(nn.InstanceNorm2d(channel_size_list[LyInd]))
def forward(self, Inp):
if self.NumLyrs == 2:
LyrOut = self.Conv2dLayerList[0](Inp)
if self.mdl_attrib == 'Cls':
if self.UsedCmpx:
RLyrOut = self.ActvationLrList[0](torch.real(LyrOut))
ILyrOut = self.ActvationLrList[0](torch.imag(LyrOut))
LyrOut = RLyrOut + 1j * ILyrOut
else:
LyrOut = self.ActvationLrList[0](LyrOut)
return LyrOut
else:
for lyInd in range(self.NumLyrs - 1):
if lyInd == self.NumLyrs - 2:
LyrOut = self.Conv2dLayerList[lyInd](LyrOut)
if self.mdl_attrib == 'Cls':
if self.UsedCmpx:
RLyrOut = self.ActvationLrList[lyInd](torch.real(LyrOut))
ILyrOut = self.ActvationLrList[lyInd](torch.imag(LyrOut))
LyrOut = RLyrOut + 1j * ILyrOut
else:
LyrOut = self.ActvationLrList[lyInd](LyrOut)
return LyrOut
elif lyInd == 0:
LyrOut = self.Conv2dLayerList[lyInd](Inp)
if self.UsedCmpx:
RLyrOut = self.InstNormLyrList[lyInd](self.ActvationLrList[lyInd](torch.real(LyrOut)).permute(0,2,1)).permute(0,2,1)
ILyrOut = self.InstNormLyrList[lyInd](self.ActvationLrList[lyInd](torch.imag(LyrOut)).permute(0,2,1)).permute(0,2,1)
LyrOut = RLyrOut + 1j * ILyrOut
else:
LyrOut = self.InstNormLyrList[lyInd](self.ActvationLrList[lyInd](LyrOut))
else:
LyrOut = self.Conv2dLayerList[lyInd](LyrOut)
if self.UsedCmpx:
RLyrOut = self.InstNormLyrList[lyInd](self.ActvationLrList[lyInd](torch.real(LyrOut)).permute(0,2,1)).permute(0,2,1)
ILyrOut = self.InstNormLyrList[lyInd](self.ActvationLrList[lyInd](torch.imag(LyrOut)).permute(0,2,1)).permute(0,2,1)
LyrOut = RLyrOut + 1j * ILyrOut
else:
LyrOut = self.InstNormLyrList[lyInd](self.ActvationLrList[lyInd](LyrOut))
尝试在 linux 系统上运行,结果发生了