如何在运行时修改Python类中方法的代码

问题描述 投票:0回答:1

我正在开发一个神经网络加速框架,使用XLA来编译模型的计算图。但是,由于模型中使用的库的代码不是由我编写的,因此它们通常不会优先考虑性能,这可能会降低我的低级代码的性能。因此,我正在考虑在导入加速库时添加一些补丁来动态修改模型前向函数的部分内容。但是,由于需要保持与更高级别模型库的兼容性,我需要更灵活的修补方法。

这是我目前正在考虑的策略,但不知何故它不起作用。 我可以使用检查库获取类的代码,找到目标行,将其替换为另一行,然后使用

exec
设置修改后的类。

使用inspect的代码demo如下:

def patch_qwen():
    import inspect
    import transformers.models.qwen2.modeling_qwen2 as qwen2
    replace_str = "        rotary_seq_len = kv_seq_len"
    search_str = "position_ids[:, -1].max().item()"
    src = inspect.getsource(qwen2.Qwen2FlashAttention2).splitlines()
    print('\n'.join(src)) # first print
    for i, line in enumerate(src):
        if search_str in line:
            # print(f"target str find {search_str} at {i}")
            src[i] = replace_str
            break
    src = '\n'.join(src)
    exec(src, qwen2.__dict__)
    src = inspect.getsource(qwen2.Qwen2FlashAttention2)
    print("------------------------------------------------")
    print(src)# second print

结果是第一次打印和第二次打印是相同的,Qwen2FlashAttention2类没有以某种方式改变。

任何人都可以给我一些关于为什么不起作用以及如何解决问题的解释吗?

python deep-learning
1个回答
0
投票

不要尝试编辑源代码,而是用方法的更新版本替换类上的方法(猴子修补类)。


举个简单的例子,如果您有

models.py
包含一个类:

class Model:
    def __init__(self, values):
        self._values = values
    
    def get_values(self):
        return self._values

    def total(self):
        print("Slow total")
        total = 0
        for value in self.get_values():
            total += value
        return total

并且您想要修补

total
函数,那么您可以猴子修补对象来替换该函数:

from models import Model

def patched_total(self):
    print("Fast total")
    return sum(self.get_values())

Model.total = patched_total

然后,每当您想要模型的总数时,它都会调用修补版本,而不是原始方法。 (注意:其他方法不会被修改。如上面使用

get_values
方法的示例所示。)


如果您希望能够在原始功能和新功能之间进行交换,那么您可以在上下文管理器中进行猴子修补:

def patched_total(self):
    print("Fast total")
    return sum(self.get_values())

class PatchTotal:
    def __init__(self):
        from models import Model
        self.model = Model
        self.total = None
        
    def __enter__(self):
        self.total = self.model.total
        self.model.total = patched_total
        
    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.model.total = self.total
        self.total = None

然后:

from models import Model
from patch import PatchTotal

model = Model([1, 2, 3])

print(model.total())

with PatchTotal():
    print(model.total())

print(model.total())

输出:

Slow total
6
Fast total
6
Slow total
6
© www.soinside.com 2019 - 2024. All rights reserved.