如何在 Pytorch 中实际应用 Conv2d 过滤器

问题描述 投票:0回答:2

我是 Python 新手,正在尝试使用 PyTorch 中的过滤器进行一些操作。

我正在努力思考如何应用 Conv2d。我有以下代码,可创建 3x3 移动平均滤波器:

resized_image4D = np.reshape(image_noisy, (1, 1, image_noisy.shape[0], image_noisy.shape[1]))
t = torch.from_numpy(resized_image4D)

conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, padding=1, bias=False)
conv.weight = torch.nn.Parameter(torch.ones((1,1,3, 3))/9.0)

通常在 NumPy 中我只会调用

filtered_image = convolve2d(image, kernel)
,但经过几天的搜索,我无法弄清楚 PyTorch 的等效项是什么。

python pytorch convolution
2个回答
2
投票

我认为您正在寻找

torch.nn.functional.conv2d

因此,你的片段变成:

resized_image4D = np.reshape(image_noisy, (1, 1, image_noisy.shape[0], image_noisy.shape[1]))
t = torch.from_numpy(resized_image4D)

conv = torch.nn.functional.conv2d(in_channels=1, out_channels=1, kernel_size=3, padding=1, bias=False)
conv.weight = torch.nn.Parameter(torch.ones((1,1,3, 3))/9.0)

0
投票

出于学习目的,我使用 scipy 实现了一个简化版本:

import numpy as np
from scipy import signal

def conv2d_simplified(input, weight, bias=None, padding=0):
    # This is an implemention of torch's conv2d using scipy correlate2d. Only
    # limited options are supported for simplicity.
    # Inspired by https://github.com/99991/NumPyConv2D/
    c_out, c_in_by_groups, kh, kw = weight.shape
    if not isinstance(padding, int):
        raise NotImplementedError()
    
    if padding:
        input = np.pad(input, ((0, 0), (0, 0), (padding, padding), (padding, padding)), "constant")

    outArr = np.empty((input.shape[0], c_out, input.shape[2]+1-kh, input.shape[3]+1-kw))
    al = np.empty((outArr.shape[2], outArr.shape[3]))

    for k in range(input.shape[0]):
        for i in range(weight.shape[0]):
            al[:, :] = 0.0

            for j in range(weight.shape[1]):
                al += signal.correlate2d(input[k, j, :, :], weight[i, j, :, :], 'valid')

            outArr[k, i, :, :] = al

    if bias is not None:
        outArr = outArr + bias.reshape(1, c_out, 1, 1)

    return outArr
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.