我正在开发一个神经网络加速框架,使用XLA来编译模型的计算图。但是,由于模型中使用的库的代码不是由我编写的,因此它们通常不会优先考虑性能,这可能会降低我的低级代码的性能。因此,我正在考虑在导入加速库时添加一些补丁来动态修改模型前向函数的部分内容。但是,由于需要保持与更高级别模型库的兼容性,我需要更灵活的修补方法。
这是我目前正在考虑的策略,但不知何故它不起作用。 我可以使用检查库获取类的代码,找到目标行,将其替换为另一行,然后使用
exec
设置修改后的类。
使用inspect的代码demo如下:
def patch_qwen():
import inspect
import transformers.models.qwen2.modeling_qwen2 as qwen2
replace_str = " rotary_seq_len = kv_seq_len"
search_str = "position_ids[:, -1].max().item()"
src = inspect.getsource(qwen2.Qwen2FlashAttention2).splitlines()
print('\n'.join(src)) # first print
for i, line in enumerate(src):
if search_str in line:
# print(f"target str find {search_str} at {i}")
src[i] = replace_str
break
src = '\n'.join(src)
exec(src, qwen2.__dict__)
src = inspect.getsource(qwen2.Qwen2FlashAttention2)
print("------------------------------------------------")
print(src)# second print
结果是第一次打印和第二次打印是相同的,Qwen2FlashAttention2类没有以某种方式改变。
任何人都可以给我一些关于为什么不起作用以及如何解决问题的解释吗?
不要尝试编辑源代码,而是用方法的更新版本替换类上的方法(猴子修补类)。
举个简单的例子,如果您有
models.py
包含一个类:
class Model:
def __init__(self, values):
self._values = values
def get_values(self):
return self._values
def total(self):
print("Slow total")
total = 0
for value in self.get_values():
total += value
return total
并且您想要修补
total
函数,那么您可以猴子修补对象来替换该函数:
from models import Model
def patched_total(self):
print("Fast total")
return sum(self.get_values())
Model.total = patched_total
然后,每当您想要模型的总数时,它都会调用修补版本,而不是原始方法。 (注意:其他方法不会被修改。如上面使用
get_values
方法的示例所示。)
如果您希望能够在原始功能和新功能之间进行交换,那么您可以在上下文管理器中进行猴子修补:
def patched_total(self):
print("Fast total")
return sum(self.get_values())
class PatchTotal:
def __init__(self):
from models import Model
self.model = Model
self.total = None
def __enter__(self):
self.total = self.model.total
self.model.total = patched_total
def __exit__(self, exc_type, exc_value, exc_traceback):
self.model.total = self.total
self.total = None
然后:
from models import Model
from patch import PatchTotal
model = Model([1, 2, 3])
print(model.total())
with PatchTotal():
print(model.total())
print(model.total())
输出:
Slow total
6
Fast total
6
Slow total
6