我想在一个图中用一个颜色条和共享的 x 轴和 y 轴绘制多个混淆矩阵。这是我到目前为止尝试过的代码
#Calculate the onfusion matrices
predicted_mod1 = df_binary["Model1"]
actual_class = df_binary["Observed"]
out_df_mod1 = pd.DataFrame(np.vstack([predicted_mod1, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod1 = pd.crosstab(out_df_mod1['actual_class'], out_df_mod1['predicted_class'], rownames=['Actual'], colnames=['Predicted'])
predicted_mod2 = df_binary["Model2"]
out_df_mod2 = pd.DataFrame(np.vstack([predicted_mod2, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod2 = pd.crosstab(out_df_mod2['actual_class'], out_df_mod2['predicted_class'], rownames=['Actual'], colnames=['Predicted'])
predicted_mod4 = df_binary["Model4"]
out_df_mod4 = pd.DataFrame(np.vstack([predicted_mod4, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod4 = pd.crosstab(out_df_mod4['actual_class'], out_df_mod4['predicted_class'], rownames=['Actual'], colnames=['Predicted'])
predicted_mod5 = df_binary["Model5"]
out_df_mod5 = pd.DataFrame(np.vstack([predicted_mod5, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod5 = pd.crosstab(out_df_mod5['actual_class'], out_df_mod5['predicted_class'], rownames=['Actual'], colnames=['Predicted'])
predicted_mod6 = df_binary["Model6"]
out_df_mod6 = pd.DataFrame(np.vstack([predicted_mod6, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod6 = pd.crosstab(out_df_mod6['actual_class'], out_df_mod6['predicted_class'], rownames=['Actual'], colnames=['Predicted'])
现在我已经使用以下代码绘制了这些矩阵
fig = plt.figure(figsize=(6, 3), dpi=300)
fig.subplots_adjust(hspace=0.8, wspace=0.6)
ax = fig.add_subplot(2, 3, 1)
sns.heatmap(CF_mod1, cmap='Blues', annot=True, fmt='d')
ax = fig.add_subplot(2, 3, 2)
sns.heatmap(CF_mod2, cmap='Blues', annot=True, fmt='d')
ax = fig.add_subplot(2, 3, 3)
sns.heatmap(CF_mod3, cmap='Blues', annot=True, fmt='d')
ax = fig.add_subplot(2, 3, 4)
sns.heatmap(CF_mod4, cmap='Blues', annot=True, fmt='d')
ax = fig.add_subplot(2, 3, 5)
sns.heatmap(CF_mod5, cmap='Blues', annot=True, fmt='d')
ax = fig.add_subplot(2, 3, 6)
sns.heatmap(CF_mod6, cmap='Blues', annot=True, fmt='d')
plt.show()
我的预期输出如下 现在我怎样才能只有一个带有共享 x 轴和 y 轴的颜色条?
数据
Model1,Model2,Model3,Model4,Model5,Model6,Observed
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
No,No,No,No,No,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,No,Yes,No,Yes,Yes
No,Yes,No,No,No,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,No,No,No,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,Yes,No,No,No,Yes,No
No,Yes,No,No,No,Yes,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,Yes,No,Yes,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
首先,您要确保所有绘图具有相同的颜色比例,因此您需要获取全局最小值和最大值。然后,您可以循环遍历子图(使用
sharex=True
和 sharey=True
来删除不在边缘上的图中的刻度),绘制数据,并删除 ylabel
(如果它不在第一列中)和/或 xlabel
如果它不在最后一行。最后,在末尾创建一个具有全局比例的 colorbar
。
import matplotlib as mpl
# your other code
nrows = 2
ncols = 3
fig, axes = plt.subplots(nrows, ncols, sharex=True, sharey=True)
cbar_ax = fig.add_axes([0.91, 0.3, 0.03, 0.4])
data = [CF_mod1, CF_mod2, CF_mod3, CF_mod4, CF_mod5, CF_mod6]
# get global min and max to enforce the same colorscale in all plots
vmin = min([d.min().min() for d in data])
vmax = max([d.max().max() for d in data])
for i, (ax, d) in enumerate(zip(axes.flat, data)):
p = sns.heatmap(d, ax=ax, annot=True,
vmin=vmin, vmax=vmax,
cmap="Blues", cbar=False)
# remove ylabel if not in the first column
if i%ncols:
ax.set_ylabel("")
# remove xlabel if not in the last row
if i//ncols + 1 != nrows:
ax.set_xlabel("")
# colorbar with the desired colorscale
# https://stackoverflow.com/a/3374216/12131013
cbar = mpl.colorbar.ColorbarBase(cbar_ax, cmap="Blues",
norm=mpl.colors.Normalize(vmin=vmin,
vmax=vmax))
fig.show()
结果:
对于轴标签,您还可以使用超标签并删除各个轴标签。
for i, (ax, d) in enumerate(zip(axes.flat, data)):
p = sns.heatmap(d, ax=ax, annot=True,
vmin=vmin, vmax=vmax,
cmap="Blues", cbar=False)
ax.set_xlabel("")
ax.set_ylabel("")
fig.supxlabel("Predicted")
fig.supylabel("Actual")
结果: