你可以访问GitHub,整体是相似的,只是架构和数据集不同
这是我的代码
火车模型
epochs = 10
mc = ModelCheckpoint('sequential', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=2)
history=model.fit(x=train_gen, epochs=epochs, validation_data=valid_gen)
def print_info( test_gen, preds, print_code, save_dir, subject ):
class_dict=test_gen.class_indices
labels= test_gen.labels
file_names= test_gen.filenames
error_list=[]
true_class=[]
pred_class=[]
prob_list=[]
new_dict={}
error_indices=[]
y_pred=[]
for key,value in class_dict.items():
new_dict[value]=key # dictionary {integer of class number: string of class name}
# store new_dict as a text fine in the save_dir
classes=list(new_dict.values()) # list of string of class names
dict_as_text=str(new_dict)
dict_name= subject + '-' +str(len(classes)) +'.txt'
dict_path=os.path.join(save_dir,dict_name)
with open(dict_path, 'w') as x_file:
x_file.write(dict_as_text)
errors=0
for i, p in enumerate(preds):
pred_index=np.argmax(p)
true_index=labels[i] # labels are integer values
if pred_index != true_index: # a misclassification has occurred
error_list.append(file_names[i])
true_class.append(new_dict[true_index])
pred_class.append(new_dict[pred_index])
prob_list.append(p[pred_index])
error_indices.append(true_index)
errors=errors + 1
y_pred.append(pred_index)
if print_code !=0:
if errors>0:
if print_code>errors:
r=errors
else:
r=print_code
msg='{0:^28s}{1:^28s}{2:^28s}{3:^16s}'.format('Filename', 'Predicted Class' , 'True Class', 'Probability')
print_in_color(msg, (0,255,0),(55,65,80))
for i in range(r):
split1=os.path.split(error_list[i])
split2=os.path.split(split1[0])
fname=split2[1] + '/' + split1[1]
msg='{0:^28s}{1:^28s}{2:^28s}{3:4s}{4:^6.4f}'.format(fname, pred_class[i],true_class[i], ' ', prob_list[i])
print_in_color(msg, (255,255,255), (55,65,60))
#print(error_list[i] , pred_class[i], true_class[i], prob_list[i])
else:
msg='With accuracy of 100 % there are no errors to print'
print_in_color(msg, (0,255,0),(55,65,80))
if errors>0:
plot_bar=[]
plot_class=[]
for key, value in new_dict.items():
count=error_indices.count(key)
if count!=0:
plot_bar.append(count) # list containg how many times a class c had an error
plot_class.append(value) # stores the class
fig=plt.figure()
fig.set_figheight(len(plot_class)/3)
fig.set_figwidth(10)
plt.style.use('fivethirtyeight')
for i in range(0, len(plot_class)):
c=plot_class[i]
x=plot_bar[i]
plt.barh(c, x, )
plt.title( ' Errors by Class on Test Set')
y_true= np.array(labels)
y_pred=np.array(y_pred)
if len(classes)<= 30:
# create a confusion matrix
cm = confusion_matrix(y_true, y_pred )
length=len(classes)
if length<8:
fig_width=8
fig_height=8
else:
fig_width= int(length * .5)
fig_height= int(length * .5)
plt.figure(figsize=(fig_width, fig_height))
sns.heatmap(cm, annot=True, vmin=0, fmt='g', cmap='Blues', cbar=False)
plt.xticks(np.arange(length)+.5, classes, rotation= 90)
plt.yticks(np.arange(length)+.5, classes, rotation=0)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()
clr = classification_report(y_true, y_pred, target_names=classes)
print("Classification Report:\n----------------------\n", clr)
混淆矩阵
print_code=0
preds=model.predict(test_gen)
print_info( test_gen, preds, print_code, save_dir, subject )
我不仅想显示表输出和召回率、精度和 f1 得分值,还想显示 CM 预测的每个图像的可视化,也许像上面的图像一样,或者可能会更好。
例如,在“Daun Sehat”表中,有一个数据样本被预测为“Karat Merah”,但如果没有可视化,我不知道“Daun Sehat”的哪个图像样本被检测为“Karat Merah”该信息
在最后一行代码“对测试集进行预测并生成混淆矩阵和分类报告”之后,我添加了如下代码来找出每个测试数据的预测
test_gen.class_indices
print(preds,preds.shape)
result_index = np.argmax(preds[])
print(result_index)
for i in range(len(preds)):
if(np.argmax(preds[i]) == 0):
print("Bercak Daun")
elif(np.argmax(preds[i]) == 1):
print("Daun Sehat")
elif(np.argmax(preds[i]) == 2):
print("Karat Merah")
else:
print("Lainya")
输出
Bercak Daun
Daun Sehat
.
.
.
up to 177
Daun Sehat
不仅仅是用字符串显示预测,我还想将其与图像一起显示。也许有更好、更有效的代码来解决我的问题?
目前还不清楚如何绘制混淆矩阵,但您可以迭代
preds
或 test_gen
并在标签不同时绘制样本和输出。
由于您没有显示模型使用的输入类型,因此要使用的显示方法取决于您。
代码看起来像这样:
# You should have before :
# x_test : test dataset
# y_true : labels of x_test
y_pred=model.predict(x_test)
# First loop to sort falty predictions from correct ones
wrong_labels = []
for label, i in enumerate(y_pred) : # The array we iterate on doesn't matter, as they all are the same length
if label != y_true[i] :
wrong_labels.append([x_test, y_true, y_pred])
# Second loop to display errors, the limit to 4 samples is arbitrary
for x, y_true, y_pred, i in enumerate(wrong_labels) :
if i>=4 :
break
plt.subplot(1,4,i+1)
display_sample(x) # This function depends on the type of sample you have
plt.title("Label " + str(y_true) + " expected, but predicted "+str(y_pred))
plt.show()