我正在尝试通过分类列为散点图着色。这是一个示例数据,我想要为散点图着色的列是“cat”。
data = {
'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'y': [2, 3, 5, 7, 11, 13, 17, 19, 23, 29],
'z': [1, 2, 2, 3, 3, 4, 4, 5, 6, 6],
'cat': ['A', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A']
}
pandas_df = pd.DataFrame(data)
pyspark_df = spark.createDataFrame(pandas_df)
我创建了以下函数来测试输出。如果我从参数中删除“色调”,一切都会正常工作,但我似乎无法使其与“色调”一起正常工作。
def facet_plot(df, x, y, color, facet_col, bins = None):
pd_df = df.toPandas()
if bins is not None:
# check col type
if pd_df[facet_col].dtype.name in ['float64', 'int64']:
# bin the facet column
pd_df['facet_col_binned']= pd.cut(pd_df[facet_col], bins = bins)
# convert intervals to midpoints
pd_df['facet_col_binned'] = pd_df['facet_col_binned'].apply(lambda interval: round(interval.mid, 1) if pd.notna(interval) else None)
pd_df['facet_col_binned'] = pd.Categorical(pd_df['facet_col_binned'])
# assigning x as 'x_binned' for remaining code
facet_col = 'facet_col_binned'
pd_df[color] = pd_df[color].astype(str)
g = sns.FacetGrid(pd_df, col=facet_col, col_wrap=4, height=5, aspect=2)
g.map(sns.scatterplot, x, y, hue=color)
# if row => then change to row_template = '{row_name}'
g.set_titles(col_template = '{col_name}')
g.set_axis_labels(x, y)
plt.show()
facet_plot(pyspark_df, 'x', 'y', color = 'cat', facet_col='cat', bins = 2)
首先,创建一个更简单的示例有助于查明问题:
import seaborn as sns
import pandas as pd
data = {
'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'y': [2, 3, 5, 7, 11, 13, 17, 19, 23, 29],
'z': [1, 2, 2, 3, 3, 1, 1, 2, 3, 3],
'cat': ['A', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A']
}
pd_df = pd.DataFrame(data)
g = sns.FacetGrid(pd_df, col='z', col_wrap=3, height=3, aspect=2)
g.map(sns.scatterplot, 'x', 'y', hue='cat')
主要问题是
g.map
在调用sns.scatterplot
时没有提供完整的数据帧。它仅将 'x'
和 'y'
替换为数据帧的相应列。因此,g.map()
无法解析(“解释”)'cat'
列。
一种可能性是使用
g.map_dataframe
代替。由于不会自动创建图形图例,因此您还需要调用 g.add_legend()
。
更好的解决方案是将
hue=
添加到 sns.FacetGrid(...., hue='cat')
并将其保留在 g.map(sns.scatterplot, 'x', 'y')
中。
推荐的解决方案是使用函数的“图形级别”版本。对于
sns.scatterplot
,这是 sns.relplot
。这也会创建一个 FacetGrid
,但针对散点图进行了更精细的调整。
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
data = {
'x': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'y': [2, 3, 5, 7, 11, 13, 17, 19, 23, 29],
'z': [1, 2, 2, 3, 3, 1, 1, 2, 3, 3],
'cat': ['A', 'A', 'B', 'B', 'A', 'A', 'B', 'B', 'A', 'A']
}
pd_df = pd.DataFrame(data)
g = sns.relplot(pd_df, x='x', y='y', hue='cat', col='z', col_wrap=3, height=3, aspect=2)
plt.show()