我有一个像下面这样的代码,里面有多个函数用于各种计算。
我使用 python fire 来传递参数,而不是定义 argparse 并从 cli 调用函数。每次添加新参数时,我都必须在 init 中为其添加
self
。我正在寻找更好的方法。
我发现 python dataclasses 可以解决这个问题。我研究了 python fire 命令分组和多个命令。
class MyClass:
def __init__(
self,
input_path: str,
output_path: str = '',
same_size: bool = False,
crop_size: int = 300,
padding: int = 20,
write_json: bool = False,
write_image: bool = False,
line_thickness: int = 2,
side_color: Tuple = (255, 255, 0),
top_color: Tuple = (255, 0, 0),
) -> None:
super(MyClass, self).__init__()
self.input_path = input_path
self.output_path = output_path
self.same_size = same_size
self.crop_size = crop_size
self.padding = padding
self.write_json = write_json
...
...
def something(
self,
) -> None:
...
...
if __name__ == '__main__':
fire.Fire(MyClass)
我的问题是如何使用 python fire 正确获取
MyDataClass
中的 MyClass
数据类值?
@dataclass
class MyDataClass:
input_path: str
output_path: str = ''
same_size: bool = False
crop_size: int = 300
padding: int = 20
write_json: bool = False
write_image: bool = False
line_thickness: int = 2
side_color: Tuple = (255, 255, 0)
top_color: Tuple = (255, 0, 0)
class MyClass:
# init is not here anymore.
def something(
self,
) -> None:
...
...
如果开箱即用火支持就更好了。尽管如此,您可以在运行时定义一个方法来汇总所有数据类属性。该方法可以输入到
Fire
。
from dataclasses import dataclass, fields
from fire import Fire
@dataclass
class MyDataClass:
input_path: str =''
output_path: str = ''
same_size: bool = False
crop_size: int = 300
padding: int = 20
write_json: bool = False
write_image: bool = False
line_thickness: int = 2
side_color: tuple = (255, 255, 0)
top_color: tuple = (255, 0, 0)
@classmethod
def generate_init(cls):
init_args = []
for f in fields(cls):
default = f.default
if not default and issubclass(f.type, str):
default = '""'
init_args.append(f'{f.name}:{f.type.__name__}={default}')
code = f'def fn(self, {", ".join(init_args)}):\n'
code += '\n'.join(f' self.{f.name} = {f.name}' for f in fields(cls))
loc = {}
exec(code, None, loc)
fn = loc["fn"]
cls.fire = fn
return
MyDataClass.generate_init()
dc = MyDataClass()
Fire(dc.fire) # use CLI
# Fire(dc.fire, "--crop_size=999") # interactive example
print("Result:", dc)