我从TensorFlow网站上获取了一个功能,可以在笔记本中显示一批图像。我想以上面图像类别在website上显示的方式打印它。这是该函数的代码:
def show_batch(image_batch, label_batch):
plt.figure(figsize=(10,10))
for n in range(25):
ax = plt.subplot(5,5,n+1)
plt.imshow(image_batch[n])
plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
plt.axis('off')
问题出在plt.title ...行。我收到错误消息:无法将1转换为dtype bool的EagerTensor
我没有得到什么问题,因为我完全按照网站教程中的方式处理了数据。
标签返回一个形状数组:[False False True True False],并应据此打印类名(我有4个类)。但事实并非如此。该函数的其余部分工作正常,但仅显示图像而不显示每个图像所属的类的名称是没有用的。
我没有找到执行此操作的漂亮方法,因此我使用了一个附加的for循环。我遍历了标签批处理,并用真实值保存了索引。
def show_batch(image_batch, label_batch):
plt.figure(figsize=(10,10))
for n in range(25):
ax = plt.subplot(5,5,n+1)
plt.imshow(np.squeeze(image_batch[n]), cmap = 'gray')
ix = 0
for a in label_batch[n]:
if a == 1:
break;
else:
ix+=1
plt.title(CLASS_NAMES[ix].title())
plt.axis('off')
仅通过示例使其更清楚: