我正在尝试创建一个字典,将输入数据类型的元组映射到处理这些数据类型的函数。我一直没能找到一种方法来以令我满意的方式进行类型提示。在禁用文件的静态类型检查之前,我想我会问......
这是 MWE:
from typing import Callable
class BaseType: pass
class IntType(BaseType): pass
class StringType(BaseType): pass
def _compare_int_int(before: IntType, after: IntType) -> int:
return 1
def _compare_int_str(before: IntType, after: StringType) -> int:
return 2
def _compare_str_int(before: StringType, after: IntType) -> int:
return 3
def _compare_str_str(before: StringType, after: StringType) -> int:
return 4
_COMPARERS: dict[tuple[type[BaseType], type[BaseType]], Callable[[BaseType, BaseType], int]] = {
(IntType, IntType): _compare_int_int,
(IntType, StringType): _compare_int_str,
(StringType, StringType): _compare_str_str,
(StringType, IntType): _compare_str_int
}
def compare(before: BaseType, after: BaseType) -> int:
compare_func = _COMPARERS[(type(before), type(after))]
return compare_func(before, after)
a = IntType()
b = StringType()
print(compare(a, a))
## 1
print(compare(a, b))
## 2
print(compare(b, a))
## 3
print(compare(b, b))
## 4
它按预期工作,但是
mypy
给出了错误
Dict entry 0 has incompatible type "Tuple[Type[IntType], Type[IntType]]": "Callable[[IntType, IntType], int]"; expected "Tuple[Type[BaseType], Type[BaseType]]": "Callable[[BaseType, BaseType], int]"
其他条目也是如此。
有什么正确提示类型的技巧吗?
我已经
谢谢。
编辑 我也尝试过像这样使用
TypeVar
:
T = TypeVar('T', bound=BaseType)
U = TypeVar('U', bound=BaseType)
_COMPARERS: dict[tuple[type[T], type[U]], Callable[[T, U], int]] = {...}
虽然我更喜欢这个想法(它阐明了函数的参数类型与元组的相应类型相匹配),但它仍然给出错误:
Dict entry 0 has incompatible type "Tuple[Type[IntType], Type[IntType]]": "Callable[[IntType, IntType], int]"; expected "Tuple[Type[T?], Type[U?]]": "Callable[[T, U], int]"
如果我离开
T
和 U
而没有 bound
值,错误不会改变。
与有关将类型映射到函数的类似问题一样,您可以根据需要使用
typing.Protocol
方法实现 dict
,而不是尝试参数化通用字典:
from typing import Callable, Protocol, TypeVar
class BaseType: pass
class IntType(BaseType): pass
class StringType(BaseType): pass
def _compare_int_int(before: IntType, after: IntType) -> int:
return 1
def _compare_int_str(before: IntType, after: StringType) -> int:
return 2
def _compare_str_int(before: StringType, after: IntType) -> int:
return 3
def _compare_str_str(before: StringType, after: StringType) -> int:
return 4
T1 = TypeVar("T1", bound=BaseType)
T2 = TypeVar("T2", bound=BaseType)
class _Comparers(Protocol):
def __getitem__(self, item: tuple[type[T1], type[T2]]) -> Callable[[T1, T2], int]: ...
def __setitem__(self, key: tuple[type[T1], type[T2]], value: Callable[[T1, T2], int]) -> None: ...
# Add more signatures from `dict` here, if needed
# You'll have to add the entries separately instead of defining a dictionary in one go
_COMPARERS: _Comparers = {}
_COMPARERS[(IntType, IntType)] = _compare_int_int
_COMPARERS[(IntType, StringType)] = _compare_int_str
_COMPARERS[(StringType, StringType)] = _compare_str_str
_COMPARERS[(StringType, IntType)] = _compare_str_int
# mypy: Value of type variable "T1" of "__setitem__" of "_Comparers" cannot be "int" [type-var]
# mypy: Value of type variable "T2" of "__setitem__" of "_Comparers" cannot be "int" [type-var]
# mypy: Incompatible types in assignment (expression has type "Callable[[IntType, IntType], int]", target has type "Callable[[int, int], int]") [assignment]
# _COMPARERS[(int, int)] = _compare_int_int
# This is type-safe - no linting errors should occur anywhere in this function
def compare(before: BaseType, after: BaseType) -> int:
compare_func = _COMPARERS[(type(before), type(after))]
return compare_func(before, after)
a = IntType()
b = StringType()
print(compare(a, a))
## 1
print(compare(a, b))
## 2
print(compare(b, a))
## 3
print(compare(b, b))
## 4
我建议使用以下方法之一来解决该问题:
TypeVar
与您的函数一起使用是获取类型的更好方法。