我正在尝试使用 PyTorch 的 Torchscript 来编写第三方库中定义的模块的脚本。
下面的例子是问题的抽象版本。假设某个我无法修改的库定义了
SomeClass
和 LibraryModule
类,其中后者是 PyTorch 模块。
LibraryModule
的主要方法是compute
,它接受一个张量和一个SomeClass
的实例。
import torch
import torch.nn as nn
class SomeClass:
"""A utility class in a library I cannot modify"""
def __init__(self, x):
self.x = x
class LibraryModule(nn.Module):
"""A module provided in a library I cannot modify"""
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def compute(self, x, some_class_object: SomeClass):
"""
Main function of my module; like forward, but takes a non-tensor argument
"""
return self.linear(x) * some_class_object.x
这就是我试图为课程获取脚本的方法:
script = torch.jit.script(LibraryModule(3, 2))
print(script.compute(torch.tensor([10, 20, 30]), SomeClass(2)))
但我收到以下错误:
File "torchscript.py", line 25, in <module>
print(script.compute(torch.tensor([10, 20, 30]), SomeClass(2)))
^^^^^^^^^^^^^^
File "\Lib\site-packages\torch\jit\_script.py", line 826, in __getattr__
return super().__getattr__(attr)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "Lib\site-packages\torch\jit\_script.py", line 533, in __getattr__
return super().__getattr__(attr)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "Lib\site-packages\torch\nn\modules\module.py", line 1931, in __getattr__
raise AttributeError(
AttributeError: 'RecursiveScriptModule' object has no attribute 'compute'. Did you mean: 'compile'?
我也尝试过直接编写该方法的脚本:
compute_script = torch.jit.script(LibraryModule(3, 2).compute)
print(compute_script(torch.tensor([10, 20, 30]), SomeClass(2)))
但后来我得到:
RuntimeError:
'Tensor (inferred)' object has no attribute or method 'linear'.:
File "torchscript.py", line 21
Main function of my module; like forward, but takes a non-tensor argument
"""
return self.linear(x) * some_class_object.x
~~~~~~~~~~~ <--- HERE
如何获得
LibraryModule.compute
的工作脚本?
错误
AttributeError: 'RecursiveScriptModule' object has no attribute 'compute'
是因为默认情况下只有
forward
和其他最终递归调用的方法forward
是脚本化(编译)的。由于 compute
没有被 forward
调用,因此它不包含在脚本模块中。为此,必须用 @torch.jit.export
来装饰它。
因为我无法更改原始库,所以我可以创建一个包装器模块,其
forward
方法调用 LibraryModule.compute
:
class WrapperModule(nn.Module):
"""A wrapper module that uses the library module"""
def __init__(self, in_features, out_features):
super().__init__()
self.lib = LibraryModule(in_features, out_features)
def forward(self, x, some_class_object: SomeClass):
return self.lib.compute(x, some_class_object)
script = torch.jit.script(WrapperModule(3, 2))
print(script(torch.tensor([10., 20., 30.]), SomeClass(torch.tensor(2))))
这是可行的,但请注意,我还需要将主张量的类型更改为 float,并且需要将张量传递给
SomeClass
。
我的问题中显示的第二个错误有些无关。当我打电话时
torch.jit.script(LibraryModule(3, 2).compute)
不知何故,对
self
的引用丢失了,我无法再访问 self.linear
了。这很奇怪,因为 LibraryModule(3, 2).compute
应该是一个捕获 self
的闭包,但这也许是 Torchscript 的一个怪癖(请注意下面的代码如何正常工作):
f = LibraryModule(3, 2).compute
print(f(torch.tensor([10., 20., 30.]), SomeClass(torch.tensor(2))))