我正在尝试从下面的数据框中绘制热图:
方法 | sub_1 | sub_2 | sub_3 | sub_4 | ...... | 距离 | count_0 |
---|---|---|---|---|---|---|---|
m1 | 0 | 2 | 2 | 1 | ...... | 0 | 20 |
m2 | 1 | 2 | 2 | 1 | ...... | 7 | 22 |
m3 | 1 | 2 | 2 | 1 | ...... | 26 | 12 |
m4 | 0 | 2 | 2 | 0 | ...... | 21 | 10 |
m5 | 0 | 2 | 2 | 0 | ...... | 17 | 5 |
在此数据框中,
methods
是应用的聚类方法,sub_*
是客户和相关值是每种方法分配的标签,distance
是每种方法距总体标签分配的距离(来自集成结果,参见附图) , count_0
表示集群 0 中分配了多少个客户。
现在,我想使用
distance
和 count_0
列对数据框进行排序并绘制热图。我尝试使用以下代码生成热图,但只有一个键,即distance
。当按两个键排序时,我希望在热图中将所有红色放在一起,白色放在一起,黑色放在一起(参见附图)。
fig, ax = plt.subplots(figsize=(40,6))
df = df.sort_values(by=['distance', 'count_0'], ascending=[True, True])
svm = sns.heatmap(df.iloc[:,0:546], ax=ax, xticklabels=False, yticklabels=True, linewidths=0.4, annot_kws={"size": 16}, cbar_kws={"shrink": 0.3})
ax.set_xlabel("Customer IDs", fontsize=30)
ax.hlines(y = 1, xmin = 0, xmax = 550, colors = 'white', lw = 10)
for i in range(2,10):
ax.hlines(y = i, xmin = 0, xmax = 550, colors = 'white', lw = 4)
figure = svm.get_figure()
ax.tick_params(axis='both', which='major', labelsize=30, labelbottom = False, bottom=False, top = False, labeltop=True)
plt.xticks(rotation=90)
ax.tick_params(labelsize=30)
plt.show()
您需要在绘图之前对数据帧进行相应的排序并应用正确的颜色图。因此
distance
和 count_0
列并保持颜色分组。这是一个例子:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
data = {
'methods': ['m1', 'm2', 'm3', 'm4', 'm5'],
'sub_1': [0, 1, 1, 0, 0],
'sub_2': [2, 2, 2, 2, 2],
'sub_3': [2, 2, 2, 2, 2],
'sub_4': [1, 1, 1, 0, 0],
'distance': [0, 7, 26, 21, 17],
'count_0': [20, 22, 12, 10, 5]
}
df = pd.DataFrame(data)
df = df.sort_values(by=['distance', 'count_0'], ascending=[True, True])
from matplotlib.colors import LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list("custom_cmap", ["black", "white", "red"], N=256)
# Set up the heatmap
fig, ax = plt.subplots(figsize=(10, 6))
svm = sns.heatmap(df.iloc[:, 1:5],
ax=ax,
cmap=cmap,
xticklabels=True,
yticklabels=df['methods'],
linewidths=0.4,
annot_kws={"size": 16},
cbar_kws={"shrink": 0.3})
ax.set_xlabel("Customer IDs", fontsize=15)
ax.set_ylabel("Clustering Methods", fontsize=15)
for i in range(1, len(df)):
ax.hlines(y=i, xmin=0, xmax=len(df.columns) - 1, colors='white', lw=2)
# Adjust tick parameters
ax.tick_params(axis='both', which='major', labelsize=12, labelbottom=False, bottom=False, top=False, labeltop=True)
plt.xticks(rotation=90)
plt.show()
这给出了