我正在 Pytorch 中训练一个模型,并且希望能够以编程方式更改模型架构的某些组件,以检查哪个组件在
forward()
中没有任何 if 块的情况下效果最好。考虑一个玩具示例:
import torch
class Model(torch.nn.Model):
def __init__(self, layers: str, d_in: int, d_out: int):
super().__init__()
self.layers = layers
linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
def forward(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
if self.layers == "parallel":
x1 = self.linears[0](x1)
x2 = self.linears[0](x2)
x = x1 + x2
elif self.layers == "sequential":
x = x1 + x2
x = self.linears[0](x)
x = self.linears[0](x)
return x
我的第一直觉是提供外部函数,例如
def parallel(x1, x2):
x1 = self.linears[0](x1)
x2 = self.linears[0](x2)
return x1 + x2
并将它们提供给模型,例如
class Model(torch.nn.Model):
def __init__(self, layers: str, d_in: int, d_out: int, fn: Callable):
super().__init__()
self.layers = layers
linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
self.fn = fn
def forward(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x = self.fn(x1, x2)
但是函数的范围当然不知道
self.linears
,我也想避免将每个架构元素传递给函数。
我的愿望是否太多?我是否必须像德语中所说的那样“咬酸苹果”,或者使用更大的函数签名,或者使用 if 条件,或者其他什么?或者我的问题有解决办法吗?
您可以在 init 函数或其他函数中使用 if 语句,例如:
from enum import Enum
class ModelType(Enum):
Parallel = 1
Sequential = 2
class Model(torch.nn.Model):
def __init__(self, layers: str, d_in: int, d_out: int, model_type: ModelType):
super().__init__()
self.layers = layers
linears = torch.nn.ModuleList([
torch.nn.Linear(d_in, d_out),
torch.nn.Linear(d_in, d_out),
])
self.model_type = model_type
self.initialize()
def initialize(self):
if self.model_type == ModelType.Parallel:
self.fn = self.parallel
else if self.model_type == ModelType.Sequential::
self.fn = self.sequential
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x = self.fn(x1, x2)
def parallel(self, x1, x2):
x1 = self.linears[0](x1)
x2 = self.linears[0](x2)
x = x1 + x2
return x
def sequential(self, x1, x2):
x = x1 + x2
x = self.linears[0](x)
x = self.linears[0](x)
return x
希望有帮助。