Plotly Express:从边际分布图中删除趋势线

问题描述 投票:0回答:1

我正在寻找一种“干净”的方法来从使用plotly-express 创建的边际分布子图中删除趋势线。我知道这有点不清楚,所以请看下面的例子:
生成一些虚假数据:

np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)

使用

marginal
trendline
选项创建散点图:

fig = px.scatter(
    data, x="feature1", y="feature2",
    color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
    log_x=False, marginal_x="box",
    log_y=False, marginal_y="box",
    trendline="ols", trendline_scope="overall", trendline_color_override='black',
    trendline_options=dict(log_x=False, log_y=False),
)

这会产生一个在所有 3 个面板中都带有趋势线的数字:
original fig

我查看了

fig.data
结构,发现趋势线是其中的最后 3 个对象,最后 2 个是出现在顶部和右侧面板中的线条。从结构中删除这些对象将导致从这些面板中删除线条。看到这里:

fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]

last 2 objects removed from original fig.data

这会产生一个新问题,因为它还会从图例中删除

trendline
,这不是我满意的行为。所以我需要首先更新倒数第三个对象(主面板的趋势线)以具有
showlegend=True
属性:

fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]

这终于给了我我想要的数字:
last 2 objects removed from original fig.data and trendline included in legend

所以我确实有一个解决方案,但它需要“人工处理”

fig
对象。
有没有更好、更干净的方法来达到相同的最终数字?

################
完整代码:

import copy

import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.express as px

pio.renderers.default = "browser"

np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)

fig = px.scatter(
    data, x="feature1", y="feature2",
    color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
    log_x=False, marginal_x="box",
    log_y=False, marginal_y="box",
    trendline="ols", trendline_scope="overall", trendline_color_override='black',
    trendline_options=dict(log_x=False, log_y=False),
)
fig.show()

fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]
fig2.show()

fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]
fig3.show()
python plotly plotly-python
1个回答
0
投票

您可以使用

Figure.update_traces()
方法,该方法允许将特定属性应用于满足
selector
参数的所有迹线(没有用于 remove 迹线的功能,但我们可以使用 visible 隐藏
它们) 
属性)。

所有 OLS 趋势线轨迹共享相同的

name
(“总体趋势线”,由
trendline_scope
给出),您可以使用它们的
xaxis
(或
yaxis
)参考来区分它们(即。
"x"
指主子图的 x 轴,
"x2"
"x3"
分别指右侧和顶部轴/子图)。

例如:

np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)

fig = px.scatter(
    data, x="feature1", y="feature2",
    color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
    log_x=False, marginal_x="box",
    log_y=False, marginal_y="box",
    trendline="ols", trendline_scope="overall", trendline_color_override='black',
    trendline_options=dict(log_x=False, log_y=False),
)

fig.update_traces(visible=False, selector=dict(name='Overall Trendline'))
fig.update_traces(visible=True, showlegend=True, selector=dict(name='Overall Trendline', xaxis='x'))

fig.show()
© www.soinside.com 2019 - 2024. All rights reserved.