我正在使用
torch
、jaxtyping
和 beartype
来键入注释我的函数(使用运行时类型检查)。我使用了 @jaxtyped(typechecker=beartype)
装饰器。
我在运行时收到以下错误:
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of display_features_from_tokens_and_feature_tensor.
The problem arose whilst typechecking parameter 'tokens'.
Actual value: tensor([41083, 531, 366, 31373, 612, 1, 290, 788, 3332, 866,
13], device='cuda:0')
Expected type: <class 'Int[Tensor, 'seq_len']'>.`
据我所知,实际值的类型与预期类型匹配。我是否遗漏了
beartype
或 jaxtyping
的细微差别?
我怀疑这可能与我的张量位于 GPU 上有关,但在
jaxtyped
文档中找不到任何提及这一点的信息。
我也尝试使用
Int64
但无济于事(同样的错误)。
是否与 https://docs.kidger.site/jaxtyping/faq/“flake8 或 ruff 抛出错误”有关?
对于单维形状检查,您可能必须在形状字符串前面添加一个空格,即“seq_len”。