Python 中的歧视联合

问题描述 投票:0回答:2

假设我有一个基类和两个派生类。我还有一个工厂方法,它返回其中一个类的对象。问题是,mypy 或 IntelliJ 无法确定对象是什么类型。他们知道两者都有可能,但不知道到底是哪一个。有什么方法可以帮助 mypy/IntelliJ 解决这个问题,而无需在

conn
变量名称旁边放置类型提示?

import abc
import enum
import typing


class BaseConnection(abc.ABC):
    @abc.abstractmethod
    def sql(self, query: str) -> typing.List[typing.Any]:
        ...


class PostgresConnection(BaseConnection):

    def sql(self, query: str) -> typing.List[typing.Any]:
        return "This is a postgres result".split()

    def only_postgres_things(self):
        pass


class MySQLConnection(BaseConnection):

    def sql(self, query: str) -> typing.List[typing.Any]:
        return "This is a mysql result".split()

    def only_mysql_things(self):
        pass


class ConnectionType(enum.Enum):
    POSTGRES = 1
    MYSQL = 2


def connect(conn_type: ConnectionType) -> typing.Union[PostgresConnection, MySQLConnection]:
    if conn_type is ConnectionType.POSTGRES:
        return PostgresConnection()
    if conn_type is ConnectionType.MYSQL:
        return MySQLConnection()


conn = connect(ConnectionType.POSTGRES)
conn.only_postgres_things()

看看 IntelliJ 如何处理这个问题: enter image description here

正如您所看到的,这两种方法:当我希望 IntelliJ/mypy 从我传递给

only_postgres_things
函数的类型中找出它时,建议使用
only_mysql_things
connect

python pycharm python-typing mypy discriminated-union
2个回答
6
投票

由于您的

ConnectionType
类的目的显然是为了使您的 API 更具可读性和用户友好性,而不是使用
Enum
的任何特定功能,因此您实际上不必将其设为
Enum
类。

相反,您可以创建一个常规类,将每个连接类型分配给一个用户友好名称的类变量,以便您可以使用类型变量键入

connect
函数的返回值,并使用类型键入参数的类型变量。使用类型别名可以使类型变量的类型更具可读性:

class ConnectionTypes:
    POSTGRES = PostgresConnection
    MYSQL = MySQLConnection

Connection = typing.TypeVar('Connection', PostgresConnection, MySQLConnection)
# or make it bound to the base class:
# Connection = typing.TypeVar('Connection', bound=BaseConnection)
ConnectionType: typing.TypeAlias = type[Connection]

def connect(type_: ConnectionType) -> Connection:
    if type_ is ConnectionType.POSTGRES:
        return PostgresConnection()
    if type_ is ConnectionType.MYSQL:
        return MySQLConnection()

Use a regular class in combination with a type variable.


3
投票

您可以尝试将

typing.overload
typing.Literal
结合使用,如下所示:


@typing.overload
def connect(type_: typing.Literal[ConnectionType.POSTGRES]) -> PostgresConnection:
    ...

@typing.overload
def connect(type_: typing.Literal[ConnectionType.MYSQL]) -> MySQLConnection:
    ...

def connect(type_):
    if type_ is ConnectionType.POSTGRES:
        return PostgresConnection()
    if type_ is ConnectionType.MYSQL:
        return MySQLConnection()

我用

type
替换了
type_
,这样你就不会遮蔽内置函数,并且使用
is
而不是
==
来比较枚举值是惯用的。

© www.soinside.com 2019 - 2024. All rights reserved.