是否可以使用 RTX 4090 训练 Mask RCNN?

问题描述 投票:0回答:1

我想用 Nvidia RTX 4090 GPU 训练 Mask R-CNN 模型,但这似乎不可能。 似乎负重有问题。

我尝试了以下实现:

使用 Python 3.7.12 和 Tensorflow 1.14/1.15 以及 Keras 2.3.1 Mask-RCNN 效果很好。 但我无法将 RTX 4090 与这些版本一起使用。 我至少需要 Tensorflow 2.12.1 才能获得正确的 RTX 4090 CUDA 版本(CUDA 11.8)。

因此,当我使用 Tensorflow 2.12.1 或 2.13.1 时,存在一个特殊问题,即权重加载错误。权重总是随机加载的。 所以我训练模型 1 epoch,然后加载权重 3 次,权重不同。 看到这样的结果:

Results when loading the same weights 3 Times

请注意,结果当然会看起来很草图并且很奇怪,因为我没有训练该模式,但我预计在相同的权重下会得到相同的结果。

我尝试了 Tom Gross 的实现mrk1992 的实现,并且都遇到了这个错误。

我查了很多问题,但没有发现任何东西。也许你能帮忙。

有人知道如何解决吗?


重现错误

为了娱乐,我给你我的 conda 环境。 创建环境.yml:

name: maskrcnn
channels:
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - _tflow_select=2.3.0=mkl
  - absl-py=0.15.0=pyhd8ed1ab_0
  - aiohttp=3.7.4.post0=py37h5e8e339_1
  - alembic=1.13.1=pyhd8ed1ab_0
  - alsa-lib=1.2.8=h166bdaf_0
  - anyio=3.7.1=pyhd8ed1ab_0
  - aom=3.5.0=h27087fc_0
  - argon2-cffi=23.1.0=pyhd8ed1ab_0
  - argon2-cffi-bindings=21.2.0=py37h540881e_2
  - astor=0.8.1=pyh9f0ad1d_0
  - async-timeout=3.0.1=py_1000
  - attr=2.5.1=h166bdaf_1
  - attrs=23.2.0=pyh71513ae_0
  - backcall=0.2.0=pyh9f0ad1d_0
  - backports=1.0=pyhd8ed1ab_3
  - backports.functools_lru_cache=2.0.0=pyhd8ed1ab_0
  - bcrypt=3.2.2=py37h540881e_0
  - beautifulsoup4=4.12.3=pyha770c72_0
  - binutils_impl_linux-64=2.40=ha885e6a_0
  - binutils_linux-64=2.40=hdade7a5_3
  - bleach=6.1.0=pyhd8ed1ab_0
  - blinker=1.6.3=pyhd8ed1ab_0
  - bottleneck=1.3.5=py37hda87dfa_0
  - brotli=1.1.0=hd590300_1
  - brotli-bin=1.1.0=hd590300_1
  - brotli-python=1.0.9=py37hd23a5d3_7
  - bzip2=1.0.8=hd590300_5
  - c-ares=1.28.1=hd590300_0
  - ca-certificates=2024.2.2=hbcca054_0
  - cached-property=1.5.2=hd8ed1ab_1
  - cached_property=1.5.2=pyha770c72_1
  - cachetools=5.3.3=pyhd8ed1ab_0
  - cairo=1.16.0=ha61ee94_1014
  - certifi=2024.2.2=pyhd8ed1ab_0
  - cffi=1.15.1=py37h43b0acd_1
  - chardet=4.0.0=py37h89c1867_3
  - charset-normalizer=3.3.2=pyhd8ed1ab_0
  - click=8.1.3=py37h89c1867_0
  - cloudpickle=2.2.1=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - colorlog=6.7.0=py37h89c1867_0
  - comm=0.2.2=pyhd8ed1ab_0
  - cryptography=38.0.2=py37h5994e8b_1
  - cycler=0.11.0=pyhd8ed1ab_0
  - cython=0.29.32=py37hd23a5d3_0
  - cytoolz=0.12.0=py37h540881e_0
  - dask-core=2022.2.0=pyhd8ed1ab_0
  - dbus=1.13.6=h5008d03_3
  - debugpy=1.6.3=py37hd23a5d3_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - defusedxml=0.7.1=pyhd8ed1ab_0
  - dill=0.3.8=pyhd8ed1ab_0
  - distro=1.9.0=pyhd8ed1ab_0
  - docker-py=6.1.3=pyhd8ed1ab_0
  - entrypoints=0.4=pyhd8ed1ab_0
  - exceptiongroup=1.2.0=pyhd8ed1ab_2
  - expat=2.6.2=h59595ed_0
  - ffmpeg=4.4.2=gpl_h8dda1f0_112
  - fftw=3.3.10=nompi_hc118613_108
  - flask=1.1.2=pyh9f0ad1d_0
  - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
  - font-ttf-inconsolata=3.000=h77eed37_0
  - font-ttf-source-code-pro=2.038=h77eed37_0
  - font-ttf-ubuntu=0.83=h77eed37_2
  - fontconfig=2.14.2=h14ed4e7_0
  - fonts-conda-ecosystem=1=0
  - fonts-conda-forge=1=0
  - fonttools=4.38.0=py37h540881e_0
  - freeglut=3.2.2=h9c3ff4c_1
  - freetype=2.12.1=h267a509_2
  - fsspec=2023.1.0=pyhd8ed1ab_0
  - gast=0.2.2=py_0
  - gcc_impl_linux-64=13.2.0=h9eb54c0_7
  - gcc_linux-64=13.2.0=h1ed452b_3
  - geos=3.11.0=h27087fc_0
  - gettext=0.22.5=h59595ed_2
  - gettext-tools=0.22.5=h59595ed_2
  - gitdb=4.0.11=pyhd8ed1ab_0
  - gitpython=3.1.43=pyhd8ed1ab_0
  - glib=2.80.2=hf974151_0
  - glib-tools=2.80.2=hb6ce0ca_0
  - gmp=6.3.0=h59595ed_1
  - gnutls=3.7.9=hb077bed_0
  - google-auth=2.23.0=pyh1a96a4e_0
  - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
  - google-pasta=0.2.0=pyh8c360ce_0
  - graphite2=1.3.13=h59595ed_1003
  - greenlet=1.1.3=py37hd23a5d3_0
  - grpc-cpp=1.48.1=hc2bec63_1
  - grpcio=1.48.1=py37h42e856d_1
  - gst-plugins-base=1.21.3=h4243ec0_1
  - gstreamer=1.21.3=h25f0c4b_1
  - gstreamer-orc=0.4.38=hd590300_0
  - gunicorn=20.1.0=py37h89c1867_2
  - gxx_impl_linux-64=13.2.0=h2a599c4_7
  - gxx_linux-64=13.2.0=he8deefe_3
  - h5py=3.7.0=nompi_py37hf1ce037_101
  - harfbuzz=5.3.0=h418a68e_0
  - hdf5=1.12.2=nompi_h4df4325_101
  - icu=70.1=h27087fc_0
  - idna=3.7=pyhd8ed1ab_0
  - imagecodecs-lite=2019.12.3=py37hc105733_5
  - imageio=2.34.1=pyh4b66e23_0
  - imgaug=0.4.0=pyhd8ed1ab_1
  - importlib-metadata=4.11.4=py37h89c1867_0
  - importlib_resources=6.0.0=pyhd8ed1ab_0
  - imutils=0.5.4=py37h89c1867_2
  - ipykernel=6.16.2=pyh210e3f2_0
  - ipython=7.33.0=py37h89c1867_0
  - ipython_genutils=0.2.0=py_1
  - ipywidgets=8.1.2=pyhd8ed1ab_1
  - itsdangerous=2.1.2=pyhd8ed1ab_0
  - jack=1.9.22=h11f4161_0
  - jasper=2.0.33=h0ff4b12_1
  - jedi=0.19.1=pyhd8ed1ab_0
  - jinja2=3.1.4=pyhd8ed1ab_0
  - joblib=1.2.0=pyhd8ed1ab_0
  - jpeg=9e=h0b41bf4_3
  - jsonschema=4.17.3=pyhd8ed1ab_0
  - jupyter=1.0.0=pyhd8ed1ab_10
  - jupyter_client=7.4.9=pyhd8ed1ab_0
  - jupyter_console=6.5.1=pyhd8ed1ab_0
  - jupyter_core=4.11.1=py37h89c1867_0
  - jupyter_server=1.23.4=pyhd8ed1ab_0
  - jupyterlab_pygments=0.3.0=pyhd8ed1ab_1
  - jupyterlab_widgets=3.0.10=pyhd8ed1ab_0
  - keras=2.3.1=py37_0
  - keras-applications=1.0.8=py_1
  - keras-preprocessing=1.1.2=pyhd8ed1ab_0
  - kernel-headers_linux-64=2.6.32=he073ed8_17
  - keyutils=1.6.1=h166bdaf_0
  - kiwisolver=1.4.4=py37h7cecad7_0
  - krb5=1.20.1=h81ceb04_0
  - lame=3.100=h166bdaf_1003
  - lcms2=2.14=h6ed2654_0
  - ld_impl_linux-64=2.40=h55db66e_0
  - lerc=4.0.0=h27087fc_0
  - libabseil=20220623.0=cxx17_h05df665_6
  - libaec=1.1.3=h59595ed_0
  - libasprintf=0.22.5=h661eb56_2
  - libasprintf-devel=0.22.5=h661eb56_2
  - libblas=3.9.0=20_linux64_openblas
  - libbrotlicommon=1.1.0=hd590300_1
  - libbrotlidec=1.1.0=hd590300_1
  - libbrotlienc=1.1.0=hd590300_1
  - libcap=2.67=he9d0100_0
  - libcblas=3.9.0=20_linux64_openblas
  - libclang=15.0.7=default_h127d8a8_5
  - libclang13=15.0.7=default_h5d6823c_5
  - libcups=2.3.3=h36d4200_3
  - libcurl=8.1.2=h409715c_0
  - libdb=6.2.32=h9c3ff4c_0
  - libdeflate=1.14=h166bdaf_0
  - libdrm=2.4.120=hd590300_0
  - libedit=3.1.20191231=he28a2e2_2
  - libev=4.33=hd590300_2
  - libevent=2.1.10=h28343ad_4
  - libexpat=2.6.2=h59595ed_0
  - libffi=3.4.2=h7f98852_5
  - libflac=1.4.3=h59595ed_0
  - libgcc-devel_linux-64=13.2.0=hceb6213_107
  - libgcc-ng=13.2.0=h77fa898_7
  - libgcrypt=1.10.3=hd590300_0
  - libgettextpo=0.22.5=h59595ed_2
  - libgettextpo-devel=0.22.5=h59595ed_2
  - libgfortran-ng=13.2.0=h69a702a_7
  - libgfortran5=13.2.0=hca663fb_7
  - libglib=2.80.2=hf974151_0
  - libglu=9.0.0=he1b5a44_1001
  - libgomp=13.2.0=h77fa898_7
  - libgpg-error=1.49=h4f305b6_0
  - libgpuarray=0.7.6=h7f98852_1003
  - libiconv=1.17=hd590300_2
  - libidn2=2.3.7=hd590300_0
  - liblapack=3.9.0=20_linux64_openblas
  - liblapacke=3.9.0=20_linux64_openblas
  - libllvm11=11.1.0=he0ac6c6_5
  - libllvm15=15.0.7=hadd5161_1
  - libnghttp2=1.58.0=h47da74e_0
  - libnsl=2.0.1=hd590300_0
  - libogg=1.3.4=h7f98852_1
  - libopenblas=0.3.25=pthreads_h413a1c8_0
  - libopencv=4.6.0=py37hfe11ba8_3
  - libopus=1.3.1=h7f98852_1
  - libpciaccess=0.18=hd590300_0
  - libpng=1.6.43=h2797004_0
  - libpq=15.3=hbcd7760_1
  - libprotobuf=3.20.1=h6239696_4
  - libsanitizer=13.2.0=h6ddb7a1_7
  - libsndfile=1.2.2=hc60ed4a_1
  - libsodium=1.0.18=h36c2ea0_1
  - libsqlite=3.45.3=h2797004_0
  - libssh2=1.11.0=h0841786_0
  - libstdcxx-devel_linux-64=13.2.0=hceb6213_107
  - libstdcxx-ng=13.2.0=hc0a3c3a_7
  - libsystemd0=253=h8c4010b_1
  - libtasn1=4.19.0=h166bdaf_0
  - libtiff=4.4.0=h82bc61c_5
  - libtool=2.4.7=h27087fc_0
  - libudev1=253=h0b41bf4_1
  - libunistring=0.9.10=h7f98852_0
  - libuuid=2.38.1=h0b41bf4_0
  - libva=2.18.0=h0b41bf4_0
  - libvorbis=1.3.7=h9c3ff4c_0
  - libvpx=1.11.0=h9c3ff4c_3
  - libwebp-base=1.4.0=hd590300_0
  - libxcb=1.13=h7f98852_1004
  - libxkbcommon=1.5.0=h79f4944_1
  - libxml2=2.10.3=hca2bb57_4
  - libzlib=1.2.13=hd590300_5
  - llvmlite=0.39.1=py37h0761922_0
  - locket=1.0.0=pyhd8ed1ab_0
  - lz4-c=1.9.4=hcb278e6_0
  - mako=1.3.5=pyhd8ed1ab_0
  - markdown=3.6=pyhd8ed1ab_0
  - markupsafe=2.1.1=py37h540881e_1
  - matplotlib-base=3.5.3=py37hf395dca_2
  - matplotlib-inline=0.1.7=pyhd8ed1ab_0
  - mistune=3.0.2=pyhd8ed1ab_0
  - mlflow=1.30.0=py37h02d9ccd_0
  - mpg123=1.32.6=h59595ed_0
  - multidict=6.0.2=py37h540881e_1
  - multiprocess=0.70.14=py37h540881e_0
  - munkres=1.1.4=pyh9f0ad1d_0
  - mysql-common=8.0.33=hf1915f5_6
  - mysql-libs=8.0.33=hca2cd23_6
  - nbclassic=1.0.0=pyhb4ecaf3_1
  - nbclient=0.7.0=pyhd8ed1ab_0
  - nbconvert=7.6.0=pyhd8ed1ab_0
  - nbconvert-core=7.6.0=pyhd8ed1ab_0
  - nbconvert-pandoc=7.6.0=pyhd8ed1ab_0
  - nbformat=5.8.0=pyhd8ed1ab_0
  - ncurses=6.5=h59595ed_0
  - nest-asyncio=1.6.0=pyhd8ed1ab_0
  - nettle=3.9.1=h7ab15ed_0
  - networkx=2.6.3=pyhd8ed1ab_1
  - nomkl=1.0=h5ca1d4c_0
  - notebook=6.5.7=pyha770c72_0
  - notebook-shim=0.2.4=pyhd8ed1ab_0
  - nspr=4.35=h27087fc_0
  - nss=3.100=hca3bf56_0
  - numba=0.56.3=py37hf081915_0
  - numexpr=2.8.3=py37h85a3170_100
  - numpy=1.21.6=py37h976b520_0
  - oauthlib=3.2.2=pyhd8ed1ab_0
  - opencv=4.6.0=py37h89c1867_3
  - openh264=2.3.1=hcb278e6_2
  - openjpeg=2.5.0=h7d73246_1
  - openssl=3.1.5=hd590300_0
  - opt_einsum=3.3.0=pyhc1e730c_2
  - p11-kit=0.24.1=hc5aa10d_0
  - packaging=21.3=pyhd8ed1ab_0
  - pandas=1.3.5=py37h8c16a72_0
  - pandoc=3.2=ha770c72_0
  - pandocfilters=1.5.0=pyhd8ed1ab_0
  - paramiko=3.4.0=pyhd8ed1ab_0
  - parso=0.8.4=pyhd8ed1ab_0
  - partd=1.4.1=pyhd8ed1ab_0
  - pcre2=10.43=hcad00b1_0
  - pexpect=4.9.0=pyhd8ed1ab_0
  - pickleshare=0.7.5=py_1003
  - pillow=9.2.0=py37h850a105_2
  - pip=24.0=pyhd8ed1ab_0
  - pixman=0.43.2=h59595ed_0
  - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1
  - prometheus_client=0.17.1=pyhd8ed1ab_0
  - prometheus_flask_exporter=0.23.0=pyhd8ed1ab_0
  - prompt-toolkit=3.0.42=pyha770c72_0
  - prompt_toolkit=3.0.42=hd8ed1ab_0
  - protobuf=3.20.1=py37hd23a5d3_0
  - psutil=5.9.3=py37h540881e_0
  - pthread-stubs=0.4=h36c2ea0_1001
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pulseaudio=16.1=hcb278e6_3
  - pulseaudio-client=16.1=h5195f5e_3
  - pulseaudio-daemon=16.1=ha8d29e2_3
  - py-opencv=4.6.0=py37h25bab4e_3
  - pyasn1=0.5.1=pyhd8ed1ab_0
  - pyasn1-modules=0.3.0=pyhd8ed1ab_0
  - pycocotools=2.0.4=py37hda87dfa_2
  - pycparser=2.21=pyhd8ed1ab_0
  - pygments=2.17.2=pyhd8ed1ab_0
  - pygpu=0.7.6=py37hb1e94ed_1003
  - pyjwt=2.8.0=pyhd8ed1ab_1
  - pynacl=1.5.0=py37h540881e_1
  - pyopenssl=23.2.0=pyhd8ed1ab_1
  - pyparsing=3.1.2=pyhd8ed1ab_0
  - pyrsistent=0.18.1=py37h540881e_1
  - pysocks=1.7.1=py37h89c1867_5
  - python=3.7.12=hf930737_100_cpython
  - python-dateutil=2.9.0=pyhd8ed1ab_0
  - python-fastjsonschema=2.19.1=pyhd8ed1ab_0
  - python_abi=3.7=4_cp37m
  - pytz=2022.7.1=pyhd8ed1ab_0
  - pyu2f=0.1.5=pyhd8ed1ab_0
  - pywavelets=1.3.0=py37hda87dfa_1
  - pywin32-on-windows=0.1.0=pyh1179c8e_3
  - pyyaml=6.0=py37h540881e_4
  - pyzmq=24.0.1=py37h0c0c2a8_0
  - qt-main=5.15.6=hf6cd601_5
  - qtconsole-base=5.4.4=pyha770c72_0
  - qtpy=2.4.1=pyhd8ed1ab_0
  - querystring_parser=1.2.4=py_0
  - re2=2022.06.01=h27087fc_1
  - readline=8.2=h8228510_1
  - requests=2.31.0=pyhd8ed1ab_0
  - requests-oauthlib=2.0.0=pyhd8ed1ab_0
  - rsa=4.9=pyhd8ed1ab_0
  - ruamel.yaml=0.17.10=py37h5e8e339_0
  - ruamel.yaml.clib=0.2.6=py37h540881e_1
  - scikit-build=0.17.6=pyh4af843d_0
  - scikit-image=0.19.3=py37hfb7772e_1
  - scikit-learn=1.0.2=py37hf9e9bfc_0
  - scipy=1.7.3=py37hf2a6cf1_0
  - send2trash=1.8.3=pyh0d859eb_0
  - setproctitle=1.3.2=py37h540881e_0
  - setuptools=69.0.3=pyhd8ed1ab_0
  - shapely=1.8.5=py37ha4e3bd1_0
  - six=1.16.0=pyh6c4a22f_0
  - smmap=5.0.0=pyhd8ed1ab_0
  - sniffio=1.3.1=pyhd8ed1ab_0
  - soupsieve=2.3.2.post1=pyhd8ed1ab_0
  - sqlalchemy=1.4.42=py37h540881e_0
  - sqlite=3.45.3=h2c6b66d_0
  - sqlparse=0.4.4=pyhd8ed1ab_0
  - svt-av1=1.4.1=hcb278e6_0
  - sysroot_linux-64=2.12=he073ed8_17
  - tensorboard=2.8.0=pyhd8ed1ab_1
  - tensorboard-data-server=0.6.1=py37h52d8a92_0
  - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
  - tensorflow=2.0.0=mkl_py37h66b46cc_0
  - tensorflow-base=2.0.0=mkl_py37h9204916_0
  - tensorflow-estimator=2.6.0=py37hcd2ae1e_0
  - termcolor=1.1.0=pyhd8ed1ab_3
  - terminado=0.17.1=pyh41d4057_0
  - theano=1.0.4=py37hf484d3e_1000
  - threadpoolctl=3.1.0=pyh8a188c0_0
  - tifffile=2020.6.3=py_0
  - tinycss2=1.3.0=pyhd8ed1ab_0
  - tk=8.6.13=noxft_h4845f30_101
  - tomli=2.0.1=pyhd8ed1ab_0
  - toolz=0.12.1=pyhd8ed1ab_0
  - tornado=6.2=py37h540881e_0
  - tqdm=4.66.4=pyhd8ed1ab_0
  - traitlets=5.9.0=pyhd8ed1ab_0
  - typing-extensions=4.7.1=hd8ed1ab_0
  - typing_extensions=4.7.1=pyha770c72_0
  - unicodedata2=14.0.0=py37h540881e_1
  - urllib3=1.26.18=pyhd8ed1ab_0
  - wcwidth=0.2.10=pyhd8ed1ab_0
  - webencodings=0.5.1=pyhd8ed1ab_2
  - websocket-client=1.6.1=pyhd8ed1ab_0
  - werkzeug=0.16.1=py_0
  - wheel=0.42.0=pyhd8ed1ab_0
  - widgetsnbextension=4.0.10=pyhd8ed1ab_0
  - wrapt=1.14.1=py37h540881e_0
  - x264=1!164.3095=h166bdaf_2
  - x265=3.5=h924138e_3
  - xcb-util=0.4.0=h516909a_0
  - xcb-util-image=0.4.0=h166bdaf_0
  - xcb-util-keysyms=0.4.0=h516909a_0
  - xcb-util-renderutil=0.3.9=h166bdaf_0
  - xcb-util-wm=0.4.1=h516909a_0
  - xkeyboard-config=2.38=h0b41bf4_0
  - xorg-fixesproto=5.0=h7f98852_1002
  - xorg-inputproto=2.3.2=h7f98852_1002
  - xorg-kbproto=1.0.7=h7f98852_1002
  - xorg-libice=1.1.1=hd590300_0
  - xorg-libsm=1.2.4=h7391055_0
  - xorg-libx11=1.8.4=h0b41bf4_0
  - xorg-libxau=1.0.11=hd590300_0
  - xorg-libxdmcp=1.1.3=h7f98852_0
  - xorg-libxext=1.3.4=h0b41bf4_2
  - xorg-libxfixes=5.0.3=h7f98852_1004
  - xorg-libxi=1.7.10=h7f98852_0
  - xorg-libxrender=0.9.10=h7f98852_1003
  - xorg-renderproto=0.11.1=h7f98852_1002
  - xorg-xextproto=7.3.0=h0b41bf4_1003
  - xorg-xproto=7.0.31=h7f98852_1007
  - xz=5.2.6=h166bdaf_0
  - yaml=0.2.5=h7f98852_2
  - yarl=1.7.2=py37h540881e_2
  - zeromq=4.3.5=h59595ed_1
  - zipp=3.15.0=pyhd8ed1ab_0
  - zlib=1.2.13=hd590300_5
  - zstd=1.5.6=ha6fb4c9_0
  - pip:
      - databricks-cli==0.18.0
      - tabulate==0.9.0
prefix: /home/local-admin/.conda/envs/maskrcnn

然后只需输入

conda env create -n maskrcnn python=3.8.19 -f environment.yml

还下载 Maks RCNN 模型:

git clone -b tensorflow-2.0 https://github.com/tomgross/Mask_RCNN.git Mask_RCNN

这是测试模型的代码:

import sys
sys.path.append("./Mask_RCNN")

import os

import numpy as np
import cv2
import imutils
import matplotlib.pyplot as plt

import tensorflow as tf

from IPython.display import clear_output


from mrcnn.config import Config
from mrcnn import model as modellib
from mrcnn import visualize

img_path = f"./test.jpg"

image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = imutils.resize(image, width=512)

plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(image);

class TestConfig(Config):
    NAME = "mask-rcnn test"
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    NUM_CLASSES = 2
    # BACKBONE = "resnet50"
    # IMAGE_MIN_DIM = 800
    # IMAGE_MAX_DIM = 1024

CLASS_NAMES = ['BG', 'FG']
TEST_MODEL_PATH = "./test_model.h5"
config = TestConfig()
# config.display()

# create new model with random weights
model = modellib.MaskRCNN(mode="training", config=config, model_dir=os.getcwd())
model.keras_model.save_weights(TEST_MODEL_PATH)

clear_output()

# prepare visualization
fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(20, 5))
fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.01, hspace=None)

# load the created random weights 3 times
weights = []
differences = []
for i in range(3):
    model = modellib.MaskRCNN(mode="inference", config=config, model_dir=os.getcwd())
    model.load_weights(TEST_MODEL_PATH, by_name=True)
    result = model.detect([image], verbose=0)[0]
    ax[i].set_title(f"{i}. Results")
    ax[i].axis("off")
    visualize.display_instances(image, result['rois'], result['masks'], result['class_ids'], CLASS_NAMES, result['scores'],
                                ax=ax[i], show_mask=True, show_bbox=True)
    weights += [model.keras_model.get_weights()]

clear_output()

# detect differences between the models
for cur_weights_1 in weights:
    cur_differences = []
    for cur_weights_2 in weights:
        all_weigt_differences = [np.abs(w1-w2) for w1, w2 in zip(cur_weights_1, cur_weights_2)]
        all_layer_differences = [np.sum(diff) for diff in all_weigt_differences]
        cur_differences += [np.sum(all_layer_differences)]
    differences += [cur_differences]
complete_sum_difference = np.sum(differences)
if complete_sum_difference == 0.0:
    print("Congratulations!\nYour implementation works and is ready to segment!")
else:
    print("Difference detected. Your Mask RCNN seem to have an problem.")
    print("Check if you correctly applied all steps of the installation.")
    print("\nMore detail of differences:")
    print("Absolute Weight Differences:", complete_sum_difference)
    print("Comparison: Weight 1, Weight 2, Weight 3")
    for idx in range(len(differences)):
        print(f"Weight {idx+1}: {differences[idx]}")

# try to remove the test model file
try:
    os.remove(TEST_MODEL_PATH)
except Exception:
    print(f"Wasn't able to delete the test model at: {TEST_MODEL_PATH}")

plt.show();

这是测试图像:

Test Image

我希望这能让您的娱乐变得轻松。

有人有解决办法吗?


我尝试过:

我还尝试了小于 2.12.1 的张量流版本,但这些不起作用。您还可以在此处查看 TensorFlow 版本的 CUDA 版本:https://www.tensorflow.org/install/source

我期望加载相同重量 3 次时得到相同的重量。为了验证,在我在 GTX 1080 Ti 上训练和使用模型之前,使用 Python 3.7.12、tensorflow 1.15 和 Keras 2.3.1,上面的代码成功并且权重相等,这也是预期的。 有了新的 GPU,这看起来就不一样了。

python machine-learning deep-learning mask-rcnn instance-segmentation
1个回答
0
投票

抱歉,为什么您不使用

torchvision
中的模型?您可以加载它进行推理,并且 torch 还提供了从头开始训练它的方法:

from torchvision.models.detection import (
    MaskRCNN_ResNet50_FPN_Weights, 
    maskrcnn_resnet50_fpn
)
import torch
import PIL
import numpy as np

weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1

# image preprocessing
transform = weights.transforms()
model = maskrcnn_resnet50_fpn(weights=weights)


img = PIL.Image.open("airplane.jpeg")
input = transform(img)


model.eval()

# add batch dimension
input = input[None,...] 


preds = model(input)
print(preds[0]['labels']) # tensor([5])


# for training check this scripts
# https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn

© www.soinside.com 2019 - 2024. All rights reserved.