我是 Python 新手,正在尝试使用 PyTorch 中的过滤器进行一些操作。
我正在努力思考如何应用 Conv2d。我有以下代码,可创建 3x3 移动平均滤波器:
resized_image4D = np.reshape(image_noisy, (1, 1, image_noisy.shape[0], image_noisy.shape[1]))
t = torch.from_numpy(resized_image4D)
conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, padding=1, bias=False)
conv.weight = torch.nn.Parameter(torch.ones((1,1,3, 3))/9.0)
通常在 NumPy 中我只会调用
filtered_image = convolve2d(image, kernel)
,但经过几天的搜索,我无法弄清楚 PyTorch 的等效项是什么。
torch.nn.functional.conv2d
。
因此,你的片段变成:
resized_image4D = np.reshape(image_noisy, (1, 1, image_noisy.shape[0], image_noisy.shape[1]))
t = torch.from_numpy(resized_image4D)
conv = torch.nn.functional.conv2d(in_channels=1, out_channels=1, kernel_size=3, padding=1, bias=False)
conv.weight = torch.nn.Parameter(torch.ones((1,1,3, 3))/9.0)
出于学习目的,我使用 scipy 实现了一个简化版本:
import numpy as np
from scipy import signal
def conv2d_simplified(input, weight, bias=None, padding=0):
# This is an implemention of torch's conv2d using scipy correlate2d. Only
# limited options are supported for simplicity.
# Inspired by https://github.com/99991/NumPyConv2D/
c_out, c_in_by_groups, kh, kw = weight.shape
if not isinstance(padding, int):
raise NotImplementedError()
if padding:
input = np.pad(input, ((0, 0), (0, 0), (padding, padding), (padding, padding)), "constant")
outArr = np.empty((input.shape[0], c_out, input.shape[2]+1-kh, input.shape[3]+1-kw))
al = np.empty((outArr.shape[2], outArr.shape[3]))
for k in range(input.shape[0]):
for i in range(weight.shape[0]):
al[:, :] = 0.0
for j in range(weight.shape[1]):
al += signal.correlate2d(input[k, j, :, :], weight[i, j, :, :], 'valid')
outArr[k, i, :, :] = al
if bias is not None:
outArr = outArr + bias.reshape(1, c_out, 1, 1)
return outArr