我正在编写使用 Jax 库的代码,无论我如何尝试配置我的环境,我都会一遍又一遍地遇到此错误:
2024-08-20 16:26:58.037892: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439]
Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-08-20 16:26:58.037952: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 46637514752 bytes free, 47587131392 bytes total.
Traceback (most recent call last):
File "GPU_pairwise_pipline.py", line 260, in <module>
SSMD_res_with_indices = process_blocks(train_set_sick, train_set_healthy, block_size)
File "GPU_pairwise_pipline.py", line 172, in process_blocks
mean_block1_sick, var_block1_sick = cal_mean_and_var(block1_sick)
File "GPU_pairwise_pipline.py", line 17, in cal_mean_and_var
data_jax = jnp.array(data)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2035, in array
out = _array_copy(object) if copy else object
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4447, in _array_copy
return copy_p.bind(arr)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4486, in _copy_impl
return dispatch.apply_primitive(prim, *args, **kwargs)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
compiled_fun = xla_primitive_callable(
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/util.py", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/util.py", line 277, in cached
return f(*args, **kwargs)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
compiled = _xla_callable_uncached(
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
return computation.compile().unsafe_call
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
(jax-env) mikabell@hanamel:~$ sudo apt install nvidia-cuda-toolkit
这就是我正在处理的事情:
nvidia-smi
Tue Aug 20 16:52:27 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA L40 On | 00000000:01:00.0 Off | 0 |
| N/A 39C P0 79W / 300W | 894MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA L40 On | 00000000:02:00.0 Off | 0 |
| N/A 29C P8 34W / 300W | 21MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA L40 On | 00000000:61:00.0 Off | 0 |
| N/A 29C P8 34W / 300W | 21MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA L40 On | 00000000:62:00.0 Off | 0 |
| N/A 30C P8 35W / 300W | 21MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 3774 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 999897 C ...physics/rtk243/anaconda3/bin/python 868MiB |
| 1 N/A N/A 3774 G /usr/lib/xorg/Xorg 4MiB |
| 2 N/A N/A 3774 G /usr/lib/xorg/Xorg 4MiB |
| 3 N/A N/A 3774 G /usr/lib/xorg/Xorg 4MiB |
+---------------------------------------------------------------------------------------+
没有内存问题,我已经设置了 $LD_LIBRARY_PATH 来指向我下载的 CUDNN 版本的位置:
echo $LD_LIBRARY_PATH
/a/home/cc/chemist/mikabell/anaconda3/envs/jax-env/lib/
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 9
我正在使用这些 Jax 安装:
pip show jax jaxlib
Name: jax
Version: 0.4.13
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages
Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy
Required-by:
---
Name: jaxlib
Version: 0.4.13+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: [email protected]
License: Apache-2.0
Location: /home/mikabell/anaconda3/envs/jax-env/lib/python3.8/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by:
请有人帮助我理解为什么我一遍又一遍地遇到同样的错误..
谢谢!
此错误通常意味着您的 CUDA 和 CUDNN 版本与您尝试安装的 jaxlib 版本不兼容。
您使用的是相当旧的 JAX 版本(v0.4.13,最新版本是 v0.4.31),并且您使用的是较新的 CUDA 版本 (v12.3),该版本是在 CUDA 发布六个月后发布的。 JAX v0.4.13 发布。鉴于此,我怀疑可以通过安装与您的 CUDA 版本兼容的更新版本的 jax 和 jaxlib 来解决该问题。
从你的回溯中可以看出你正在使用Python 3.8;最近的 JAX 版本不支持 Python 3.8,因此修复此问题可能还需要安装更新的 Python 版本 - 无论如何,尽快执行此操作会很好,因为 Python 3.8 将在几周内达到其生命周期(2024 年 10 月) .
更新到 Python 3.9 或更高版本后,您可以通过以下方式更新 JAX:
$ pip install -U "jax[cuda12]"
它应该可以解决你的问题。