if start_date is None or end_date is None:
raise PreventUpdate
start_date = pd.to_datetime(start_date)
end_date = pd.to_datetime(end_date)
forecast_df['ds'] = pd.to_datetime(forecast_df['ds'])
historical_df['ds'] = pd.to_datetime(historical_df['ds'])
forecast_df['store_id'] = 'ЦБ000005'
min_forecast_date = forecast_df['ds'].min()
if start_date < min_forecast_date:
historical_data = historical_df[(historical_df['store_id'] == selected_store) &
(historical_df['sku'] == selected_product) &
(historical_df['ds'] >= start_date) &
(historical_df['ds'] < min_forecast_date)]
forecast_data = forecast_df[(forecast_df['store_id'] == selected_store) &
(forecast_df['sku'] == selected_product) &
(forecast_df['ds'] >= min_forecast_date) &
(forecast_df['ds'] <= end_date)]
combined_data = pd.concat([historical_data, forecast_data])
else:
combined_data = forecast_df[(forecast_df['store_id'] == selected_store) &
(forecast_df['sku'] == selected_product) &
(forecast_df['ds'] >= start_date) &
(forecast_df['ds'] <= end_date)]
line_colors = ['green' if date >= min_forecast_date else 'red' for date in combined_data['ds']]
forecast_figure = {
'data': [
{'x': combined_data['ds'],
'y': combined_data['yhat'],
'type': 'scatter',
'mode': 'lines',
'name': 'Sales prediction',
'line': {'color': line_colors}}],
'layout': {
'title': f'Sales prediction for {selected_product}',
'xaxis': {'title': 'Date'},
'yaxis': {'title': 'Sales'}
}
我画了一条销售预测线,但有时用户会选择一个区间,让过去的数据出现在那里,即日期小于当前日期。为此,我想用两种颜色画一条线,如果日期小于当前日期,则为黑色,在我的上下文中,它是 min_forecast_date,任何更大的颜色都是绿色。我使用的方法不起作用,这就是为什么整条线都以默认颜色绘制......即蓝色。
我创建了一个示例数据框和解决方案,我建议将预测之前的所有日期的绘图涂成黑色,之后涂成绿色:
import pandas as pd
import numpy as np
import plotly.graph_objs as go
date_range = pd.date_range(start='2023-01-01', periods=120, freq='D')
np.random.seed(0)
sales = np.random.randint(100, 500, size=(120))
forecast_df = pd.DataFrame({
'ds': date_range,
'yhat': sales,
'store_id': 'ЦБ000005',
'sku': 'Product123'
})
historical_df = pd.DataFrame({
'ds': pd.date_range(start='2022-01-01', periods=365, freq='D'),
'yhat': np.random.randint(100, 500, size=(365)),
'store_id': ['ЦБ000005']*365,
'sku': ['Product123']*365
})
selected_store = 'ЦБ000005'
selected_product = 'Product123'
start_date = '2022-06-01'
end_date = '2023-04-30'
min_forecast_date = '2023-01-01'
start_date = pd.to_datetime(start_date)
end_date = pd.to_datetime(end_date)
min_forecast_date = pd.to_datetime(min_forecast_date)
historical_data = historical_df[(historical_df['store_id'] == selected_store) &
(historical_df['sku'] == selected_product) &
(historical_df['ds'] >= start_date) &
(historical_df['ds'] < min_forecast_date)]
forecast_data = forecast_df[(forecast_df['store_id'] == selected_store) &
(forecast_df['sku'] == selected_product) &
(forecast_df['ds'] >= min_forecast_date) &
(forecast_df['ds'] <= end_date)]
combined_data = pd.concat([historical_data, forecast_data])
trace1 = go.Scatter(
x=historical_data['ds'],
y=historical_data['yhat'],
mode='lines',
name='Historical Sales',
line=dict(color='black')
)
trace2 = go.Scatter(
x=forecast_data['ds'],
y=forecast_data['yhat'],
mode='lines',
name='Forecast Sales',
line=dict(color='green')
)
data = [trace1, trace2]
layout = go.Layout(
title='Sales Prediction for Product123',
xaxis=dict(title='Date'),
yaxis=dict(title='Sales')
)
fig = go.Figure(data=data, layout=layout)
fig.show()