如何使用我的代码可视化预测样本需要更多答案

问题描述 投票:0回答:1

如何可视化用于测试使用混淆矩阵创建的模型的样本?例如如下。 enter image description here

你可以访问GitHub,整体是相似的,只是架构和数据集不同

https://github.com/cendekialnazalia/CaisimPestDetection/blob/main/Percobaan%20E%20-%20CNN%20add%20Models%20Xception.ipynb

这是我的代码

火车模型

epochs = 10

mc = ModelCheckpoint('sequential', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)  
early_stopping = EarlyStopping(monitor='val_loss', patience=2)
history=model.fit(x=train_gen, epochs=epochs, validation_data=valid_gen)
def print_info( test_gen, preds, print_code, save_dir, subject ):
    class_dict=test_gen.class_indices
    labels= test_gen.labels
    file_names= test_gen.filenames 
    error_list=[]
    true_class=[]
    pred_class=[]
    prob_list=[]
    new_dict={}
    error_indices=[]
    y_pred=[]
    for key,value in class_dict.items():
        new_dict[value]=key             # dictionary {integer of class number: string of class name}
    # store new_dict as a text fine in the save_dir
    classes=list(new_dict.values())     # list of string of class names
    dict_as_text=str(new_dict)
    dict_name= subject + '-' +str(len(classes)) +'.txt'  
    dict_path=os.path.join(save_dir,dict_name)    
    with open(dict_path, 'w') as x_file:
        x_file.write(dict_as_text)    
    errors=0      
    for i, p in enumerate(preds):
        pred_index=np.argmax(p)        
        true_index=labels[i]  # labels are integer values
        if pred_index != true_index: # a misclassification has occurred
            error_list.append(file_names[i])
            true_class.append(new_dict[true_index])
            pred_class.append(new_dict[pred_index])
            prob_list.append(p[pred_index])
            error_indices.append(true_index)            
            errors=errors + 1
        y_pred.append(pred_index)    
    if print_code !=0:
        if errors>0:
            if print_code>errors:
                r=errors
            else:
                r=print_code           
            msg='{0:^28s}{1:^28s}{2:^28s}{3:^16s}'.format('Filename', 'Predicted Class' , 'True Class', 'Probability')
            print_in_color(msg, (0,255,0),(55,65,80))
            for i in range(r):                
                split1=os.path.split(error_list[i])                
                split2=os.path.split(split1[0])                
                fname=split2[1] + '/' + split1[1]
                msg='{0:^28s}{1:^28s}{2:^28s}{3:4s}{4:^6.4f}'.format(fname, pred_class[i],true_class[i], ' ', prob_list[i])
                print_in_color(msg, (255,255,255), (55,65,60))
                #print(error_list[i]  , pred_class[i], true_class[i], prob_list[i])               
        else:
            msg='With accuracy of 100 % there are no errors to print'
            print_in_color(msg, (0,255,0),(55,65,80))
    if errors>0:
        plot_bar=[]
        plot_class=[]
        for  key, value in new_dict.items():        
            count=error_indices.count(key) 
            if count!=0:
                plot_bar.append(count) # list containg how many times a class c had an error
                plot_class.append(value)   # stores the class 
        fig=plt.figure()
        fig.set_figheight(len(plot_class)/3)
        fig.set_figwidth(10)
        plt.style.use('fivethirtyeight')
        for i in range(0, len(plot_class)):
            c=plot_class[i]
            x=plot_bar[i]
            plt.barh(c, x, )
            plt.title( ' Errors by Class on Test Set')
    y_true= np.array(labels)        
    y_pred=np.array(y_pred)
    if len(classes)<= 30:
        # create a confusion matrix 
        cm = confusion_matrix(y_true, y_pred )        
        length=len(classes)
        if length<8:
            fig_width=8
            fig_height=8
        else:
            fig_width= int(length * .5)
            fig_height= int(length * .5)
    
        plt.figure(figsize=(fig_width, fig_height))
        sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Blues', cbar=False)       
        plt.xticks(np.arange(length)+.5, classes, rotation= 90)
        plt.yticks(np.arange(length)+.5, classes, rotation=0)
        plt.xlabel("Predicted")
        plt.ylabel("Actual")
        plt.title("Confusion Matrix")
        plt.show()
    clr = classification_report(y_true, y_pred, target_names=classes)
    print("Classification Report:\n----------------------\n", clr)

混淆矩阵

print_code=0
preds=model.predict(test_gen) 
print_info( test_gen, preds, print_code, save_dir, subject ) 

输出 enter image description here

我不仅想显示表输出和召回率、精度和 f1 得分值,还想显示 CM 预测的每个图像的可视化,也许像上面的图像一样,或者可能会更好。

例如,在“Daun Sehat”表中,有一个数据样本被预测为“Karat Merah”,但如果没有可视化,我不知道“Daun Sehat”的哪个图像样本被检测为“Karat Merah”该信息

在最后一行代码“对测试集进行预测并生成混淆矩阵和分类报告”之后,我添加了如下代码来找出每个测试数据的预测

test_gen.class_indices
print(preds,preds.shape)
result_index = np.argmax(preds[])
print(result_index)
for i in range(len(preds)):
  if(np.argmax(preds[i]) == 0):
      print("Bercak Daun")
  elif(np.argmax(preds[i]) == 1):
      print("Daun Sehat")
  elif(np.argmax(preds[i]) == 2):
      print("Karat Merah")
  else:
      print("Lainya")

输出

Bercak Daun
Daun Sehat
.
.
.
up to 177
Daun Sehat

不仅仅是用字符串显示预测,我还想将其与图像一起显示。也许有更好、更有效的代码来解决我的问题?

python tensorflow machine-learning keras image-classification
1个回答
0
投票

目前还不清楚如何绘制混淆矩阵,但您可以迭代

preds
test_gen
并在标签不同时绘制样本和输出。

由于您没有显示模型使用的输入类型,因此要使用的显示方法取决于您。

代码看起来像这样:

# You should have before :
#   x_test : test dataset 
#   y_true : labels of x_test

y_pred=model.predict(x_test)

# First loop to sort falty predictions from correct ones
wrong_labels = []
for label, i in enumerate(y_pred) : # The array we iterate on doesn't matter, as they all are the same length
    if label != y_true[i] :
        wrong_labels.append([x_test, y_true, y_pred])

# Second loop to display errors, the limit to 4 samples is arbitrary

for x, y_true, y_pred, i in enumerate(wrong_labels) : 
    if i>=4 :
        break
    plt.subplot(1,4,i+1)
    display_sample(x)   # This function depends on the type of sample you have
    plt.title("Label " + str(y_true) + " expected, but predicted "+str(y_pred))

plt.show()
© www.soinside.com 2019 - 2024. All rights reserved.