当我使用keras VGG16预训练模型在tf.dataset.map函数中获得瓶颈功能时,出现了此问题
vgg16 = VGG16(include_top=False, input_shape=input_shape)
vgg16.trainable = False
def _parse_function(example_proto):
features = {'img': tf.FixedLenFeature((), tf.string, default_value=""),
'label': tf.FixedLenFeature((), tf.int64, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
img = tf.decode_raw(parsed_features['img'], tf.uint8)
img = tf.reshape(img, (224, 224, 3))
img = tf.cast(img, tf.float32) / 255.0
img = tf.expand_dims(img, 0)
img = vgg16(img)
img = tf.squeeze(img, [0])
label = parsed_features['label']
return img, label
ds.map(_parse_function)
错误打印是
2019-06-19 13:47:19.378807:E tensorflow / core / common_runtime / executor.cc:624]执行程序无法创建内核。无效的参数:默认MaxPoolingOp仅在设备类型CPU上支持NHWC[[{{node vgg16 / block1_pool / MaxPool}}]]2019-06-19 13:47:19.379204:W tensorflow / core / framework / op_kernel.cc:1401] OP_REQUIRES在iterator_ops.cc:660失败:无效参数:默认MaxPoolingOp仅在设备类型CPU上支持NHWC[[{{node vgg16 / block1_pool / MaxPool}}]]追溯(最近一次通话):_do_call中的文件“ /opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py”,行1334返回fn(* args)_run_fn中的文件“ /opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py”,第1319行选项,feed_dict,fetch_list,target_list,run_metadata)_call_tf_sessionrun中的文件“ /opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py”,行1407run_metadata)tensorflow.python.framework.errors_impl.InvalidArgumentError:默认MaxPoolingOp仅在设备类型CPU上支持NHWC[[{{node vgg16 / block1_pool / MaxPool}}]][[{{node MakeIterator}}]]
在处理以上异常期间,发生了另一个异常:
追踪(最近通话):在第103行中输入文件“ /home/tf/image_cls/train.py”steps_per_epoch = steps_per_epoch,适合的文件“ /opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py”,行776shuffle = shuffle)_standardize_user_data中第2200行的文件“ /opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py”K.get_session()。run(x.initializer)运行中的文件“ /opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py”,第929行run_metadata_ptr)_run中的文件“ /opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py”,行1152feed_dict_tensor,选项,run_metadata)_do_run中的文件“ /opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py”,行1328run_metadata)_do_call中的文件“ /opt/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py”,行1348提高类型(e)(node_def,op,消息)tensorflow.python.framework.errors_impl.InvalidArgumentError:默认MaxPoolingOp仅在设备类型CPU上支持NHWC[[{{node vgg16 / block1_pool / MaxPool}}]][[node MakeIterator(定义于/home/tf/image_cls/train.py:103)]]
我有你的粗俗问题,您找到解决方案了吗?