以编程方式更改 pytorch 模型的组件?

问题描述 投票:0回答:1

我正在 Pytorch 中训练一个模型,并且希望能够以编程方式更改模型架构的某些组件,以检查哪个组件在

forward()
中没有任何 if 块的情况下效果最好。考虑一个玩具示例:

import torch

class Model(torch.nn.Model):
   def __init__(self, layers: str, d_in: int, d_out: int):
      super().__init__()
      self.layers = layers
      linears = torch.nn.ModuleList([
         torch.nn.Linear(d_in, d_out),
         torch.nn.Linear(d_in, d_out),
      ])

   def forward(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
      if self.layers == "parallel":
         x1 = self.linears[0](x1)
         x2 = self.linears[0](x2)
         x = x1 + x2
      elif self.layers == "sequential":
         x = x1 + x2
         x = self.linears[0](x)
         x = self.linears[0](x)
      return x

我的第一直觉是提供外部函数,例如

def parallel(x1, x2):
   x1 = self.linears[0](x1)
   x2 = self.linears[0](x2)
   return x1 + x2

并将它们提供给模型,例如

class Model(torch.nn.Model):
   def __init__(self, layers: str, d_in: int, d_out: int, fn: Callable):
      super().__init__()
      self.layers = layers
      linears = torch.nn.ModuleList([
         torch.nn.Linear(d_in, d_out),
         torch.nn.Linear(d_in, d_out),
      ])
      self.fn = fn

   def forward(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
      x = self.fn(x1, x2)

但是函数的范围当然不知道

self.linears
,我也想避免将每个架构元素传递给函数。

我的愿望是否太多?我是否必须像德语中所说的那样“咬酸苹果”,或者使用更大的函数签名,或者使用 if 条件,或者其他什么?或者我的问题有解决办法吗?

python pytorch
1个回答
0
投票

您可以在 init 函数或其他函数中使用 if 语句,例如:

from enum import Enum

class ModelType(Enum):
    Parallel = 1
    Sequential = 2
    
class Model(torch.nn.Model):
    def __init__(self, layers: str, d_in: int, d_out: int, model_type: ModelType):
        super().__init__()
        self.layers = layers
        linears = torch.nn.ModuleList([
         torch.nn.Linear(d_in, d_out),
         torch.nn.Linear(d_in, d_out),
        ])
        self.model_type = model_type
        self.initialize()

    def initialize(self):
        if self.model_type == ModelType.Parallel:
            self.fn = self.parallel
        else if self.model_type == ModelType.Sequential::
            self.fn = self.sequential

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x = self.fn(x1, x2)
    
    def parallel(self, x1, x2):
        x1 = self.linears[0](x1)
        x2 = self.linears[0](x2)
        x = x1 + x2
        return x
    
    def sequential(self, x1, x2):
        x = x1 + x2
        x = self.linears[0](x)
        x = self.linears[0](x)
        return x

希望有帮助。

© www.soinside.com 2019 - 2024. All rights reserved.