我正在寻找一种“干净”的方法来从使用plotly-express 创建的边际分布子图中删除趋势线。我知道这有点不清楚,所以请看下面的例子:
生成一些虚假数据:
np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)
使用
marginal
和 trendline
选项创建散点图:
fig = px.scatter(
data, x="feature1", y="feature2",
color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
log_x=False, marginal_x="box",
log_y=False, marginal_y="box",
trendline="ols", trendline_scope="overall", trendline_color_override='black',
trendline_options=dict(log_x=False, log_y=False),
)
我查看了
fig.data
结构,发现趋势线是其中的最后 3 个对象,最后 2 个是出现在顶部和右侧面板中的线条。从结构中删除这些对象将导致从这些面板中删除线条。看到这里:
fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]
这会产生一个新问题,因为它还会从图例中删除
trendline
,这不是我满意的行为。所以我需要首先更新倒数第三个对象(主面板的趋势线)以具有 showlegend=True
属性:
fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]
所以我确实有一个解决方案,但它需要“人工处理”
fig
对象。################
完整代码:
import copy
import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.express as px
pio.renderers.default = "browser"
np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)
fig = px.scatter(
data, x="feature1", y="feature2",
color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
log_x=False, marginal_x="box",
log_y=False, marginal_y="box",
trendline="ols", trendline_scope="overall", trendline_color_override='black',
trendline_options=dict(log_x=False, log_y=False),
)
fig.show()
fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]
fig2.show()
fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]
fig3.show()
Figure.update_traces()
方法,该方法允许将特定属性应用于满足 selector
参数的所有迹线(没有用于 remove 迹线的功能,但我们可以使用 visible
隐藏它们)属性)。
name
(“总体趋势线”,由 trendline_scope
给出),您可以使用它们的 xaxis
(或 yaxis
)参考来区分它们(即。 "x"
指主子图的 x 轴,"x2"
和 "x3"
分别指右侧和顶部轴/子图)。
例如:
np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)
fig = px.scatter(
data, x="feature1", y="feature2",
color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
log_x=False, marginal_x="box",
log_y=False, marginal_y="box",
trendline="ols", trendline_scope="overall", trendline_color_override='black',
trendline_options=dict(log_x=False, log_y=False),
)
fig.update_traces(visible=False, selector=dict(name='Overall Trendline'))
fig.update_traces(visible=True, showlegend=True, selector=dict(name='Overall Trendline', xaxis='x'))
fig.show()