Torchscript 失败:“RecursiveScriptModule”对象没有属性

问题描述 投票:0回答:1

我正在尝试使用 PyTorch 的 Torchscript 来编写第三方库中定义的模块的脚本。

下面的例子是问题的抽象版本。假设某个我无法修改的库定义了

SomeClass
LibraryModule
类,其中后者是 PyTorch 模块。

LibraryModule
的主要方法是
compute
,它接受一个张量和一个
SomeClass
的实例。

import torch
import torch.nn as nn


class SomeClass:
    """A utility class in a library I cannot modify"""
    def __init__(self, x):
        self.x = x


class LibraryModule(nn.Module):
    """A module provided in a library I cannot modify"""
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)

    def compute(self, x, some_class_object: SomeClass):
        """
        Main function of my module; like forward, but takes a non-tensor argument
        """
        return self.linear(x) * some_class_object.x

这就是我试图为课程获取脚本的方法:

script = torch.jit.script(LibraryModule(3, 2))
print(script.compute(torch.tensor([10, 20, 30]), SomeClass(2)))

但我收到以下错误:

  File "torchscript.py", line 25, in <module>
    print(script.compute(torch.tensor([10, 20, 30]), SomeClass(2)))
          ^^^^^^^^^^^^^^
  File "\Lib\site-packages\torch\jit\_script.py", line 826, in __getattr__
    return super().__getattr__(attr)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "Lib\site-packages\torch\jit\_script.py", line 533, in __getattr__
    return super().__getattr__(attr)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "Lib\site-packages\torch\nn\modules\module.py", line 1931, in __getattr__
    raise AttributeError(
AttributeError: 'RecursiveScriptModule' object has no attribute 'compute'. Did you mean: 'compile'?

我也尝试过直接编写该方法的脚本:

compute_script = torch.jit.script(LibraryModule(3, 2).compute)
print(compute_script(torch.tensor([10, 20, 30]), SomeClass(2)))

但后来我得到:

RuntimeError: 
'Tensor (inferred)' object has no attribute or method 'linear'.:
  File "torchscript.py", line 21
        Main function of my module; like forward, but takes a non-tensor argument
        """
        return self.linear(x) * some_class_object.x
               ~~~~~~~~~~~ <--- HERE

如何获得

LibraryModule.compute
的工作脚本?

python pytorch torchscript
1个回答
0
投票

错误

AttributeError: 'RecursiveScriptModule' object has no attribute 'compute'

是因为默认情况下只有

forward
和其他最终递归调用的方法
forward
是脚本化(编译)的。由于
compute
没有被
forward
调用,因此它不包含在脚本模块中。为此,必须用
@torch.jit.export
来装饰它。

因为我无法更改原始库,所以我可以创建一个包装器模块,其

forward
方法调用
LibraryModule.compute
:

class WrapperModule(nn.Module):
    """A wrapper module that uses the library module"""
    def __init__(self, in_features, out_features):
        super().__init__()
        self.lib = LibraryModule(in_features, out_features)

    def forward(self, x, some_class_object: SomeClass):
        return self.lib.compute(x, some_class_object)


script = torch.jit.script(WrapperModule(3, 2))
print(script(torch.tensor([10., 20., 30.]), SomeClass(torch.tensor(2))))

这是可行的,但请注意,我还需要将主张量的类型更改为 float,并且需要将张量传递给

SomeClass

我的问题中显示的第二个错误有些无关。当我打电话时

torch.jit.script(LibraryModule(3, 2).compute)

不知何故,对

self
的引用丢失了,我无法再访问
self.linear
了。这很奇怪,因为
LibraryModule(3, 2).compute
应该是一个捕获
self
的闭包,但这也许是 Torchscript 的一个怪癖(请注意下面的代码如何正常工作):

f = LibraryModule(3, 2).compute
print(f(torch.tensor([10., 20., 30.]), SomeClass(torch.tensor(2))))
© www.soinside.com 2019 - 2024. All rights reserved.