Beartype 和 jaxtyping 在运行时错误地捕获 Tensor 类型错误

问题描述 投票:0回答:1

我正在使用

torch
jaxtyping
beartype
来键入注释我的函数(使用运行时类型检查)。我使用了
@jaxtyped(typechecker=beartype)
装饰器。

我在运行时收到以下错误:

jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of display_features_from_tokens_and_feature_tensor.
The problem arose whilst typechecking parameter 'tokens'.
Actual value: tensor([41083,   531,   366, 31373,   612,     1,   290,   788,  3332,   866,
13], device='cuda:0')
Expected type: <class 'Int[Tensor, 'seq_len']'>.`

据我所知,实际值的类型与预期类型匹配。我是否遗漏了

beartype
jaxtyping
的细微差别?

我怀疑这可能与我的张量位于 GPU 上有关,但在

jaxtyped
文档中找不到任何提及这一点的信息。

我也尝试使用

Int64
但无济于事(同样的错误)。

python pytorch typechecking
1个回答
0
投票

是否与 https://docs.kidger.site/jaxtyping/faq/“flake8 或 ruff 抛出错误”有关?

对于单维形状检查,您可能必须在形状字符串前面添加一个空格,即“seq_len”。

© www.soinside.com 2019 - 2024. All rights reserved.