我想实现 Vision Transformer 模型,他们在论文中表示,他们将输入图像分割成一定分辨率的小块,比如如果图像 64x64,块分辨率为 16x16,那么它将被分成 16 个小块,每个小块分辨率 16x16,所以最终的形状是 (N,P,P,C),其中 N 是 patch 数量,P 是分辨率,C 是通道数量。
我尝试过将分割矢量化:
def image_to_patches_fast(image, res_patch):
(H, W, C) = get_image_shape(image)
if C == 1:
image = image.convert('RGB')
(H, W, C) = get_image_shape(image)
P = res_patch
N = (H*W)//(P**2)
image_tensor = torchvision.transforms.PILToTensor()(image).permute(1,2,0)
image_patches = image_tensor.view(N,P,P,C)
该功能有效,但输出并不符合预期,因为当我尝试可视化补丁时,出现了问题,补丁可能位置不佳或者我不知道,这是一个例子:
可视化补丁的功能:
def show_patches(patches):
N,P = patches.shape[0], patches.shape[1]
nrows, ncols = int(N**0.5),int(N**0.5)
fig, axes = plt.subplots(nrows = nrows, ncols=ncols)
for row in range(nrows):
for col in range(ncols):
idx = col + (row*nrows)
axes[row][col].imshow(patches[idx,:,:,:])
axes[row][col].axis("off")
plt.subplots_adjust(left=0.1,
bottom=0.1,
right=0.9,
top=0.9,
wspace=0.1,
hspace=0.1)
plt.show()
我尝试了另一个函数来分割图像,但由于使用循环,速度较慢,并且它按预期工作:
def image_to_patches_slow(image, res_patch):
(H, W, C) = get_image_shape(image)
if C == 1:
image = image.convert('RGB')
(H, W, C) = get_image_shape(image)
P = res_patch
N = (H*W)//(P**2)
nrows, ncols = int(N**0.5), int(N**0.5)
image_tensor = torchvision.transforms.PILToTensor()(image).permute(1,2,0)
image_patches = torch.zeros((N,P,P,C),dtype = torch.int)
for row in range(nrows):
s_row = row * N
e_row = (row * N) + N
for col in range(ncols):
idx = col + (row*nrows)
s_col = col*N
e_col = (col*N) + N
image_patches[idx] = image_tensor[s_row:e_row, s_col:e_col]
return image_patches
所以任何帮助,因为这个缓慢的版本成为训练的瓶颈。
此方法使用单线整形操作进行修补。它对每个通道执行此操作,并假设尺寸可被贴片宽度整除。如果没有,您需要裁剪/调整图像大小。
import torchvision, torch
img = torchvision.io.read_image('../image.png').permute(1, 2, 0)
H, W, C = img.shape
patch_width = 100
n_rows = H // patch_width
n_cols = W // patch_width
#
# Into patches
# [n_rows, n_cols, patch_width, patch_width, C]
#
patches = torch.empty(n_rows, n_cols, patch_width, patch_width, C)
for chan in range(C):
patches[..., chan] = img[..., chan].reshape(n_rows, patch_width, n_cols, patch_width).permute(0, 2, 1, 3)
#
#Plot
#
f, axs = plt.subplots(n_rows, n_cols, figsize=(5, 5))
for row_idx in range(n_rows):
for col_idx in range(n_cols):
axs[row_idx, col_idx].imshow(patches[row_idx, col_idx, ...] / 255)
for ax in axs.flatten():
ax.set_xticks([])
ax.set_yticks([])
f.subplots_adjust(wspace=0.05, hspace=0.05)