我有一个包含几个
np.array
的课程:
class VECMParams(ModelParams):
def __init__(
self,
ecm_gamma: np.ndarray,
ecm_mu: Optional[np.ndarray],
ecm_lambda: np.ndarray,
ecm_beta: np.ndarray,
intercept_coint: bool,
):
self.ecm_gamma = ecm_gamma
self.ecm_mu = ecm_mu
self.ecm_lambda = ecm_lambda
self.ecm_beta = ecm_beta
self.intercept_coint = intercept_coint
我想覆盖
==
运营商。基本上,当所有数组都等于 VECMParam
一个时,rhs
等于另一个:
def __eq__(self, rhs: object) -> bool:
if not isinstance(rhs, VECMParams):
raise NotImplementedError()
return (
np.all(self.ecm_gamma == rhs.ecm_gamma) and
np.all(self.ecm_mu == rhs.ecm_mu) and
np.all(self.ecm_lambda == rhs.ecm_lambda) and
np.all(self.ecm_beta == rhs.ecm_beta)
)
仍然,mypy 一直说
Incompatible return value type (got "Union[bool_, bool]", expected "bool") [return-value]
因为 np.all
返回 bool_
和 __eq__
需要返回原生 bool
。我搜索了几个小时,看起来没有办法将这些 bool_ 转换为本机 bool。有人有同样的问题吗?
PS:做
my_bool_ is True
没有被评估为正确的原生布尔值