如何为子图中的seaborn热图获取具有共享x轴和y轴的单个颜色条?

问题描述 投票:0回答:1

我想在一个图中用一个颜色条和共享的 x 轴和 y 轴绘制多个混淆矩阵。这是我到目前为止尝试过的代码

#Calculate the onfusion matrices
predicted_mod1 = df_binary["Model1"]
actual_class = df_binary["Observed"]

out_df_mod1 = pd.DataFrame(np.vstack([predicted_mod1, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod1 = pd.crosstab(out_df_mod1['actual_class'], out_df_mod1['predicted_class'], rownames=['Actual'], colnames=['Predicted'])

predicted_mod2 = df_binary["Model2"]

out_df_mod2 = pd.DataFrame(np.vstack([predicted_mod2, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod2 = pd.crosstab(out_df_mod2['actual_class'], out_df_mod2['predicted_class'], rownames=['Actual'], colnames=['Predicted'])

predicted_mod4 = df_binary["Model4"]

out_df_mod4 = pd.DataFrame(np.vstack([predicted_mod4, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod4 = pd.crosstab(out_df_mod4['actual_class'], out_df_mod4['predicted_class'], rownames=['Actual'], colnames=['Predicted'])

predicted_mod5 = df_binary["Model5"]

out_df_mod5 = pd.DataFrame(np.vstack([predicted_mod5, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod5 = pd.crosstab(out_df_mod5['actual_class'], out_df_mod5['predicted_class'], rownames=['Actual'], colnames=['Predicted'])

predicted_mod6 = df_binary["Model6"]

out_df_mod6 = pd.DataFrame(np.vstack([predicted_mod6, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod6 = pd.crosstab(out_df_mod6['actual_class'], out_df_mod6['predicted_class'], rownames=['Actual'], colnames=['Predicted'])

现在我已经使用以下代码绘制了这些矩阵

fig = plt.figure(figsize=(6, 3), dpi=300)
fig.subplots_adjust(hspace=0.8, wspace=0.6)

ax = fig.add_subplot(2, 3, 1)
sns.heatmap(CF_mod1, cmap='Blues', annot=True, fmt='d')

ax = fig.add_subplot(2, 3, 2)
sns.heatmap(CF_mod2, cmap='Blues', annot=True, fmt='d')

ax = fig.add_subplot(2, 3, 3)
sns.heatmap(CF_mod3, cmap='Blues', annot=True, fmt='d')

ax = fig.add_subplot(2, 3, 4)
sns.heatmap(CF_mod4, cmap='Blues', annot=True, fmt='d')

ax = fig.add_subplot(2, 3, 5)
sns.heatmap(CF_mod5, cmap='Blues', annot=True, fmt='d')

ax = fig.add_subplot(2, 3, 6)
sns.heatmap(CF_mod6, cmap='Blues', annot=True, fmt='d')

plt.show()

enter image description here 我的预期输出如下 enter image description here 现在我怎样才能只有一个带有共享 x 轴和 y 轴的颜色条?

数据

Model1,Model2,Model3,Model4,Model5,Model6,Observed
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
No,No,No,No,No,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,No,Yes,No,Yes,Yes
No,Yes,No,No,No,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,No,No,No,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,Yes,No,No,No,Yes,No
No,Yes,No,No,No,Yes,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,Yes,No,Yes,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
python matplotlib seaborn pivot-table heatmap
1个回答
0
投票

首先,您要确保所有绘图具有相同的颜色比例,因此您需要获取全局最小值和最大值。然后,您可以循环遍历子图(使用

sharex=True
sharey=True
来删除不在边缘上的图中的刻度),绘制数据,并删除
ylabel
(如果它不在第一列中)和/或
xlabel 
如果它不在最后一行。最后,在末尾创建一个具有全局比例的
colorbar

import matplotlib as mpl

# your other code

nrows = 2
ncols = 3
fig, axes = plt.subplots(nrows, ncols, sharex=True, sharey=True)
cbar_ax = fig.add_axes([0.91, 0.3, 0.03, 0.4])

data = [CF_mod1, CF_mod2, CF_mod3, CF_mod4, CF_mod5, CF_mod6]

# get global min and max to enforce the same colorscale in all plots
vmin = min([d.min().min() for d in data])
vmax = max([d.max().max() for d in data])

for i, (ax, d) in enumerate(zip(axes.flat, data)):
    p = sns.heatmap(d, ax=ax, annot=True,
                    vmin=vmin, vmax=vmax,
                    cmap="Blues", cbar=False)
    # remove ylabel if not in the first column
    if i%ncols:
        ax.set_ylabel("")
    # remove xlabel if not in the last row
    if i//ncols + 1 != nrows:
        ax.set_xlabel("")

# colorbar with the desired colorscale
# https://stackoverflow.com/a/3374216/12131013
cbar = mpl.colorbar.ColorbarBase(cbar_ax, cmap="Blues",
                                 norm=mpl.colors.Normalize(vmin=vmin,
                                                           vmax=vmax))
fig.show()

结果:

对于轴标签,您还可以使用超标签并删除各个轴标签。

for i, (ax, d) in enumerate(zip(axes.flat, data)):
    p = sns.heatmap(d, ax=ax, annot=True,
                    vmin=vmin, vmax=vmax,
                    cmap="Blues", cbar=False)
    ax.set_xlabel("")
    ax.set_ylabel("")
fig.supxlabel("Predicted")
fig.supylabel("Actual")

结果:

© www.soinside.com 2019 - 2024. All rights reserved.