在 python 中组合以 4D 数组形式给出的图像补丁的最快方法

问题描述 投票:0回答:1

给定一个大小为 (N,W,H,3) 的 4D 数组,其中 N 是块的数量,W,H 是图像块的宽度和高度,3 是颜色通道的数量。假设这些补丁是通过获取原始图像 I 并将其分成小方块而生成的。这种划分发生的顺序是逐行的。因此,如果我们将图像分成 3x3 个块(总共 9 个),每个块的大小为 10x10 像素,则 4D 数组将为 (9,10,10,3),其中元素的顺序为 [patch11,patch12,patch13,patch21 ,补丁22,补丁23,补丁31,补丁32,补丁33]。

现在我的问题是,仅使用 python 函数和 numpy(无 PIL 或 OpenCV)将这些补丁组合回以在 python 中生成原始图像的最有效方法。

非常感谢。

我可以编写一个双 for 循环来完成如下工作,但我想知道是否有更好的算法可以提供更快的性能:

import numpy as np

def reconstruct_image(patches, num_rows, num_cols):
    # num_rows and num_cols are the number of patches in the rows and columns respectively
    patch_height, patch_width, channels = patches.shape[1], patches.shape[2], patches.shape[3]

    # Initialize the empty array for the full image
    full_image = np.zeros((num_rows * patch_height, num_cols * patch_width, channels), dtype=patches.dtype)

    # Iterate over the rows and columns of patches
    for i in range(num_rows):
        for j in range(num_cols):
            # Get the index of the current patch in the 4D array
            patch_index = i * num_cols + j
            # Place the patch in the appropriate position in the full image
            full_image[i*patch_height:(i+1)*patch_height, j*patch_width:(j+1)*patch_width, :] = patches[patch_index]

    return full_image

N = 9  # Number of patches
W, H, C = 10, 10, 3  # Patch dimensions (WxHxC)
num_rows, num_cols = 3, 3  # Number of patches in rows and columns (3x3 patches)
patches = np.random.rand(N, W, H, C)  # Example patch data

reconstructed_image = reconstruct_image(patches, num_rows, num_cols)
python algorithm image recursion combinations
1个回答
0
投票

这是一种纯粹的 numpy 方法:

M = 3 # Number of patches per dimesion
N = M*M  # Number of patches
W, H, C = 10, 10, 3  # Patch dimensions (WxHxC)
num_rows, num_cols = 3, 3  # Number of patches in rows and columns (3x3 patches)
patches = np.random.rand(N, W, H, C)  # Example patch data

reconstructed_image = reconstruct_image(patches, num_rows, num_cols)

reconstructed_image_2 = np.transpose(np.reshape(patches, (M,M,W, H, C)), axes=(0,2,1,3,4)).reshape(M*W, M*H, C)

assert np.all(reconstructed_image == reconstructed_image) # True
© www.soinside.com 2019 - 2024. All rights reserved.