seaborn 计数图,仅计算低于和高于阈值的数据点总数

问题描述 投票:0回答:1

这是我的数据结构:

betterTrue = [random.randint(0,1) for x in range(500)]
betterFalse = [(x + 1) % 2 for x in betterTrue]

data = {
    "model": ["A" for x in range(500)] + ["B" for x in range(500)],
    "safety": [random.randint(0,4) for x in range(1000)],
    "honesty": [random.randint(0,4) for x in range(1000)],
    "quality": [random.randint(0,4) for x in range(1000)],
    "better": betterTrue + betterFalse
}

我想生成计数图,比较每个模型在

safety
honesty
quality
better
列中的性能。对于前三个,数据为 0 到 5 之间的整数值,对于
better
,数据为
0
1

但是对于前三列,我只关心数据点是否大于或等于 3 或小于 3。有没有办法生成一个计数图,将数据放入两个 bin

>= 3
< 3

作为参考,这就是我们不这样做时的样子,而是按每个可能的值离散地进行分箱

fig = sns.countplot(x = 'safety', hue='model', data=data, stat='count')

python seaborn
1个回答
0
投票

有很多方法可以将值范围转换为少数值。纯 Python 方式是

data["good quality"] = [0 if q < 3 else 1 for q in data["quality"]]
。除了
0
1
,您还可以使用字符串。

当处理大量数字时,numpy 提供速度和强大的数组表示法

下面的代码使用 numpy 创建随机数组(由于某些神秘的原因,

random.randint()
np.random.randint()
对第二个参数的解释不同)。该代码使用
pd.cut()
有效地将数字范围映射到字符串。

plt.subplots()
创建具有多个子图的图形。

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

betterTrue = np.random.randint(0, 2, 500)
betterFalse = 1 - betterTrue

data = {
    "model": np.repeat(["A", "B"], 500),
    "safety": np.random.randint(0, 5, 1000),
    "honesty": np.random.randint(0, 5, 1000),
    "quality": np.random.randint(0, 5, 1000),
    "better": np.concatenate([betterTrue, betterFalse])
}
# add a new column, with "bad" for quality 0 till 2 and "good" for quality 3 or 4
data["good quality"] = pd.cut(data["quality"], bins=[0, 2, 4], labels=["bad", "good"])

columns_to_plot = [key for key in data.keys() if key != "quality"]

# choose a seaborn style
sns.set_style('whitegrid')
# create a figure with subplots
fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(12, 7))

# flatten the array of axes, so it can be indexed as a 1D array
axs = axs.flatten()

# create a countplot for each of the desired columns
for column, ax in zip(columns_to_plot, axs):
    sns.countplot(data, x=column, ax=ax)

# hide unused subplots
for ax in axs[len(columns_to_plot):]:
    ax.remove()

# remove upper and right spine
sns.despine()

plt.tight_layout()
plt.show()

sns.countplot on multiple columns, one column with reduced values

© www.soinside.com 2019 - 2024. All rights reserved.