我有一个函数 -
main()
,它接受类 InputPermutation
的实例作为唯一的参数。这个想法是,main()
运行结果的差异仅基于 InputPermutation
类的差异。我需要迭代 InputPermutation
类的所有可能配置,并记录每次运行的 main()
的结果。
InputPermutation
具有作为不同 Enum
类的实例的属性。我需要迭代每个 Enum
属性的所有可能的 InputPermutation
类实例以获取所有可能的配置。
这是我面临的问题的简化模型。
from enum import Enum
class Colour(Enum):
GREEN = "green"
BLUE = "blue"
RED = "red"
class Country(Enum):
ENGLAND = "england"
JAPAN = "japan"
AUSTRALIA = "australia"
class InputPermutation:
def __init__(self, colour: Colour, country: Country):
self.colour = colour
self.country = country
def main(input_permutation: InputPermutation) -> dict:
colour = input_permutation.colour.value
country = input_permutation.country.value
result = {colour: country}
return result
def iterate() -> dict:
pass
我需要帮助来制作这个 iterate() 函数,不知道如何使其工作......
我希望该函数返回一个字典,其中每个键都是一个“run_index”,只是每次运行 main() 时都会加 1 的数字。理想情况下,每个值都是以下结构的字典:
{
1: {"green": "england"},
2: {"green": "japan"},
3: {"green": "australia"},
4: {"blue": "england"},
# etc...
}
我希望它能够扩展,因此无论有多少不同的 Enum 类(或现有 Enum 类中的新选项)添加到 InputPermutation 类中,该函数仍然会迭代所有选项。我已经设法获得此输出,而无需以这种方式进行扩展。
该问题可能特定于我对 Enum 类的使用。我选择此选项的原因是因为当我选择要选择的选项时它会提供下拉列表。它还通过在将选项输入转换为字符串之前锁定您以特定格式输入来标准化选项输入,从而减少输入错误的可能性(如果有意义的话)。
这个问题确实对我正在制作的模型具有现实世界的适用性,但我认为这个
country : colour
的东西在这里会更容易使用......
我认为这就是您正在寻找的:
# new imports:
import typing
import itertools
def iterate() -> dict:
# get all options:
input_args_types = typing.get_type_hints(InputPermutation.__init__)
all_options_per_type = list(map(list, input_args_types))
all_InputPermutation_combos = list(
itertools.product(*all_options_per_type)
)
return {
i: {a.name for a in arg}
for i, arg in enumerate(all_InputPermutation_combos)
}
输出:
{0: {'england', 'green'},
1: {'green', 'japan'},
2: {'australia', 'green'},
3: {'blue', 'england'},
4: {'blue', 'japan'},
5: {'australia', 'blue'},
6: {'england', 'red'},
7: {'japan', 'red'},
8: {'australia', 'red'}}
即使您添加枚举:
class Extra(Enum):
TEST = "test1"
TEST2 = "test2"
class InputPermutation:
def __init__(self, colour: Colour, country: Country, extra: Extra):
self.colour = colour
self.country = country
那么输出是:
>>> iterate()
{0: {'england', 'green', 'test1'},
1: {'england', 'green', 'test2'},
2: {'green', 'japan', 'test1'},
3: {'green', 'japan', 'test2'},
4: {'australia', 'green', 'test1'},
5: {'australia', 'green', 'test2'},
6: {'blue', 'england', 'test1'},
7: {'blue', 'england', 'test2'},
8: {'blue', 'japan', 'test1'},
9: {'blue', 'japan', 'test2'},
10: {'australia', 'blue', 'test1'},
11: {'australia', 'blue', 'test2'},
12: {'england', 'red', 'test1'},
13: {'england', 'red', 'test2'},
14: {'japan', 'red', 'test1'},
15: {'japan', 'red', 'test2'},
16: {'australia', 'red', 'test1'},
17: {'australia', 'red', 'test2'}}
该函数利用来自
__init__
函数的输入信息。如果没有输入信息,这个功能当然不起作用。这是使用类型系统的一种方法,检查哪些枚举被用作输入。 Itertools 用于至少创建每个组合一次。