如何对 pandas 数据框中的两列进行排序并从结果数据框中绘制热图?

问题描述 投票:0回答:1

我正在尝试从下面的数据框中绘制热图:

方法 sub_1 sub_2 sub_3 sub_4 ...... 距离 count_0
m1 0 2 2 1 ...... 0 20
m2 1 2 2 1 ...... 7 22
m3 1 2 2 1 ...... 26 12
m4 0 2 2 0 ...... 21 10
m5 0 2 2 0 ...... 17 5

在此数据框中,

methods
是应用的聚类方法,
sub_*
是客户和相关值是每种方法分配的标签,
distance
是每种方法距总体标签分配的距离(来自集成结果,参见附图) ,
count_0
表示集群 0 中分配了多少个客户。

现在,我想使用

distance
count_0
列对数据框进行排序并绘制热图。我尝试使用以下代码生成热图,但只有一个键,即
distance
。当按两个键排序时,我希望在热图中将所有红色放在一起,白色放在一起,黑色放在一起(参见附图)。

fig, ax = plt.subplots(figsize=(40,6))
df = df.sort_values(by=['distance', 'count_0'], ascending=[True, True])

svm = sns.heatmap(df.iloc[:,0:546], ax=ax, xticklabels=False, yticklabels=True, linewidths=0.4, annot_kws={"size": 16}, cbar_kws={"shrink": 0.3})

ax.set_xlabel("Customer IDs", fontsize=30)
ax.hlines(y = 1, xmin = 0, xmax = 550, colors = 'white', lw = 10)

for i in range(2,10):
    ax.hlines(y = i, xmin = 0, xmax = 550, colors = 'white', lw = 4)

figure = svm.get_figure()
ax.tick_params(axis='both', which='major', labelsize=30, labelbottom = False, bottom=False, top = False, labeltop=True)
plt.xticks(rotation=90)
ax.tick_params(labelsize=30)
plt.show()

生成以下输出: enter image description here

python pandas dataframe matplotlib heatmap
1个回答
0
投票

您需要在绘图之前对数据帧进行相应的排序并应用正确的颜色图。因此

distance
count_0
列并保持颜色分组。这是一个例子:

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

data = {
    'methods': ['m1', 'm2', 'm3', 'm4', 'm5'],
    'sub_1': [0, 1, 1, 0, 0],
    'sub_2': [2, 2, 2, 2, 2],
    'sub_3': [2, 2, 2, 2, 2],
    'sub_4': [1, 1, 1, 0, 0],
    'distance': [0, 7, 26, 21, 17],
    'count_0': [20, 22, 12, 10, 5]
}
df = pd.DataFrame(data)

df = df.sort_values(by=['distance', 'count_0'], ascending=[True, True])

from matplotlib.colors import LinearSegmentedColormap

cmap = LinearSegmentedColormap.from_list("custom_cmap", ["black", "white", "red"], N=256)

# Set up the heatmap
fig, ax = plt.subplots(figsize=(10, 6))  

svm = sns.heatmap(df.iloc[:, 1:5],  
                  ax=ax,
                  cmap=cmap,
                  xticklabels=True,  
                  yticklabels=df['methods'],  
                  linewidths=0.4,
                  annot_kws={"size": 16},
                  cbar_kws={"shrink": 0.3})

ax.set_xlabel("Customer IDs", fontsize=15)
ax.set_ylabel("Clustering Methods", fontsize=15)

for i in range(1, len(df)):
    ax.hlines(y=i, xmin=0, xmax=len(df.columns) - 1, colors='white', lw=2)

# Adjust tick parameters
ax.tick_params(axis='both', which='major', labelsize=12, labelbottom=False, bottom=False, top=False, labeltop=True)
plt.xticks(rotation=90)
plt.show()

这给出了

enter image description here

© www.soinside.com 2019 - 2024. All rights reserved.