我正在尝试开始使用 Dash,但遇到一个简单的问题。我有一个数据表
Index, Product, Customer_Age, Revenue
1, A, 12, 10
2, B, 99, 12
等等。
我想将每种产品的收入绘制为直方图。但是,我希望将条形图分为不同颜色以适应不同年龄组(例如 10 人一组)。我该如何实现这一目标?此刻,我和
from dash import Dash, html, dash_table, dcc
import plotly.express as px, pandas as pd
data = pd.read_csv('data.csv')
app = Dash(__name__)
app.layout = html.Div([
dcc.Graph(figure=px.histogram(data, x='Product', y='Revenue', histfunc='sum', color='Customer_Age')),
])
if __name__ == '__main__':
app.run(debug=True)
为每个年龄段分配不同的颜色。此外,颜色是相当随机的,而不是一个很好的连续颜色序列。有没有一种优雅的方式来实现我想要的?
在您的特定用例中,由于您对类别的重视,因此使用条形图可视化数据可能比使用直方图更有意义。通常,直方图显示来自一个“连续”(或至少有效地视为连续)数值类别(变量)的一系列值(有序样本)的分箱分布。因此,如您所描述的那样可视化颜色编码的年龄分类分布会变得有点复杂。 (要显示多少个直方图?每种颜色一个?等等) 在此示例中,我计算了要在条形图的子图中显示的收入总和,每个产品一个子图,条形图按年龄区间分组,间隔设置为 10 年。当任何条形悬停在应用程序中时,子图下方的破折号表会更新以仅显示悬停的数据。
下面,我修改了这个演示应用程序以另外(对每个用户悬停交互做出反应
[即通过回调触发])更新单个直方图,然后将显示数据的分布当前悬停的条形代表根据这些值计算出的sum。
例如,
import random
import dash
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, html, dcc, dash_table
from dash.dependencies import Input, Output, State
from plotly.subplots import make_subplots
# Sample data
n = 500
ages = list(range(14, 83))
interval = 10
data = pd.DataFrame(
{
"Index": list(range(n)),
"Product": random.choices(["A", "B"], k=n),
"Customer_Age": random.choices(ages, k=n),
"Revenue": random.choices(list(range(10, 40, 2)), k=n),
}
)
# Bin ages into groups of 10
bins = list(range(min(ages) - 1, max(ages) + interval, interval))
labels = [f"{i}-{i+9}" for i in bins[:-1]]
data["Age_Group"] = pd.cut(
data["Customer_Age"], bins=bins, labels=labels, right=True
)
# Group data by product and age group and sum the revenue
grouped = (
data.groupby(["Product", "Age_Group"]).sum()["Revenue"].reset_index()
)
print(grouped)
# Sort products alphabetically
products = sorted(data["Product"].unique())
# Create subplot layout with shared y-axis
fig = make_subplots(
rows=1, cols=2, subplot_titles=products, shared_yaxes=True
)
# Define a custom color map for age groups
# Sequential HSL-like colormap
color_map = {
label: f"hsl({i * 360 / len(labels)}, 100%, 50%)"
for i, label in enumerate(labels)
}
# Standard colors:
# color_map = {
# label: color for label, color in zip(labels, px.colors.qualitative.Set1)
# }
# For each product, plot the revenue by age group as a bar chart
for idx, product in enumerate(products):
product_data = grouped[grouped["Product"] == product]
for age_group in labels:
age_data = product_data[product_data["Age_Group"] == age_group]
if not age_data.empty:
revenue = age_data["Revenue"].values[0]
else:
revenue = 0
# Only show legend for the first subplot but ensure both linked
showlegend = True if idx == 0 else False
fig.add_trace(
go.Bar(
x=[age_group],
y=[revenue],
name=age_group,
marker_color=color_map[age_group],
showlegend=showlegend,
legendgroup=age_group,
),
row=1,
col=idx + 1,
)
fig.update_layout(barmode="group", title="Revenue by Product and Age Group")
app = Dash(__name__)
app.layout = html.Div(
[
dcc.Graph(id="revenue-graph", figure=fig),
html.Div(
[
html.Button("Reset Table", id="reset-button", n_clicks=0),
html.Button("Show/Hide Sum", id="toggle-button", n_clicks=0),
],
style={"textAlign": "center", "margin": "20px"},
),
html.Div( # Container Div using CSS Grid
style={
"display": "grid",
"gridTemplateColumns": "1fr 1fr",
"gap": "10px",
},
children=[
dcc.Graph(id="histogram-plot"),
dash_table.DataTable(
id="data-table",
columns=[{"name": i, "id": i} for i in data.columns],
data=grouped.to_dict("records"),
style_table={"height": "500px", "overflowY": "auto"},
style_cell={"textAlign": "center"},
),
],
),
],
style={"padding": "20px", "margin": "5%"},
)
# Determine number of bins
nbins = 10
@app.callback(
[
Output("data-table", "data"),
Output("histogram-plot", "figure"),
Output("histogram-plot", "style"),
],
[
Input("revenue-graph", "hoverData"),
Input("reset-button", "n_clicks"),
Input("toggle-button", "n_clicks"),
],
)
def update_output(hoverData, reset_clicks, toggle_clicks):
ctx = dash.callback_context
if not ctx.triggered_id:
return (
grouped.to_dict("records"),
{},
{},
)
if ctx.triggered_id == "reset-button":
return (
grouped.to_dict("records"),
{},
{},
)
elif ctx.triggered_id == "toggle-button":
if toggle_clicks % 2 == 1:
return (
data.to_dict("records"),
{},
{},
)
else:
return (
grouped.to_dict("records"),
{},
{},
)
elif ctx.triggered_id == "revenue-graph":
if hoverData:
point_data = hoverData["points"][0]
curve_number = point_data["curveNumber"]
product = products[curve_number // len(labels)]
age_group = point_data["x"]
filtered_data = data[
(data["Product"] == product)
& (data["Age_Group"] == age_group)
]
else:
return (
dash.no_update,
dash.no_update,
dash.no_update,
)
histogram = px.histogram(
filtered_data,
x="Revenue",
color_discrete_sequence=["rgba(0, 0, 0, 0.1)"],
nbins=nbins,
histnorm="probability density",
)
histogram.update_traces(
marker_line_color="black", marker_line_width=1
)
histogram.update_layout(
yaxis_range=[0, 0.1], xaxis_range=[0, 50],
)
if toggle_clicks % 2 == 1:
return (
filtered_data.to_dict("records"),
histogram,
{"display": "block"},
)
else:
sum_data = grouped[
(grouped["Product"] == product)
& (grouped["Age_Group"] == age_group)
]
return (
sum_data.to_dict("records"),
histogram,
{"display": "block"},
)
if __name__ == "__main__":
app.run_server(debug=True)
结果: