我正在编写一个自定义图像数据加载函数,以根据其二进制掩码随机裁剪大图像的一部分。该函数将在 PyTorch 数据加载器中使用,因此我希望它尽可能快且节省内存。图像和掩模都很大,宽度和高度都在 10k~20k 像素的数量级。
我希望图像的每一幅图像在二值掩模中至少包含一个正点。我当前的解决方案是首先从掩模图像中随机采样一个正点,然后在其周围生成一个裁剪框。实现包含一段代码如下:
import PIL
import numpy as np
... # Some preprocessing to find all mask and image files
mask = PIL.Image.open(mask_file) # both the width and height have 10k~20k pixels.
# fast_pil_to_numpy: https://uploadcare.com/blog/fast-import-of-pillow-images-to-numpy-opencv-arrays/
mask_np = fast_pil_to_numpy(mask).astype(bool) # dim: [height, width]
mask_loc = np.where(mask_np) # get (loc_y, loc_x) of all positive indices
idx = np.random.randint(low=0, high=len(mask_loc[0]))
... # Generate a crop box around (mask_loc[1][idx], mask_loc[0][idx])
用
line_profiler
分析整个函数后,我发现mask_loc = np.where(mask_np)
行是性能瓶颈之一。我该如何优化这部分?是否有另一种更有效的方法从二值图像中随机采样一个正点?
最好的方法是随机采样mask中的一个像素,如果没有设置,则重试。在达到枚举所有设置像素的性能之前,您可以尝试数千次。
如果您的蒙版非常非常稀疏,那么您当前的方法可能是最好的。
如果您的蒙版是较大图像中的一个小而紧凑的区域,则获取边界框并仅对该框中的随机像素进行采样将加快速度。