我正在尝试加载本地图像数据集并用它来训练我的模型。 我正在像这样加载数据集。
data_load = tk.utils.image_dataset_from_directory(
dir,
labels="inferred",
batch_size=128,
image_size=image_shape,
shuffle=True,
seed=42,
validation_split=0.2,
subset="training",
)
这里的dir是我的数据存储的本地路径。 当我使用这些数据来训练我的模型时,使用 model.fit ,我收到此错误。
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
c:\Users\HP\Desktop\SBU\Courses\spring23\ese577\Labs\Lab3\lenet.ipynb Cell 8 in 2
1 epoch = 15
----> 2 hist = model.fit(data.train, batch_size=batch_size, epochs=epoch)
File c:\python10\lib\site-packages\keras\utils\traceback_utils.py:70, in filter_traceback..error_handler(*args, **kwargs)
67 filtered_tb = _process_traceback_frames(e.__traceback__)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
File c:\python10\lib\site-packages\tensorflow\python\eager\execute.py:52, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
50 try:
51 ctx.ensure_initialized()
---> 52 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
53 inputs, attrs, num_outputs)
54 except core._NotOkStatusException as e:
55 if name is not None:
InvalidArgumentError: Graph execution error:
Number of channels inherent in the image must be 1, 3 or 4, was 2
[[{{node decode_image/DecodeImage}}]]
[[IteratorGetNext]] [Op:__inference_train_function_4955]
有趣的是,它总是在这个阶段抛出这个错误
Epoch 1/15
9/124 [=>............................] - ETA: 6:36 - loss: 5.0484 - accuracy: 0.4852
当我搜索类似的错误时,我发现它通常是在读取bmp图像时观察到的。我所有的图像都是 jpg,但我仍然收到此错误。
如何修复此错误,或者如何识别不良图像,以便我可以将其从数据集中删除并继续我的训练?
您必须过滤掉导致问题的任何图像。我使用下面的代码来处理目录中的所有图像并检测有缺陷的图像,以便我可以从数据集中删除它们。
import os
import cv2
from tqdm import tqdm
datadir=r'c:\datasets\autism\test'# path to directory with class sub directories holding the image files
bad_img_list=[] # will be a list of defective images
classes=sorted(os.listdir(datadir)) # a list of classes within the datadir
for klass in classes: # iterate through each class
classpath=os.path.join(datadir, klass)
flist=sorted(os.listdir(classpath)) # list of files in the current class
for f in tqdm(flist, ncols=100, unit='files', colour='blue', desc=klass): # iterate through the files
fpath=os.path.join(classpath,f) # path to image file
try:
index=f.rfind('.') # find the rightmost . in f
ext=f[index+1:].lower() # get the files extension and convert to lower case
good_ext=['jpg', 'jpeg', 'bmp', 'png']# list of allowable extension for image_dataset_from_directory
if ext not in good_ext:
raise ValueError('image had improper extension') # create an exception so the file will be appended to bad_img_list
img=cv2.imread(fpath) # read in the image
shape=img.shape # get the image shape (height, width) or (height, width, channels)
count=len(shape)
if count == 2: # if shapeis (width, height) image is single channel
channels=1
else:
channels=shape[2] # shapeis (width, height, channels)
if channels == 2:
raise ValueError('image had 2 channels') # create an exception so the file will be appended to bad_img_list
except:
bad_img_list.append(fpath) # append to bad_img_list if there is an exception
if len(bad_img_list) >0:
print('below is a list of defective image filepaths')
for f in bad_img_list:
print (f)
很奇怪,消息包含问题但不包含文件名,是吗?
我使用以下方法来清理我的训练集。由于它使用tensorflow来读取图像,因此在有问题的图像方面应该是一致的:
import glob
import os
import tensorflow
for fn in glob.glob("images/**/*"):
with open(fn, 'rb') as f:
try:
tensorflow.io.decode_image(f.read())
except Exception as e:
# os.remove(fn) # Warning: Make sure the glob above doesn't match critical non-image files
print(fn, e)