如 seaborn histplot - 在每个条形上方打印 y 值中所述,可以使用
.bar_label(ax.containers[0])
在一维直方图中的每个条形上显示计数。
我正在努力弄清楚如何对二维直方图(使用
sns.histplot(data, x='var1', y='var2')
创建)进行等效操作。
我知道我可以使用
(a,b)
为 .annotate('foo', xy=(a, b))
bin 进行注释,但我不确定如何检索该 bin 的计数(将其传递给 .annotate()
)。
我希望结果与 https://seaborn.pydata.org/examples/spreadsheet_heatmap.html 中显示的结果类似,只是它是一个直方图,而不是热图。
np.histogram2d
您可以利用
np.histogram2d()
进行计数,然后通过sns.heatmap()
显示结果:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
np.random.seed(20240529)
x = np.random.randn(1000).cumsum()
y = np.random.randn(1000).cumsum()
hist, xbins, ybins = np.histogram2d(x, y)
xlabels = [f'{x:.2f}' for x in (xbins[:-1] + xbins[1:]) / 2]
ylabels = [f'{y:.2f}' for y in (ybins[:-1] + ybins[1:]) / 2]
sns.heatmap(hist.T, xticklabels=xlabels, yticklabels=ylabels, annot=True, fmt='.0f', cbar=False)
sns.histplot
正如Peter的回答中提到的,计数和位置信息也可以从
QuadMesh
创建的sns.histplot
中提取。这是一个概括。
在
.get_array()
上调用 QuadMesh
会给出一个包含每个单元格计数的掩码数组。这是一个二维矩阵。该矩阵的第一行是图中第一行(最低 y 值)的计数。
同样,
.get_coordinates()
给出位置。这些不是细胞中心的位置,而是它们的边缘。边缘比单元格多 1 行和 1 列。坐标被组织为 xy 值的行和列(形成 3d 数组)。
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
np.random.seed(20240529)
x = np.random.randn(1000).cumsum()
y = np.random.randn(1000).cumsum()
ax = sns.histplot(x=x, y=y, bins=(10, 10), cbar=False)
coords = ax.collections[0].get_coordinates()
half_width = (coords[0, 1, 0] - coords[0, 0, 0]) / 2
half_height = (coords[1, 0, 1] - coords[0, 0, 1]) / 2
for v, (xv, yv) in zip(ax.collections[0].get_array().ravel(), coords[:-1, :-1, :].reshape(-1, 2)):
if not np.ma.is_masked(v):
ax.text(xv + half_width, yv + half_height, f'{v:.0f}', ha='center', va='center', color='white')
这是分类数据的样子:
sns.set_style('white')
titanic = sns.load_dataset('titanic')
ax = sns.histplot(data=titanic, x='who', y='class', cbar=False)
coords = ax.collections[0].get_coordinates()
half_width = (coords[0, 1, 0] - coords[0, 0, 0]) / 2
half_height = (coords[1, 0, 1] - coords[0, 0, 1]) / 2
for v, (xv, yv) in zip(ax.collections[0].get_array().ravel(), coords[:-1, :-1, :].reshape(-1, 2)):
if not np.ma.is_masked(v):
ax.text(xv + half_width, yv + half_height, f'{v:.0f}', ha='center', va='center', color='white')
ax.margins(x=0, y=0) # remove unneeded whitespace
plt.tight_layout()
我发现可以从
histplot()
的.collections
属性中提取计数,然后从.annotate()
'd中提取,如下:
import numpy as np
import seaborn as sns
ax = sns.histplot(data, x='var1', y='var2', cbar=True)
w = ax.collections[0].get_coordinates().shape[1] - 1
for k, v in enumerate(ax.collections[0].get_array()):
if not np.ma.is_masked(v):
ax.annotate(v, xy=(k % w, k // w), ha='center', color='white')