这是我的数据结构:
betterTrue = [random.randint(0,1) for x in range(500)]
betterFalse = [(x + 1) % 2 for x in betterTrue]
data = {
"model": ["A" for x in range(500)] + ["B" for x in range(500)],
"safety": [random.randint(0,4) for x in range(1000)],
"honesty": [random.randint(0,4) for x in range(1000)],
"quality": [random.randint(0,4) for x in range(1000)],
"better": betterTrue + betterFalse
}
我想生成计数图,比较每个模型在
safety
、honesty
、quality
和 better
列中的性能。对于前三个,数据为 0 到 5 之间的整数值,对于 better
,数据为 0
或 1
。
但是对于前三列,我只关心数据点是否大于或等于 3 或小于 3。有没有办法生成一个计数图,将数据放入两个 bin
>= 3
和 < 3
?
作为参考,这就是我们不这样做时的样子,而是按每个可能的值离散地进行分箱
fig = sns.countplot(x = 'safety', hue='model', data=data, stat='count')
有很多方法可以将值范围转换为少数值。纯 Python 方式是
data["good quality"] = [0 if q < 3 else 1 for q in data["quality"]]
。除了 0
和 1
,您还可以使用字符串。
当处理大量数字时,numpy 提供速度和强大的数组表示法。
下面的代码使用 numpy 创建随机数组(由于某些神秘的原因,
random.randint()
和np.random.randint()
对第二个参数的解释不同)。该代码使用 pd.cut()
有效地将数字范围映射到字符串。
plt.subplots()
创建具有多个子图的图形。
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
betterTrue = np.random.randint(0, 2, 500)
betterFalse = 1 - betterTrue
data = {
"model": np.repeat(["A", "B"], 500),
"safety": np.random.randint(0, 5, 1000),
"honesty": np.random.randint(0, 5, 1000),
"quality": np.random.randint(0, 5, 1000),
"better": np.concatenate([betterTrue, betterFalse])
}
# add a new column, with "bad" for quality 0 till 2 and "good" for quality 3 or 4
data["good quality"] = pd.cut(data["quality"], bins=[0, 2, 4], labels=["bad", "good"])
columns_to_plot = [key for key in data.keys() if key != "quality"]
# choose a seaborn style
sns.set_style('whitegrid')
# create a figure with subplots
fig, axs = plt.subplots(ncols=3, nrows=2, figsize=(12, 7))
# flatten the array of axes, so it can be indexed as a 1D array
axs = axs.flatten()
# create a countplot for each of the desired columns
for column, ax in zip(columns_to_plot, axs):
sns.countplot(data, x=column, ax=ax)
# hide unused subplots
for ax in axs[len(columns_to_plot):]:
ax.remove()
# remove upper and right spine
sns.despine()
plt.tight_layout()
plt.show()