jax 库错误 jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN 库初始化失败

问题描述 投票:0回答:1

我正在编写使用 Jax 库的代码,无论我如何尝试配置我的环境,我都会一遍又一遍地遇到此错误:

2024-08-20 16:26:58.037892: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] 

Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-08-20 16:26:58.037952: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 46637514752 bytes free, 47587131392 bytes total.
Traceback (most recent call last):
  File "GPU_pairwise_pipline.py", line 260, in <module>
    SSMD_res_with_indices = process_blocks(train_set_sick, train_set_healthy, block_size)
  File "GPU_pairwise_pipline.py", line 172, in process_blocks
    mean_block1_sick, var_block1_sick = cal_mean_and_var(block1_sick)
  File "GPU_pairwise_pipline.py", line 17, in cal_mean_and_var
    data_jax = jnp.array(data)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2035, in array
    out = _array_copy(object) if copy else object
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4447, in _array_copy
    return copy_p.bind(arr)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4486, in _copy_impl
    return dispatch.apply_primitive(prim, *args, **kwargs)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
    compiled_fun = xla_primitive_callable(
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/util.py", line 284, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/util.py", line 277, in cached
    return f(*args, **kwargs)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
    compiled = _xla_callable_uncached(
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
    return computation.compile().unsafe_call
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
(jax-env) mikabell@hanamel:~$  sudo apt install nvidia-cuda-toolkit

这就是我正在处理的事情:

nvidia-smi
Tue Aug 20 16:52:27 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA L40                     On  | 00000000:01:00.0 Off |                    0 |
| N/A   39C    P0              79W / 300W |    894MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA L40                     On  | 00000000:02:00.0 Off |                    0 |
| N/A   29C    P8              34W / 300W |     21MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA L40                     On  | 00000000:61:00.0 Off |                    0 |
| N/A   29C    P8              34W / 300W |     21MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA L40                     On  | 00000000:62:00.0 Off |                    0 |
| N/A   30C    P8              35W / 300W |     21MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      3774      G   /usr/lib/xorg/Xorg                            4MiB |
|    0   N/A  N/A    999897      C   ...physics/rtk243/anaconda3/bin/python      868MiB |
|    1   N/A  N/A      3774      G   /usr/lib/xorg/Xorg                            4MiB |
|    2   N/A  N/A      3774      G   /usr/lib/xorg/Xorg                            4MiB |
|    3   N/A  N/A      3774      G   /usr/lib/xorg/Xorg                            4MiB |
+---------------------------------------------------------------------------------------+

没有内存问题,我已经设置了 $LD_LIBRARY_PATH 来指向我下载的 CUDNN 版本的位置:

echo $LD_LIBRARY_PATH
/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/


#define CUDNN_MAJOR 8
#define CUDNN_MINOR 9

我正在使用这些 Jax 安装:

pip show jax jaxlib
Name: jax
Version: 0.4.13
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages
Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy
Required-by:
---
Name: jaxlib
Version: 0.4.13+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by:

请有人帮助我理解为什么我一遍又一遍地遇到同样的错误..

谢谢!

python cuda conda jax cudnn
1个回答
0
投票

此错误通常意味着您的 CUDA 和 CUDNN 版本与您尝试安装的 jaxlib 版本不兼容。

您使用的是相当旧的 JAX 版本(v0.4.13,最新版本是 v0.4.31),并且您使用的是较新的 CUDA 版本 (v12.3),该版本是在 CUDA 发布六个月后发布的。 JAX v0.4.13 发布。鉴于此,我怀疑可以通过安装与您的 CUDA 版本兼容的更新版本的 jax 和 jaxlib 来解决该问题。

从你的回溯中可以看出你正在使用Python 3.8;最近的 JAX 版本不支持 Python 3.8,因此修复此问题可能还需要安装更新的 Python 版本 - 无论如何,尽快执行此操作会很好,因为 Python 3.8 将在几周内达到其生命周期(2024 年 10 月) .

更新到 Python 3.9 或更高版本后,您可以通过以下方式更新 JAX:

$ pip install -U "jax[cuda12]"

它应该可以解决你的问题。

© www.soinside.com 2019 - 2024. All rights reserved.