将图像分割成小块

问题描述 投票:0回答:1

我想实现 Vision Transformer 模型,他们在论文中表示,他们将输入图像分割成一定分辨率的小块,比如如果图像 64x64,块分辨率为 16x16,那么它将被分成 16 个小块,每个小块分辨率 16x16,所以最终的形状是 (N,P,P,C),其中 N 是 patch 数量,P 是分辨率,C 是通道数量。

我尝试过将分割矢量化:

def image_to_patches_fast(image, res_patch):
    
    (H, W, C) = get_image_shape(image)
    
    
    if C == 1:
        image = image.convert('RGB')
        (H, W, C) = get_image_shape(image)
                    
    P = res_patch
    N = (H*W)//(P**2)
        
    image_tensor = torchvision.transforms.PILToTensor()(image).permute(1,2,0)
    image_patches = image_tensor.view(N,P,P,C)

该功能有效,但输出并不符合预期,因为当我尝试可视化补丁时,出现了问题,补丁可能位置不佳或者我不知道,这是一个例子:

输入图像:

输出补丁的可视化:

可视化补丁的功能:

def show_patches(patches):
    
    N,P = patches.shape[0], patches.shape[1]
       
    nrows, ncols = int(N**0.5),int(N**0.5)
    fig, axes = plt.subplots(nrows = nrows, ncols=ncols)
    for row in range(nrows):

        for col in range(ncols):

            idx = col + (row*nrows)
            
            axes[row][col].imshow(patches[idx,:,:,:])
            axes[row][col].axis("off")

    plt.subplots_adjust(left=0.1,
                    bottom=0.1,
                    right=0.9,
                    top=0.9,
                    wspace=0.1,
                    hspace=0.1)
    plt.show()
    

我尝试了另一个函数来分割图像,但由于使用循环,速度较慢,并且它按预期工作:

def image_to_patches_slow(image, res_patch):
    
    (H, W, C) = get_image_shape(image)
    
    
    if C == 1:
        image = image.convert('RGB')
        (H, W, C) = get_image_shape(image)
                    
    P = res_patch
    N = (H*W)//(P**2)
    
    nrows, ncols = int(N**0.5), int(N**0.5)
    
    image_tensor = torchvision.transforms.PILToTensor()(image).permute(1,2,0)
    image_patches = torch.zeros((N,P,P,C),dtype = torch.int)
    
    
    for row in range(nrows):
        s_row = row * N
        e_row = (row * N) + N
        for col in range(ncols):

            idx = col + (row*nrows)

            s_col = col*N
            e_col = (col*N) + N
                
            image_patches[idx] = image_tensor[s_row:e_row, s_col:e_col]
    
    return image_patches

它的输出:

所以任何帮助,因为这个缓慢的版本成为训练的瓶颈。

python matplotlib deep-learning pytorch python-imaging-library
1个回答
0
投票

此方法使用单线整形操作进行修补。每个通道都会执行此操作。

如果图像尺寸不能被补丁宽度整除,它将通过剪掉末端来裁剪图像。如果您用更智能的方式替换这种基本的裁剪,例如中心裁剪、调整大小或

torchvision
中提供的组合(缩放然后中心裁剪),那就更好了。

下面的示例是将 200x200 图像分解为 50 像素的块。

import torchvision, torch

img = torchvision.io.read_image('../image.png').permute(1, 2, 0)

H, W, C = img.shape

patch_width = 50
n_rows = H // patch_width
n_cols = W // patch_width

cropped_img = img[:n_rows * patch_width, :n_cols * patch_width, :]

#
# Into patches
# [n_rows, n_cols, patch_width, patch_width, C]
#
patches = torch.empty(n_rows, n_cols, patch_width, patch_width, C)
for chan in range(C):
    patches[..., chan] = (
        cropped_img[..., chan]
        .reshape(n_rows, patch_width, n_cols, patch_width)
        .permute(0, 2, 1, 3)
    )
    
#
#Plot
#
f, axs = plt.subplots(n_rows, n_cols, figsize=(5, 5))

for row_idx in range(n_rows):
    for col_idx in range(n_cols):
        axs[row_idx, col_idx].imshow(patches[row_idx, col_idx, ...] / 255)

for ax in axs.flatten():
    ax.set_xticks([])
    ax.set_yticks([])
f.subplots_adjust(wspace=0.05, hspace=0.05)
© www.soinside.com 2019 - 2024. All rights reserved.