我在使用 Plotly 和 Dash 通过将光标悬停在散点图中的点上来检索悬停数据时遇到了问题。 从 Dash 应用程序检索到的悬停数据似乎包含同一图中多个点的相同 pointNumber 和 pointIndex。这使得当鼠标悬停在相应点上时无法显示与给定实例关联的正确信息。
这是一个可以在 Jupyter Notebook 中运行的简化示例。最后我想在悬停时显示图像。
from sklearn.datasets import load_iris
import numpy as np
import pandas as pd
from jupyter_dash import JupyterDash
from dash import dcc, html, Input, Output, no_update
import plotly.express as px
# Loading iris data to pandas dataframe
data = load_iris()
images = data.data
labels = data.target
df = pd.DataFrame(images[:, :2], columns=["feat1", "feat2"])
df["label"] = labels
# Color for each class
color_map = {0: "setosa",
1: "versicolor",
2: "virginica"}
colors = [color_map[l] for l in labels]
df["color"] = colors
pd.set_option("display.max_rows", None, "display.max_columns", None)
print(df)
# Setup plotly scatter plot
fig = px.scatter(df, x="feat1", y="feat2", color="color")
fig.update_traces(hoverinfo="none",
hovertemplate=None)
# Setup Dash
app = JupyterDash(__name__)
app.layout = html.Div(className="container",
children=[dcc.Graph(id="graph-5", figure=fig, clear_on_unhover=True),
dcc.Tooltip(id="graph-tooltip-5", direction="bottom")])
@app.callback(Output("graph-tooltip-5", "show"),
Output("graph-tooltip-5", "bbox"),
Output("graph-tooltip-5", "children"),
Input("graph-5", "hoverData"))
def display_hover(hoverData):
if hoverData is None:
return False, no_update, no_update
print(hoverData)
hover_data = hoverData["points"][0]
bbox = hover_data["bbox"]
num = hover_data["pointNumber"]
children = [html.Div([html.Img(style={"height": "50px",
"width": "50px",
"display": "block",
"margin": "0 auto"}),
html.P("Feat1: {}".format(str(df.loc[num]["feat1"]))),
html.P("Feat2: {}".format(str(df.loc[num]["feat2"])))])]
return True, bbox, children
if __name__ == "__main__":
app.run_server(mode="inline", debug=True)
可以通过以下两个通过 print(df) 检索的实例来观察该问题:
索引 feat1 feat2 标签颜色
31 5.4 3.4 0 山毛榉
131 7.9 3.8 2 弗吉尼亚
两者都分配了通过 print(HoverData) 检索到的相同的 pointNumber 和 pointIndex:
{'points': [{'curveNumber': 2, 'pointNumber': 31, 'pointIndex': 31, “x”:7.9,“y”:3.8,“bbox”:{“x0”:1235.5,“x1”:1241.5,“y0”:152.13, ‘y1’:158.13}}]}
{'点': [{'曲线编号': 0, '点编号': 31, 'pointIndex':31,'x':5.4,'y':3.4,'bbox':{'x0':481.33,'x1': 487.33,'y0':197.38,'y1':203.38}}]}
这是将鼠标悬停在两个实例上时的可视化效果。右侧图像的悬停信息是错误的。
有趣的是,使用时问题就解决了
fig = px.scatter(df, x="feat1", y="feat2", color="label")
但是,这将导致图例以连续方式显示,并且无法选择性地可视化与 HTML 中特定类关联的实例。
这是一个错误还是我忽略了某些事情? 非常感谢任何帮助!
事实证明,我错误地期望
pointNumber
和 pointIndex
是独一无二的。一旦非数字列用作 color
中的 px.scatter()
参数,每个类的点编号和索引就会重新编号。散点图中的点可以通过组合 curveNumber
以及 pointNumber
和 pointIndex
之一来唯一标识。
一个潜在的解决方案是为每个类生成单独的索引并将它们添加到数据框中:
curve_indices = np.array([np.arange(0, num_samples) for num_samples in np.unique(class_annot, return_counts=True)[1]], dtype="object")
curve_indices = np.concatenate(curve_indices).ravel()
df["curve_index"] = curve_indices
在回调函数中,可以使用
识别每个实例的数据帧中的正确索引 df_index = df[(df.label == curve) & (df.curve_index == num)].index[0]
添加基于 @C.S. 接受的答案的完整工作示例,以防其他人发现它有帮助。一开始我并不清楚如何合并新指数。
from sklearn.datasets import load_iris
import numpy as np
import pandas as pd
from jupyter_dash import JupyterDash
from dash import dcc, html, Input, Output, no_update
import plotly.express as px
# Loading iris data to pandas dataframe
data = load_iris()
images = data.data
labels = data.target
df = pd.DataFrame(images[:, :2], columns=["feat1", "feat2"])
df["label"] = labels
# Color for each class
color_map = {0: "setosa",
1: "versicolor",
2: "virginica"}
colors = [color_map[l] for l in labels]
df["color"] = colors
# This will create an index from 0 to the max number of samples in each class (color, in this case corresponding to species) and add it to the data frame
curve_indices = np.array([np.arange(0, num_samples) for num_samples in np.unique(labels, return_counts=True)[1]], dtype="object")
curve_indices = np.concatenate(curve_indices).ravel()
df["curve_index"] = curve_indices
pd.set_option("display.max_rows", None, "display.max_columns", None)
# Look at the df to see what the curve_index does
print(df)
# Setup plotly scatter plot
fig = px.scatter(df, x="feat1", y="feat2", color="color")
fig.update_traces(hoverinfo="none",
hovertemplate=None)
# Setup Dash
app = JupyterDash(__name__)
app.layout = html.Div(className="container",
children=[dcc.Graph(id="graph-5", figure=fig, clear_on_unhover=True),
dcc.Tooltip(id="graph-tooltip-5", direction="bottom")])
@app.callback(Output("graph-tooltip-5", "show"),
Output("graph-tooltip-5", "bbox"),
Output("graph-tooltip-5", "children"),
Input("graph-5", "hoverData"))
def display_hover(hoverData):
if hoverData is None:
return False, no_update, no_update
# Printing the hover data makes it clear which values we are drawing from
print(hoverData)
hover_data = hoverData["points"][0]
bbox = hover_data["bbox"]
num = hover_data["pointNumber"]
# We need to pull the curve data out of the hoverData
curve = hover_data["curveNumber"]
# This finds the index (row number) of the df where the curve number and point index within that curve match our newly added indices
df_index = df[(df.label == curve) & (df.curve_index == num)].index[0]
# We can print df_index if we want to see this
#print(df_index)
children = [html.Div([html.Img(style={"height": "50px",
"width": "50px",
"display": "block",
"margin": "0 auto"}),
# df.iloc[df_index] will return the row of the df that the pointer is hovering over
html.P("Feat1: {}".format(str(df.iloc[df_index]["feat1"]))),
html.P("Feat2: {}".format(str(df.iloc[df_index]["feat2"])))])]
return True, bbox, children
if __name__ == "__main__":
app.run_server(mode="inline", debug=True)