CUDA 和 JAX 库不兼容

问题描述 投票:0回答:1

我的CUDA版本是11.6。 JAX的版本是0.4.16,jaxlib的版本是0.4.16+cuda11.cudnn86。

当我运行简单的 Python 代码时,会显示一条错误消息

W external/xla/xla/service/gpu/buffer_comparator.cc:1054] INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-meiji-993e158e-113566-608cf99230264, line 10; fatal   : Unsupported .version 7.8; current version is '7.6'

第二条错误消息显示

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load PTX text as a module: CUDA_ERROR_UNSUPPORTED_PTX_VERSION: the provided PTX was compiled with an unsupported toolchain.

详情如下:

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1698537556.277868  113566 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
random key: [0 0]
2023-10-28 19:59:41.023080: W external/xla/xla/service/gpu/buffer_comparator.cc:1054] INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /tmp/tempfile-meiji-993e158e-113566-608cf99230264, line 10; fatal   : Unsupported .version 7.8; current version is '7.6'
ptxas fatal   : Ptx assembly aborted due to errors

Relying on driver to perform ptx compilation.
Setting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda  or modifying $PATH can be used to set the location of ptxas
This message will only be logged once.
2023-10-28 19:59:41.077542: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:857] failed to load PTX text as a module: CUDA_ERROR_UNSUPPORTED_PTX_VERSION: the provided PTX was compiled with an unsupported toolchain.
2023-10-28 19:59:41.077572: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:862] error log buffer (98 bytes): ptxas application ptx input, line 10; fatal   : Unsupported .version 7.8; current version is '7.6
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/shuting/nucleotide_transformer/nucleotide_transformer_test.py", line 76, in <module>
    outs = forward_fn.apply(parameters, random_key, tokens)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/transform.py", line 183, in apply_fn
    out, state = f.apply(params, None, *args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/transform.py", line 456, in apply_fn
    out = f(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/model.py", line 363, in nucleotide_transformer_fn
    outs = encoder(
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
    out = f(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/model.py", line 310, in __call__
    x, outs = self.apply_attention_blocks(
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/model.py", line 226, in apply_attention_blocks
    output = layer(
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
    out = f(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 281, in __call__
    output = self.self_attention(
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 240, in self_attention
    return self.sa_layer(x, x, x, attention_mask=attention_mask)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
    out = f(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 149, in __call__
    attention_weights = self.attention_weights(query, key, attention_mask)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 83, in attention_weights
    query_heads = self._linear_projection_he_init(query, self.key_size, "query")
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/nucleotide_transformer/layers.py", line 173, in _linear_projection_he_init
    y = hk.Linear(
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 458, in wrapped
    out = f(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/module.py", line 299, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/shuting/anaconda3/envs/transformer5/lib/python3.9/site-packages/haiku/_src/basic.py", line 181, in __call__
    out = jnp.dot(inputs, w, precision=precision)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load PTX text as a module: CUDA_ERROR_UNSUPPORTED_PTX_VERSION: the provided PTX was compiled with an unsupported toolchain.
I0000 00:00:1698537581.636191  113566 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.

有人可以帮我解决这个问题吗?非常感谢!

deep-learning cuda gpu nvidia jax
1个回答
0
投票

您可以在此处找到 JAX GPU 安装说明:https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu

特别是,JAX v0.4.16 需要 CUDA 版本 11.8 或更高版本。您可以尝试通过

pip
安装 CUDA,如该链接第一个小标题下所述。如果由于某种原因您必须使用 CUDA 11.6,我建议尝试较旧的 JAX 版本 - 可能 0.4.8 之前的版本左右会兼容。

© www.soinside.com 2019 - 2024. All rights reserved.