我正在尝试使用以下代码对二维图像进行小波散射:
#import torch
from kymatio import Scattering2D
import numpy as np
import PIL
from PIL import Image
FILENAME = "./square.png"
image = PIL.Image.open(FILENAME).convert("L")
a = np.array(image).astype(np.float64)
x = torch.from_numpy(a)
imageSize = x.shape
print( imageSize )
scattering = Scattering2D(J=2, shape=imageSize, frontend='numpy', L=8)
Sx = scattering(x)
print(Sx.size())
并收到以下错误消息。有人可以帮忙吗?
回溯(最近一次调用最后一次): 文件“/Users/pminev/Desktop/scatter2d.py”,第 18 行,位于 Sx = 散射(x) ^^^^^^^^^^^^^^ 文件“/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”,第 1511 行,在 _wrapped_call_impl 中 返回 self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 文件“/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”,第 1520 行,在 _call_impl 中 返回forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 文件“/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/kymatio/frontend/torch_frontend.py”,第 22 行,向前 返回自散射(x) ^^^^^^^^^^^^^^^^^^^ 文件“/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/kymatio/scattering2d/frontend/torch_frontend.py”,第 98 行,散射 S =scattering2d(输入, self.pad, self.unpad, self.backend, self.J, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^ 文件“/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/kymatio/scattering2d/core/scattering2d.py”,第19行,在scattering2d中 U_1_c = cdgmm(U_0_c, phi['级别'][0]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 文件“/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/kymatio/backend/torch_backend.py”,第 192 行,位于 cdgmm raise TypeError('输入和过滤器必须具有相同的数据类型。') 类型错误:输入和过滤器必须具有相同的数据类型。
尝试对图像进行小波散射。
您可以尝试像这样使用 kymatio 调用 numpy。 手电筒的使用似乎是问题所在
import numpy as np
import matplotlib.pyplot as plt
import kymatio.numpy as kp
# Example image (2D array)
FILENAME = "./square.png"
pil_img = PIL.Image.open(FILENAME).convert("L")
img = np.array(pil_img).astype(float)
# Define wavelet scattering parameters
J = 2 # Number of scales
L = 8 # Number of angles
image_size = img.shape[0]
# Compute wavelet scattering transform
scattering = kp.Scattering2D(J=J, shape=(image_size, image_size), L=L)
scattering_coeffs = scattering(img)
# Plot the original image
plt.figure(figsize=(8, 4))
plt.subplot(1, 3, 1)
plt.imshow(img, cmap='gray')
plt.title('Original Image')
plt.axis('off')
# Plot the first-order scattering coefficients
plt.subplot(1, 3, 2)
plt.imshow(scattering_coeffs[1], cmap='viridis')
plt.title('First-order Scattering Coefficients')
plt.axis('off')
# Plot the second-order scattering coefficients
plt.subplot(1, 3, 3)
plt.imshow(scattering_coeffs[2], cmap='viridis')
plt.title('Second-order Scattering Coefficients')
plt.axis('off')
plt.tight_layout()
plt.show()