我正在尝试使用Python中的plotly创建一些3D散点图。我有一个包含一些性格特征和聚类标签的数据框,我想使用颜色和符号来绘制每个特征的 TSNE 嵌入来表示特征和聚类值。这是我的代码:
from plotly.subplots import make_subplots
# Replace y and n with numbers
df2 = df.replace({'y': 1, 'n': 0})
# Assign cluster labels
df2 = df2.assign(Cluster=labels)
# Create a list of plot titles for each personality trait
plot_titles = ['Extraversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness']
# Plot the TSNE embeddings for each personality trait
for i, trait in enumerate(['cEXT', 'cNEU', 'cAGR', 'cCON', 'cOPN']):
# Create a single plot figure for each trait
fig = px.scatter_3d(df2,
x=embeddings_3d[:, 0], y=embeddings_3d[:, 1], z=embeddings_3d[:, 2],
color = df2[trait], symbol=df2['Cluster'],
color_discrete_map={0: '#FF0000', 1: '#0000FF'},
size_max=1, symbol_map={0: 'circle', 1: 'square', 2: 'diamond', 3: 'cross', 4: 'x'},
opacity=0.3)
# Update the layout of the plot figure with the title
fig.update_layout(title=plot_titles[i])
# Show the plot figure
fig.show()
然而,剧情的传说却是一团糟。它显示了颜色和簇表示的渐变图例,但我不需要渐变。我的类别是标签,因此是分类的。只是名字而已!而且,这一切都是混合的。所以我需要呈现的只是每个数据点代表的集群的标签。
如何修复图的图例以仅显示具有相应颜色和符号的簇标签?任何帮助,将不胜感激。谢谢!
这是数据:df2
这是一个可能的解决方案的草案,因为我仍然缺少问题中未提供的一些参数(标签和嵌入_3d)。检查这是否是您正在寻找的内容。但请注意,我无法运行该代码,因此这是我刚刚在脑海中写下的内容。
from plotly.subplots import make_subplots
import plotly.express as px
# Replace y and n with numbers
df2 = df.replace({'y': 1, 'n': 0})
# Assign cluster labels
df2 = df2.assign(Cluster=labels)
# Create a list of plot titles for each personality trait
plot_titles = ['Extraversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness']
# Initialize an empty list to store traces
traces = []
# Plot the TSNE embeddings for each personality trait
for i, trait in enumerate(['cEXT', 'cNEU', 'cAGR', 'cCON', 'cOPN']):
# Create a scatter plot trace for the current trait
trace = px.scatter_3d(df2,
x=embeddings_3d[:, 0], y=embeddings_3d[:, 1], z=embeddings_3d[:, 2],
color=df2[trait], symbol=df2['Cluster'],
color_discrete_map={0: '#FF0000', 1: '#0000FF'},
size_max=1, symbol_map={0: 'circle', 1: 'square', 2: 'diamond', 3: 'cross', 4: 'x'},
opacity=0.3,
title=plot_titles[i],
showlegend=True) # Set showlegend to True
# Append the trace to the list
traces.append(trace)
# Create a subplot with all the traces
fig = make_subplots(rows=1, cols=len(traces))
for i, trace in enumerate(traces):
# Add each trace to the subplot
fig.add_trace(trace, row=1, col=i + 1)
# Update the subplot layout
fig.update_layout(title_text="Personality Traits and Cluster Labels", showlegend=True)
# Show the plot
fig.show()