如何从tensorflow数据集中绘制类名?

问题描述 投票:0回答:1

我从TensorFlow网站上获取了一个功能,可以在笔记本中显示一批图像。我想以上面图像类别在website上显示的方式打印它。这是该函数的代码:

def show_batch(image_batch, label_batch):
plt.figure(figsize=(10,10))
for n in range(25):
    ax = plt.subplot(5,5,n+1)
    plt.imshow(image_batch[n])
    plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
    plt.axis('off')

问题出在plt.title ...行。我收到错误消息:无法将1转换为dtype bool的EagerTensor

我没有得到什么问题,因为我完全按照网站教程中的方式处理了数据。

标签返回一个形状数组:[False False True True False],并应据此打印类名(我有4个类)。但事实并非如此。该函数的其余部分工作正常,但仅显示图像而不显示每个图像所属的类的名称是没有用的。

python tensorflow tensorflow2.0 tensor tensorflow-datasets
1个回答
0
投票

我没有找到执行此操作的漂亮方法,因此我使用了一个附加的for循环。我遍历了标签批处理,并用真实值保存了索引。

def show_batch(image_batch, label_batch):
plt.figure(figsize=(10,10))
for n in range(25):
    ax = plt.subplot(5,5,n+1)
    plt.imshow(np.squeeze(image_batch[n]), cmap = 'gray')
    ix = 0
    for a in label_batch[n]:
        if a == 1:
            break;
        else:
            ix+=1
    plt.title(CLASS_NAMES[ix].title())
    plt.axis('off')  

仅通过示例使其更清楚:

  • class_names如下[class1,class2,class3,class4]
  • label_batch是另一个数组[false,false,true,false]
  • 在这种情况下,正确的索引是2(计数从0开始),我需要的类是class3
© www.soinside.com 2019 - 2024. All rights reserved.