gpu

问题描述 投票:0回答:1

目前使用Pywavelet在分类器上工作,这是我的计算块:

class WaveletLayer(nn.Module):
    def __init__(self):
        super(WaveletLayer, self).__init__()

    def forward(self, x):
        def wavelet_transform(img):
            coeffs = pywt.dwt2(img.cpu().numpy(), "haar")
            LL, (LH, HL, HH) = coeffs
            return (
                torch.from_numpy(LL).to(img.device),
                torch.from_numpy(LH).to(img.device),
                torch.from_numpy(HL).to(img.device),
                torch.from_numpy(HH).to(img.device),
            )

        # Apply wavelet transform to each channel separately
        LL, LH, HL, HH = zip(
            *[wavelet_transform(x[:, i : i + 1]) for i in range(x.shape[1])]
        )

        # Concatenate the results
        LL = torch.cat(LL, dim=1)
        LH = torch.cat(LH, dim=1)
        HL = torch.cat(HL, dim=1)
        HH = torch.cat(HH, dim=1)

        return torch.cat([LL, LH, HL, HH], dim=1)

该模块的输出进入一个重新网络进行学习,而这样做的同时,我发现我的CPU堵塞了,从而减慢了我的训练过程

我正在尝试将GPU用于这些计算。

python tensorflow machine-learning deep-learning pytorch
1个回答
0
投票

沿每个维度的HAAR小波的高频分量是成对差,除以2的平方根。
    沿每个维度的HAAR小波的低频分量是成对总和,除以2的平方根。
  • 以下代码在纯pytorch中实现了这一点:
  • class HaarWaveletLayer(nn.Module): R_SQRT_2 = 1 / 2 ** .5 # 1 / sqrt(2) for normalization # avg ("low") along cols def l_0(self, t): return (t[..., ::2, :] + t[..., 1::2, :]) * self.R_SQRT_2 # avg ("low") along rows def l_1(self, t): return (t[..., :, ::2] + t[..., :, 1::2]) * self.R_SQRT_2 # diff ("hi") along cols def h_0(self, t): return (t[..., ::2, :] - t[..., 1::2, :]) * self.R_SQRT_2 # diff ("hi") along rows def h_1(self, t): return (t[..., :, ::2] - t[..., :, 1::2]) * self.R_SQRT_2 def forward(self, x): l_1 = self.l_1(x) h_1 = self.h_1(x) ll = self.l_0(l_1) lh = self.h_0(l_1) hl = self.l_0(h_1) hh = self.h_0(h_1) return torch.cat([ll, lh, hl, hh], dim=1)

结合给定的代码,您可以说服自己的等效性,如下所示:

t = torch.rand((7, 3, 14, 14))
result_given = WaveletLayer()(t)
result_own = HaarWaveletLayer()(t)
assert (result_given - result_own).abs().max() < 1e-5

注意提供的代码仅适用于均匀形状的图像。


最新问题
© www.soinside.com 2019 - 2025. All rights reserved.