我已经编写了代码来加载 ImageNet 2012 的片段并将其作为 np.array 返回,虽然该代码适用于大多数情况,但某些图像对于代码来说是有问题的,因为它们不是 RGB 格式。
from PIL import Image
import numpy as np
import argparse
def load_image_list (PATH, IMAGE_LIST):
output_list = []
for fname in IMAGE_LIST:
img_pil = Image.open(PATH + '/' + fname)
if img_pil.mode == 'CMYK':
img_pil = img_pil.convert('RGB')
width, height = img_pil.size
smallest_dim = min((width, height))
left = (width - smallest_dim)/2
right = (width + smallest_dim)/2
top = (height - smallest_dim)/2
bottom = (height + smallest_dim)/2
img_pil = img_pil.crop((left, top, right, bottom))
img_pil = img_pil.resize((224,224))
img_pil.load()
img1 = np.asarray(img_pil, dtype="float32")
if img1.ndim == 2:
img1 = np.stack((img1,)*3, axis=-1)
img1 = np.rollaxis(np.array(img1),2,0).astype(np.float32)
output_list.append(img1)
try:
output_list = np.array(output_list, dtype="float32")
except:
counter = 0
helper = -1
for o in output_list:
print(counter, o.shape)
counter += 1
if(o.shape[0] == 4):
print(o)
helper = output_list.index(o)
counter = 0
for name in IMAGE_LIST:
print(counter, name)
counter += 1
print(helper, IMAGE_LIST[helper])
print(helper, output_list[helper].shape)
retry_item = IMAGE_LIST[helper]
img_pil = Image.open(PATH + '/' + retry_item)
width, height = img_pil.size
smallest_dim = min((width, height))
left = (width - smallest_dim)/2
right = (width + smallest_dim)/2
top = (height - smallest_dim)/2
bottom = (height + smallest_dim)/2
img_pil = img_pil.crop((left, top, right, bottom))
img_pil = img_pil.resize((224,224))
img_pil.load()
img1 = np.asarray(img_pil, dtype="float32")
print('number of dimensions: ', img1.ndim)
print('shape: ', img1.shape)
if img1.ndim == 2:
print('ndim was 2')
print(img1)
img1 = np.stack((img1,)*3, axis=-1)
img1 = np.rollaxis(np.array(img1),2,0).astype(np.float32)
print('number of dimensions: ', img1.ndim)
print('shape: ', img1.shape)
output_list = np.array(output_list, dtype="float32")
output_list = output_list/255
return output_list
except 行是为了找出它何时中断、为什么会中断。我已经可以加载验证集中的每个图像,但现在我尝试在测试集上运行,但图像 n02105855_2933.JPEG 再次出现错误,因为加载的图像是 4 维而不是 3 维。有没有更好的方法将未知颜色格式的 JPEG 加载为 RGB 图像? ImageNet 2012 中还使用了哪些其他颜色格式?
特定图像的问题在于它是rgba格式,a尺寸均为255。所以下面的代码解决了这个问题:
from PIL import Image
import numpy as np
import argparse
def load_image_list (PATH, IMAGE_LIST):
output_list = []
for fname in IMAGE_LIST:
img_pil = Image.open(PATH + '/' + fname)
if img_pil.mode == 'CMYK':
img_pil = img_pil.convert('RGB')
width, height = img_pil.size
smallest_dim = min((width, height))
left = (width - smallest_dim)/2
right = (width + smallest_dim)/2
top = (height - smallest_dim)/2
bottom = (height + smallest_dim)/2
img_pil = img_pil.crop((left, top, right, bottom))
img_pil = img_pil.resize((224,224))
img_pil.load()
img1 = np.asarray(img_pil, dtype="float32")
if img1.ndim == 2:
img1 = np.stack((img1,)*3, axis=-1)
if img1.shape[2] == 4:
img1 = np.delete(img1, 3, axis=2)
img1 = np.rollaxis(np.array(img1),2,0).astype(np.float32)
output_list.append(img1)
try:
output_list = np.array(output_list, dtype="float32")
except:
counter = 0
helper = -1
for o in output_list:
print(counter, o.shape)
counter += 1
if(o.shape[0] == 4):
print(o)
helper = output_list.index(o)
counter = 0
for name in IMAGE_LIST:
print(counter, name)
counter += 1
print(helper, IMAGE_LIST[helper])
print(helper, output_list[helper].shape)
retry_item = IMAGE_LIST[helper]
img_pil = Image.open(PATH + '/' + retry_item)
width, height = img_pil.size
smallest_dim = min((width, height))
left = (width - smallest_dim)/2
right = (width + smallest_dim)/2
top = (height - smallest_dim)/2
bottom = (height + smallest_dim)/2
img_pil = img_pil.crop((left, top, right, bottom))
img_pil = img_pil.resize((224,224))
img_pil.load()
img1 = np.asarray(img_pil, dtype="float32")
print('number of dimensions: ', img1.ndim)
print('shape: ', img1.shape)
if img1.ndim == 2:
print('ndim was 2')
print(img1)
img1 = np.stack((img1,)*3, axis=-1)
if img1.shape[2] == 4:
print('color dimension was size 4')
img1 = np.delete(img1, 3, axis=2)
img1 = np.rollaxis(np.array(img1),2,0).astype(np.float32)
print('number of dimensions: ', img1.ndim)
print('shape: ', img1.shape)
output_list = np.array(output_list, dtype="float32")
output_list = output_list/255
return output_list