为绘图直方图分配颜色进行分组

问题描述 投票:0回答:1

我正在尝试开始使用 Dash,但遇到一个简单的问题。我有一个数据表

Index, Product, Customer_Age, Revenue
1, A, 12, 10
2, B, 99, 12

等等。

我想将每种产品的收入绘制为直方图。但是,我希望将条形图分为不同颜色以适应不同年龄组(例如 10 人一组)。我该如何实现这一目标?此刻,我和

from dash import Dash, html, dash_table, dcc
import plotly.express as px, pandas as pd
data = pd.read_csv('data.csv')
app = Dash(__name__)
app.layout = html.Div([
    dcc.Graph(figure=px.histogram(data, x='Product', y='Revenue', histfunc='sum', color='Customer_Age')),
])
if __name__ == '__main__':
    app.run(debug=True)

为每个年龄段分配不同的颜色。此外,颜色是相当随机的,而不是一个很好的连续颜色序列。有没有一种优雅的方式来实现我想要的?

python pandas plotly histogram plotly-dash
1个回答
1
投票

这是一个示例解决方案,除了直方图之外,还使用条形图来可视化分组数据

在您的特定用例中,由于您对类别的重视,因此使用条形图可视化数据可能比使用直方图更有意义。通常,直方图显示来自一个“连续”(或至少有效地视为连续)数值类别(变量)的一系列值(有序样本)的分箱分布。因此,如您所描述的那样可视化颜色编码的年龄分类分布会变得有点复杂。 (要显示多少个直方图?每种颜色一个?等等 在此示例中,我计算了要在条形图的子图中显示的收入总和,每个产品一个子图,条形图按年龄区间分组,间隔设置为 10 年。当任何条形悬停在应用程序中时,子图下方的破折号表会更新以仅显示悬停的数据。

下面,我修改了这个演示应用程序以另外(

对每个用户悬停交互做出反应

[即通过回调触发])更新单个直方图,然后将显示数据的分布当前悬停的条形代表根据这些值计算出的sum 例如, import random import dash import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go from dash import Dash, html, dcc, dash_table from dash.dependencies import Input, Output, State from plotly.subplots import make_subplots # Sample data n = 500 ages = list(range(14, 83)) interval = 10 data = pd.DataFrame( { "Index": list(range(n)), "Product": random.choices(["A", "B"], k=n), "Customer_Age": random.choices(ages, k=n), "Revenue": random.choices(list(range(10, 40, 2)), k=n), } ) # Bin ages into groups of 10 bins = list(range(min(ages) - 1, max(ages) + interval, interval)) labels = [f"{i}-{i+9}" for i in bins[:-1]] data["Age_Group"] = pd.cut( data["Customer_Age"], bins=bins, labels=labels, right=True ) # Group data by product and age group and sum the revenue grouped = ( data.groupby(["Product", "Age_Group"]).sum()["Revenue"].reset_index() ) print(grouped) # Sort products alphabetically products = sorted(data["Product"].unique()) # Create subplot layout with shared y-axis fig = make_subplots( rows=1, cols=2, subplot_titles=products, shared_yaxes=True ) # Define a custom color map for age groups # Sequential HSL-like colormap color_map = { label: f"hsl({i * 360 / len(labels)}, 100%, 50%)" for i, label in enumerate(labels) } # Standard colors: # color_map = { # label: color for label, color in zip(labels, px.colors.qualitative.Set1) # } # For each product, plot the revenue by age group as a bar chart for idx, product in enumerate(products): product_data = grouped[grouped["Product"] == product] for age_group in labels: age_data = product_data[product_data["Age_Group"] == age_group] if not age_data.empty: revenue = age_data["Revenue"].values[0] else: revenue = 0 # Only show legend for the first subplot but ensure both linked showlegend = True if idx == 0 else False fig.add_trace( go.Bar( x=[age_group], y=[revenue], name=age_group, marker_color=color_map[age_group], showlegend=showlegend, legendgroup=age_group, ), row=1, col=idx + 1, ) fig.update_layout(barmode="group", title="Revenue by Product and Age Group") app = Dash(__name__) app.layout = html.Div( [ dcc.Graph(id="revenue-graph", figure=fig), html.Div( [ html.Button("Reset Table", id="reset-button", n_clicks=0), html.Button("Show/Hide Sum", id="toggle-button", n_clicks=0), ], style={"textAlign": "center", "margin": "20px"}, ), html.Div( # Container Div using CSS Grid style={ "display": "grid", "gridTemplateColumns": "1fr 1fr", "gap": "10px", }, children=[ dcc.Graph(id="histogram-plot"), dash_table.DataTable( id="data-table", columns=[{"name": i, "id": i} for i in data.columns], data=grouped.to_dict("records"), style_table={"height": "500px", "overflowY": "auto"}, style_cell={"textAlign": "center"}, ), ], ), ], style={"padding": "20px", "margin": "5%"}, ) # Determine number of bins nbins = 10 @app.callback( [ Output("data-table", "data"), Output("histogram-plot", "figure"), Output("histogram-plot", "style"), ], [ Input("revenue-graph", "hoverData"), Input("reset-button", "n_clicks"), Input("toggle-button", "n_clicks"), ], ) def update_output(hoverData, reset_clicks, toggle_clicks): ctx = dash.callback_context if not ctx.triggered_id: return ( grouped.to_dict("records"), {}, {}, ) if ctx.triggered_id == "reset-button": return ( grouped.to_dict("records"), {}, {}, ) elif ctx.triggered_id == "toggle-button": if toggle_clicks % 2 == 1: return ( data.to_dict("records"), {}, {}, ) else: return ( grouped.to_dict("records"), {}, {}, ) elif ctx.triggered_id == "revenue-graph": if hoverData: point_data = hoverData["points"][0] curve_number = point_data["curveNumber"] product = products[curve_number // len(labels)] age_group = point_data["x"] filtered_data = data[ (data["Product"] == product) & (data["Age_Group"] == age_group) ] else: return ( dash.no_update, dash.no_update, dash.no_update, ) histogram = px.histogram( filtered_data, x="Revenue", color_discrete_sequence=["rgba(0, 0, 0, 0.1)"], nbins=nbins, histnorm="probability density", ) histogram.update_traces( marker_line_color="black", marker_line_width=1 ) histogram.update_layout( yaxis_range=[0, 0.1], xaxis_range=[0, 50], ) if toggle_clicks % 2 == 1: return ( filtered_data.to_dict("records"), histogram, {"display": "block"}, ) else: sum_data = grouped[ (grouped["Product"] == product) & (grouped["Age_Group"] == age_group) ] return ( sum_data.to_dict("records"), histogram, {"display": "block"}, ) if __name__ == "__main__": app.run_server(debug=True)

结果:

© www.soinside.com 2019 - 2024. All rights reserved.