我想用 Nvidia RTX 4090 GPU 训练 Mask R-CNN 模型,但这似乎不可能。 似乎负重有问题。
我尝试了以下实现:
使用 Python 3.7.12 和 Tensorflow 1.14/1.15 以及 Keras 2.3.1 Mask-RCNN 效果很好。 但我无法将 RTX 4090 与这些版本一起使用。 我至少需要 Tensorflow 2.12.1 才能获得正确的 RTX 4090 CUDA 版本(CUDA 11.8)。
因此,当我使用 Tensorflow 2.12.1 或 2.13.1 时,存在一个特殊问题,即权重加载错误。权重总是随机加载的。 所以我训练模型 1 epoch,然后加载权重 3 次,权重不同。 看到这样的结果:
请注意,结果当然会看起来很草图并且很奇怪,因为我没有训练该模式,但我预计在相同的权重下会得到相同的结果。
我尝试了 Tom Gross 的实现 和 mrk1992 的实现,并且都遇到了这个错误。
我查了很多问题,但没有发现任何东西。也许你能帮忙。
有人知道如何解决吗?
为了娱乐,我给你我的 conda 环境。 创建环境.yml:
name: maskrcnn
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- _tflow_select=2.3.0=mkl
- absl-py=0.15.0=pyhd8ed1ab_0
- aiohttp=3.7.4.post0=py37h5e8e339_1
- alembic=1.13.1=pyhd8ed1ab_0
- alsa-lib=1.2.8=h166bdaf_0
- anyio=3.7.1=pyhd8ed1ab_0
- aom=3.5.0=h27087fc_0
- argon2-cffi=23.1.0=pyhd8ed1ab_0
- argon2-cffi-bindings=21.2.0=py37h540881e_2
- astor=0.8.1=pyh9f0ad1d_0
- async-timeout=3.0.1=py_1000
- attr=2.5.1=h166bdaf_1
- attrs=23.2.0=pyh71513ae_0
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=pyhd8ed1ab_3
- backports.functools_lru_cache=2.0.0=pyhd8ed1ab_0
- bcrypt=3.2.2=py37h540881e_0
- beautifulsoup4=4.12.3=pyha770c72_0
- binutils_impl_linux-64=2.40=ha885e6a_0
- binutils_linux-64=2.40=hdade7a5_3
- bleach=6.1.0=pyhd8ed1ab_0
- blinker=1.6.3=pyhd8ed1ab_0
- bottleneck=1.3.5=py37hda87dfa_0
- brotli=1.1.0=hd590300_1
- brotli-bin=1.1.0=hd590300_1
- brotli-python=1.0.9=py37hd23a5d3_7
- bzip2=1.0.8=hd590300_5
- c-ares=1.28.1=hd590300_0
- ca-certificates=2024.2.2=hbcca054_0
- cached-property=1.5.2=hd8ed1ab_1
- cached_property=1.5.2=pyha770c72_1
- cachetools=5.3.3=pyhd8ed1ab_0
- cairo=1.16.0=ha61ee94_1014
- certifi=2024.2.2=pyhd8ed1ab_0
- cffi=1.15.1=py37h43b0acd_1
- chardet=4.0.0=py37h89c1867_3
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- click=8.1.3=py37h89c1867_0
- cloudpickle=2.2.1=pyhd8ed1ab_0
- colorama=0.4.6=pyhd8ed1ab_0
- colorlog=6.7.0=py37h89c1867_0
- comm=0.2.2=pyhd8ed1ab_0
- cryptography=38.0.2=py37h5994e8b_1
- cycler=0.11.0=pyhd8ed1ab_0
- cython=0.29.32=py37hd23a5d3_0
- cytoolz=0.12.0=py37h540881e_0
- dask-core=2022.2.0=pyhd8ed1ab_0
- dbus=1.13.6=h5008d03_3
- debugpy=1.6.3=py37hd23a5d3_0
- decorator=5.1.1=pyhd8ed1ab_0
- defusedxml=0.7.1=pyhd8ed1ab_0
- dill=0.3.8=pyhd8ed1ab_0
- distro=1.9.0=pyhd8ed1ab_0
- docker-py=6.1.3=pyhd8ed1ab_0
- entrypoints=0.4=pyhd8ed1ab_0
- exceptiongroup=1.2.0=pyhd8ed1ab_2
- expat=2.6.2=h59595ed_0
- ffmpeg=4.4.2=gpl_h8dda1f0_112
- fftw=3.3.10=nompi_hc118613_108
- flask=1.1.2=pyh9f0ad1d_0
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
- font-ttf-inconsolata=3.000=h77eed37_0
- font-ttf-source-code-pro=2.038=h77eed37_0
- font-ttf-ubuntu=0.83=h77eed37_2
- fontconfig=2.14.2=h14ed4e7_0
- fonts-conda-ecosystem=1=0
- fonts-conda-forge=1=0
- fonttools=4.38.0=py37h540881e_0
- freeglut=3.2.2=h9c3ff4c_1
- freetype=2.12.1=h267a509_2
- fsspec=2023.1.0=pyhd8ed1ab_0
- gast=0.2.2=py_0
- gcc_impl_linux-64=13.2.0=h9eb54c0_7
- gcc_linux-64=13.2.0=h1ed452b_3
- geos=3.11.0=h27087fc_0
- gettext=0.22.5=h59595ed_2
- gettext-tools=0.22.5=h59595ed_2
- gitdb=4.0.11=pyhd8ed1ab_0
- gitpython=3.1.43=pyhd8ed1ab_0
- glib=2.80.2=hf974151_0
- glib-tools=2.80.2=hb6ce0ca_0
- gmp=6.3.0=h59595ed_1
- gnutls=3.7.9=hb077bed_0
- google-auth=2.23.0=pyh1a96a4e_0
- google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
- google-pasta=0.2.0=pyh8c360ce_0
- graphite2=1.3.13=h59595ed_1003
- greenlet=1.1.3=py37hd23a5d3_0
- grpc-cpp=1.48.1=hc2bec63_1
- grpcio=1.48.1=py37h42e856d_1
- gst-plugins-base=1.21.3=h4243ec0_1
- gstreamer=1.21.3=h25f0c4b_1
- gstreamer-orc=0.4.38=hd590300_0
- gunicorn=20.1.0=py37h89c1867_2
- gxx_impl_linux-64=13.2.0=h2a599c4_7
- gxx_linux-64=13.2.0=he8deefe_3
- h5py=3.7.0=nompi_py37hf1ce037_101
- harfbuzz=5.3.0=h418a68e_0
- hdf5=1.12.2=nompi_h4df4325_101
- icu=70.1=h27087fc_0
- idna=3.7=pyhd8ed1ab_0
- imagecodecs-lite=2019.12.3=py37hc105733_5
- imageio=2.34.1=pyh4b66e23_0
- imgaug=0.4.0=pyhd8ed1ab_1
- importlib-metadata=4.11.4=py37h89c1867_0
- importlib_resources=6.0.0=pyhd8ed1ab_0
- imutils=0.5.4=py37h89c1867_2
- ipykernel=6.16.2=pyh210e3f2_0
- ipython=7.33.0=py37h89c1867_0
- ipython_genutils=0.2.0=py_1
- ipywidgets=8.1.2=pyhd8ed1ab_1
- itsdangerous=2.1.2=pyhd8ed1ab_0
- jack=1.9.22=h11f4161_0
- jasper=2.0.33=h0ff4b12_1
- jedi=0.19.1=pyhd8ed1ab_0
- jinja2=3.1.4=pyhd8ed1ab_0
- joblib=1.2.0=pyhd8ed1ab_0
- jpeg=9e=h0b41bf4_3
- jsonschema=4.17.3=pyhd8ed1ab_0
- jupyter=1.0.0=pyhd8ed1ab_10
- jupyter_client=7.4.9=pyhd8ed1ab_0
- jupyter_console=6.5.1=pyhd8ed1ab_0
- jupyter_core=4.11.1=py37h89c1867_0
- jupyter_server=1.23.4=pyhd8ed1ab_0
- jupyterlab_pygments=0.3.0=pyhd8ed1ab_1
- jupyterlab_widgets=3.0.10=pyhd8ed1ab_0
- keras=2.3.1=py37_0
- keras-applications=1.0.8=py_1
- keras-preprocessing=1.1.2=pyhd8ed1ab_0
- kernel-headers_linux-64=2.6.32=he073ed8_17
- keyutils=1.6.1=h166bdaf_0
- kiwisolver=1.4.4=py37h7cecad7_0
- krb5=1.20.1=h81ceb04_0
- lame=3.100=h166bdaf_1003
- lcms2=2.14=h6ed2654_0
- ld_impl_linux-64=2.40=h55db66e_0
- lerc=4.0.0=h27087fc_0
- libabseil=20220623.0=cxx17_h05df665_6
- libaec=1.1.3=h59595ed_0
- libasprintf=0.22.5=h661eb56_2
- libasprintf-devel=0.22.5=h661eb56_2
- libblas=3.9.0=20_linux64_openblas
- libbrotlicommon=1.1.0=hd590300_1
- libbrotlidec=1.1.0=hd590300_1
- libbrotlienc=1.1.0=hd590300_1
- libcap=2.67=he9d0100_0
- libcblas=3.9.0=20_linux64_openblas
- libclang=15.0.7=default_h127d8a8_5
- libclang13=15.0.7=default_h5d6823c_5
- libcups=2.3.3=h36d4200_3
- libcurl=8.1.2=h409715c_0
- libdb=6.2.32=h9c3ff4c_0
- libdeflate=1.14=h166bdaf_0
- libdrm=2.4.120=hd590300_0
- libedit=3.1.20191231=he28a2e2_2
- libev=4.33=hd590300_2
- libevent=2.1.10=h28343ad_4
- libexpat=2.6.2=h59595ed_0
- libffi=3.4.2=h7f98852_5
- libflac=1.4.3=h59595ed_0
- libgcc-devel_linux-64=13.2.0=hceb6213_107
- libgcc-ng=13.2.0=h77fa898_7
- libgcrypt=1.10.3=hd590300_0
- libgettextpo=0.22.5=h59595ed_2
- libgettextpo-devel=0.22.5=h59595ed_2
- libgfortran-ng=13.2.0=h69a702a_7
- libgfortran5=13.2.0=hca663fb_7
- libglib=2.80.2=hf974151_0
- libglu=9.0.0=he1b5a44_1001
- libgomp=13.2.0=h77fa898_7
- libgpg-error=1.49=h4f305b6_0
- libgpuarray=0.7.6=h7f98852_1003
- libiconv=1.17=hd590300_2
- libidn2=2.3.7=hd590300_0
- liblapack=3.9.0=20_linux64_openblas
- liblapacke=3.9.0=20_linux64_openblas
- libllvm11=11.1.0=he0ac6c6_5
- libllvm15=15.0.7=hadd5161_1
- libnghttp2=1.58.0=h47da74e_0
- libnsl=2.0.1=hd590300_0
- libogg=1.3.4=h7f98852_1
- libopenblas=0.3.25=pthreads_h413a1c8_0
- libopencv=4.6.0=py37hfe11ba8_3
- libopus=1.3.1=h7f98852_1
- libpciaccess=0.18=hd590300_0
- libpng=1.6.43=h2797004_0
- libpq=15.3=hbcd7760_1
- libprotobuf=3.20.1=h6239696_4
- libsanitizer=13.2.0=h6ddb7a1_7
- libsndfile=1.2.2=hc60ed4a_1
- libsodium=1.0.18=h36c2ea0_1
- libsqlite=3.45.3=h2797004_0
- libssh2=1.11.0=h0841786_0
- libstdcxx-devel_linux-64=13.2.0=hceb6213_107
- libstdcxx-ng=13.2.0=hc0a3c3a_7
- libsystemd0=253=h8c4010b_1
- libtasn1=4.19.0=h166bdaf_0
- libtiff=4.4.0=h82bc61c_5
- libtool=2.4.7=h27087fc_0
- libudev1=253=h0b41bf4_1
- libunistring=0.9.10=h7f98852_0
- libuuid=2.38.1=h0b41bf4_0
- libva=2.18.0=h0b41bf4_0
- libvorbis=1.3.7=h9c3ff4c_0
- libvpx=1.11.0=h9c3ff4c_3
- libwebp-base=1.4.0=hd590300_0
- libxcb=1.13=h7f98852_1004
- libxkbcommon=1.5.0=h79f4944_1
- libxml2=2.10.3=hca2bb57_4
- libzlib=1.2.13=hd590300_5
- llvmlite=0.39.1=py37h0761922_0
- locket=1.0.0=pyhd8ed1ab_0
- lz4-c=1.9.4=hcb278e6_0
- mako=1.3.5=pyhd8ed1ab_0
- markdown=3.6=pyhd8ed1ab_0
- markupsafe=2.1.1=py37h540881e_1
- matplotlib-base=3.5.3=py37hf395dca_2
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
- mistune=3.0.2=pyhd8ed1ab_0
- mlflow=1.30.0=py37h02d9ccd_0
- mpg123=1.32.6=h59595ed_0
- multidict=6.0.2=py37h540881e_1
- multiprocess=0.70.14=py37h540881e_0
- munkres=1.1.4=pyh9f0ad1d_0
- mysql-common=8.0.33=hf1915f5_6
- mysql-libs=8.0.33=hca2cd23_6
- nbclassic=1.0.0=pyhb4ecaf3_1
- nbclient=0.7.0=pyhd8ed1ab_0
- nbconvert=7.6.0=pyhd8ed1ab_0
- nbconvert-core=7.6.0=pyhd8ed1ab_0
- nbconvert-pandoc=7.6.0=pyhd8ed1ab_0
- nbformat=5.8.0=pyhd8ed1ab_0
- ncurses=6.5=h59595ed_0
- nest-asyncio=1.6.0=pyhd8ed1ab_0
- nettle=3.9.1=h7ab15ed_0
- networkx=2.6.3=pyhd8ed1ab_1
- nomkl=1.0=h5ca1d4c_0
- notebook=6.5.7=pyha770c72_0
- notebook-shim=0.2.4=pyhd8ed1ab_0
- nspr=4.35=h27087fc_0
- nss=3.100=hca3bf56_0
- numba=0.56.3=py37hf081915_0
- numexpr=2.8.3=py37h85a3170_100
- numpy=1.21.6=py37h976b520_0
- oauthlib=3.2.2=pyhd8ed1ab_0
- opencv=4.6.0=py37h89c1867_3
- openh264=2.3.1=hcb278e6_2
- openjpeg=2.5.0=h7d73246_1
- openssl=3.1.5=hd590300_0
- opt_einsum=3.3.0=pyhc1e730c_2
- p11-kit=0.24.1=hc5aa10d_0
- packaging=21.3=pyhd8ed1ab_0
- pandas=1.3.5=py37h8c16a72_0
- pandoc=3.2=ha770c72_0
- pandocfilters=1.5.0=pyhd8ed1ab_0
- paramiko=3.4.0=pyhd8ed1ab_0
- parso=0.8.4=pyhd8ed1ab_0
- partd=1.4.1=pyhd8ed1ab_0
- pcre2=10.43=hcad00b1_0
- pexpect=4.9.0=pyhd8ed1ab_0
- pickleshare=0.7.5=py_1003
- pillow=9.2.0=py37h850a105_2
- pip=24.0=pyhd8ed1ab_0
- pixman=0.43.2=h59595ed_0
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1
- prometheus_client=0.17.1=pyhd8ed1ab_0
- prometheus_flask_exporter=0.23.0=pyhd8ed1ab_0
- prompt-toolkit=3.0.42=pyha770c72_0
- prompt_toolkit=3.0.42=hd8ed1ab_0
- protobuf=3.20.1=py37hd23a5d3_0
- psutil=5.9.3=py37h540881e_0
- pthread-stubs=0.4=h36c2ea0_1001
- ptyprocess=0.7.0=pyhd3deb0d_0
- pulseaudio=16.1=hcb278e6_3
- pulseaudio-client=16.1=h5195f5e_3
- pulseaudio-daemon=16.1=ha8d29e2_3
- py-opencv=4.6.0=py37h25bab4e_3
- pyasn1=0.5.1=pyhd8ed1ab_0
- pyasn1-modules=0.3.0=pyhd8ed1ab_0
- pycocotools=2.0.4=py37hda87dfa_2
- pycparser=2.21=pyhd8ed1ab_0
- pygments=2.17.2=pyhd8ed1ab_0
- pygpu=0.7.6=py37hb1e94ed_1003
- pyjwt=2.8.0=pyhd8ed1ab_1
- pynacl=1.5.0=py37h540881e_1
- pyopenssl=23.2.0=pyhd8ed1ab_1
- pyparsing=3.1.2=pyhd8ed1ab_0
- pyrsistent=0.18.1=py37h540881e_1
- pysocks=1.7.1=py37h89c1867_5
- python=3.7.12=hf930737_100_cpython
- python-dateutil=2.9.0=pyhd8ed1ab_0
- python-fastjsonschema=2.19.1=pyhd8ed1ab_0
- python_abi=3.7=4_cp37m
- pytz=2022.7.1=pyhd8ed1ab_0
- pyu2f=0.1.5=pyhd8ed1ab_0
- pywavelets=1.3.0=py37hda87dfa_1
- pywin32-on-windows=0.1.0=pyh1179c8e_3
- pyyaml=6.0=py37h540881e_4
- pyzmq=24.0.1=py37h0c0c2a8_0
- qt-main=5.15.6=hf6cd601_5
- qtconsole-base=5.4.4=pyha770c72_0
- qtpy=2.4.1=pyhd8ed1ab_0
- querystring_parser=1.2.4=py_0
- re2=2022.06.01=h27087fc_1
- readline=8.2=h8228510_1
- requests=2.31.0=pyhd8ed1ab_0
- requests-oauthlib=2.0.0=pyhd8ed1ab_0
- rsa=4.9=pyhd8ed1ab_0
- ruamel.yaml=0.17.10=py37h5e8e339_0
- ruamel.yaml.clib=0.2.6=py37h540881e_1
- scikit-build=0.17.6=pyh4af843d_0
- scikit-image=0.19.3=py37hfb7772e_1
- scikit-learn=1.0.2=py37hf9e9bfc_0
- scipy=1.7.3=py37hf2a6cf1_0
- send2trash=1.8.3=pyh0d859eb_0
- setproctitle=1.3.2=py37h540881e_0
- setuptools=69.0.3=pyhd8ed1ab_0
- shapely=1.8.5=py37ha4e3bd1_0
- six=1.16.0=pyh6c4a22f_0
- smmap=5.0.0=pyhd8ed1ab_0
- sniffio=1.3.1=pyhd8ed1ab_0
- soupsieve=2.3.2.post1=pyhd8ed1ab_0
- sqlalchemy=1.4.42=py37h540881e_0
- sqlite=3.45.3=h2c6b66d_0
- sqlparse=0.4.4=pyhd8ed1ab_0
- svt-av1=1.4.1=hcb278e6_0
- sysroot_linux-64=2.12=he073ed8_17
- tensorboard=2.8.0=pyhd8ed1ab_1
- tensorboard-data-server=0.6.1=py37h52d8a92_0
- tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
- tensorflow=2.0.0=mkl_py37h66b46cc_0
- tensorflow-base=2.0.0=mkl_py37h9204916_0
- tensorflow-estimator=2.6.0=py37hcd2ae1e_0
- termcolor=1.1.0=pyhd8ed1ab_3
- terminado=0.17.1=pyh41d4057_0
- theano=1.0.4=py37hf484d3e_1000
- threadpoolctl=3.1.0=pyh8a188c0_0
- tifffile=2020.6.3=py_0
- tinycss2=1.3.0=pyhd8ed1ab_0
- tk=8.6.13=noxft_h4845f30_101
- tomli=2.0.1=pyhd8ed1ab_0
- toolz=0.12.1=pyhd8ed1ab_0
- tornado=6.2=py37h540881e_0
- tqdm=4.66.4=pyhd8ed1ab_0
- traitlets=5.9.0=pyhd8ed1ab_0
- typing-extensions=4.7.1=hd8ed1ab_0
- typing_extensions=4.7.1=pyha770c72_0
- unicodedata2=14.0.0=py37h540881e_1
- urllib3=1.26.18=pyhd8ed1ab_0
- wcwidth=0.2.10=pyhd8ed1ab_0
- webencodings=0.5.1=pyhd8ed1ab_2
- websocket-client=1.6.1=pyhd8ed1ab_0
- werkzeug=0.16.1=py_0
- wheel=0.42.0=pyhd8ed1ab_0
- widgetsnbextension=4.0.10=pyhd8ed1ab_0
- wrapt=1.14.1=py37h540881e_0
- x264=1!164.3095=h166bdaf_2
- x265=3.5=h924138e_3
- xcb-util=0.4.0=h516909a_0
- xcb-util-image=0.4.0=h166bdaf_0
- xcb-util-keysyms=0.4.0=h516909a_0
- xcb-util-renderutil=0.3.9=h166bdaf_0
- xcb-util-wm=0.4.1=h516909a_0
- xkeyboard-config=2.38=h0b41bf4_0
- xorg-fixesproto=5.0=h7f98852_1002
- xorg-inputproto=2.3.2=h7f98852_1002
- xorg-kbproto=1.0.7=h7f98852_1002
- xorg-libice=1.1.1=hd590300_0
- xorg-libsm=1.2.4=h7391055_0
- xorg-libx11=1.8.4=h0b41bf4_0
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxdmcp=1.1.3=h7f98852_0
- xorg-libxext=1.3.4=h0b41bf4_2
- xorg-libxfixes=5.0.3=h7f98852_1004
- xorg-libxi=1.7.10=h7f98852_0
- xorg-libxrender=0.9.10=h7f98852_1003
- xorg-renderproto=0.11.1=h7f98852_1002
- xorg-xextproto=7.3.0=h0b41bf4_1003
- xorg-xproto=7.0.31=h7f98852_1007
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7f98852_2
- yarl=1.7.2=py37h540881e_2
- zeromq=4.3.5=h59595ed_1
- zipp=3.15.0=pyhd8ed1ab_0
- zlib=1.2.13=hd590300_5
- zstd=1.5.6=ha6fb4c9_0
- pip:
- databricks-cli==0.18.0
- tabulate==0.9.0
prefix: /home/local-admin/.conda/envs/maskrcnn
然后只需输入
conda env create -n maskrcnn python=3.8.19 -f environment.yml
还下载 Maks RCNN 模型:
git clone -b tensorflow-2.0 https://github.com/tomgross/Mask_RCNN.git Mask_RCNN
这是测试模型的代码:
import sys
sys.path.append("./Mask_RCNN")
import os
import numpy as np
import cv2
import imutils
import matplotlib.pyplot as plt
import tensorflow as tf
from IPython.display import clear_output
from mrcnn.config import Config
from mrcnn import model as modellib
from mrcnn import visualize
img_path = f"./test.jpg"
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = imutils.resize(image, width=512)
plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(image);
class TestConfig(Config):
NAME = "mask-rcnn test"
GPU_COUNT = 1
IMAGES_PER_GPU = 1
NUM_CLASSES = 2
# BACKBONE = "resnet50"
# IMAGE_MIN_DIM = 800
# IMAGE_MAX_DIM = 1024
CLASS_NAMES = ['BG', 'FG']
TEST_MODEL_PATH = "./test_model.h5"
config = TestConfig()
# config.display()
# create new model with random weights
model = modellib.MaskRCNN(mode="training", config=config, model_dir=os.getcwd())
model.keras_model.save_weights(TEST_MODEL_PATH)
clear_output()
# prepare visualization
fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(20, 5))
fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.01, hspace=None)
# load the created random weights 3 times
weights = []
differences = []
for i in range(3):
model = modellib.MaskRCNN(mode="inference", config=config, model_dir=os.getcwd())
model.load_weights(TEST_MODEL_PATH, by_name=True)
result = model.detect([image], verbose=0)[0]
ax[i].set_title(f"{i}. Results")
ax[i].axis("off")
visualize.display_instances(image, result['rois'], result['masks'], result['class_ids'], CLASS_NAMES, result['scores'],
ax=ax[i], show_mask=True, show_bbox=True)
weights += [model.keras_model.get_weights()]
clear_output()
# detect differences between the models
for cur_weights_1 in weights:
cur_differences = []
for cur_weights_2 in weights:
all_weigt_differences = [np.abs(w1-w2) for w1, w2 in zip(cur_weights_1, cur_weights_2)]
all_layer_differences = [np.sum(diff) for diff in all_weigt_differences]
cur_differences += [np.sum(all_layer_differences)]
differences += [cur_differences]
complete_sum_difference = np.sum(differences)
if complete_sum_difference == 0.0:
print("Congratulations!\nYour implementation works and is ready to segment!")
else:
print("Difference detected. Your Mask RCNN seem to have an problem.")
print("Check if you correctly applied all steps of the installation.")
print("\nMore detail of differences:")
print("Absolute Weight Differences:", complete_sum_difference)
print("Comparison: Weight 1, Weight 2, Weight 3")
for idx in range(len(differences)):
print(f"Weight {idx+1}: {differences[idx]}")
# try to remove the test model file
try:
os.remove(TEST_MODEL_PATH)
except Exception:
print(f"Wasn't able to delete the test model at: {TEST_MODEL_PATH}")
plt.show();
这是测试图像:
我希望这能让您的娱乐变得轻松。
有人有解决办法吗?
我尝试过:
我还尝试了小于 2.12.1 的张量流版本,但这些不起作用。您还可以在此处查看 TensorFlow 版本的 CUDA 版本:https://www.tensorflow.org/install/source。
我期望加载相同重量 3 次时得到相同的重量。为了验证,在我在 GTX 1080 Ti 上训练和使用模型之前,使用 Python 3.7.12、tensorflow 1.15 和 Keras 2.3.1,上面的代码成功并且权重相等,这也是预期的。 有了新的 GPU,这看起来就不一样了。
抱歉,为什么您不使用
torchvision
中的模型?您可以加载它进行推理,并且 torch 还提供了从头开始训练它的方法:
from torchvision.models.detection import (
MaskRCNN_ResNet50_FPN_Weights,
maskrcnn_resnet50_fpn
)
import torch
import PIL
import numpy as np
weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1
# image preprocessing
transform = weights.transforms()
model = maskrcnn_resnet50_fpn(weights=weights)
img = PIL.Image.open("airplane.jpeg")
input = transform(img)
model.eval()
# add batch dimension
input = input[None,...]
preds = model(input)
print(preds[0]['labels']) # tensor([5])
# for training check this scripts
# https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn