如何识别tensorflow keras中的问题图像?

问题描述 投票:0回答:2

我正在尝试加载本地图像数据集并用它来训练我的模型。 我正在像这样加载数据集。

        data_load = tk.utils.image_dataset_from_directory(
            dir,
            labels="inferred",
            batch_size=128,
            image_size=image_shape,
            shuffle=True,
            seed=42,
            validation_split=0.2,
            subset="training",
        )

这里的dir是我的数据存储的本地路径。 当我使用这些数据来训练我的模型时,使用 model.fit ,我收到此错误。

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
c:\Users\HP\Desktop\SBU\Courses\spring23\ese577\Labs\Lab3\lenet.ipynb Cell 8 in 2
      1 epoch = 15
----> 2 hist = model.fit(data.train, batch_size=batch_size, epochs=epoch)

File c:\python10\lib\site-packages\keras\utils\traceback_utils.py:70, in filter_traceback..error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File c:\python10\lib\site-packages\tensorflow\python\eager\execute.py:52, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     50 try:
     51   ctx.ensure_initialized()
---> 52   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     53                                       inputs, attrs, num_outputs)
     54 except core._NotOkStatusException as e:
     55   if name is not None:

InvalidArgumentError: Graph execution error:

Number of channels inherent in the image must be 1, 3 or 4, was 2
     [[{{node decode_image/DecodeImage}}]]
     [[IteratorGetNext]] [Op:__inference_train_function_4955]

有趣的是,它总是在这个阶段抛出这个错误

Epoch 1/15
  9/124 [=>............................] - ETA: 6:36 - loss: 5.0484 - accuracy: 0.4852

当我搜索类似的错误时,我发现它通常是在读取bmp图像时观察到的。我所有的图像都是 jpg,但我仍然收到此错误。

如何修复此错误,或者如何识别不良图像,以便我可以将其从数据集中删除并继续我的训练?

tensorflow keras
2个回答
0
投票

您必须过滤掉导致问题的任何图像。我使用下面的代码来处理目录中的所有图像并检测有缺陷的图像,以便我可以从数据集中删除它们。

import os
import cv2
from tqdm import tqdm
datadir=r'c:\datasets\autism\test'# path to  directory with class sub directories holding the image files
bad_img_list=[] # will be a list of defective images
classes=sorted(os.listdir(datadir)) # a list of classes  within the datadir
for klass in classes: # iterate through each class
    classpath=os.path.join(datadir, klass)
    flist=sorted(os.listdir(classpath)) # list of files in the current class   
    for f in tqdm(flist, ncols=100, unit='files', colour='blue', desc=klass): # iterate through the files
        fpath=os.path.join(classpath,f) # path to image file
        try:
            index=f.rfind('.') # find the rightmost . in f
            ext=f[index+1:].lower() # get the files extension and convert to lower case
            good_ext=['jpg', 'jpeg', 'bmp', 'png']# list of allowable extension for image_dataset_from_directory
            if ext not in good_ext:
                raise ValueError('image had improper extension') # create an exception so the file will be appended to bad_img_list                
            img=cv2.imread(fpath) # read in the image
            shape=img.shape # get the image shape (height, width) or (height, width, channels)
            count=len(shape)
            if count == 2: # if shapeis (width, height) image is single channel
                channels=1
            else:
                channels=shape[2] # shapeis (width, height, channels)
            if channels == 2:
                raise ValueError('image had 2 channels') # create an exception so the file will be appended to bad_img_list
        except:
            bad_img_list.append(fpath) # append to bad_img_list if there is an exception
if len(bad_img_list) >0:
    print('below is a list of defective image filepaths')
    for f in bad_img_list:
        print (f)

0
投票

很奇怪,消息包含问题但不包含文件名,是吗?

我使用以下方法来清理我的训练集。由于它使用tensorflow来读取图像,因此在有问题的图像方面应该是一致的:

import glob
import os
import tensorflow
for fn in glob.glob("images/**/*"):
    with open(fn, 'rb') as f:
        try:
            tensorflow.io.decode_image(f.read())
        except Exception as e:
#            os.remove(fn) # Warning: Make sure the glob above doesn't match critical non-image files
            print(fn, e)
© www.soinside.com 2019 - 2024. All rights reserved.