SHAP 瀑布图作为 matplotlib 子图?

问题描述 投票:0回答:1

您好,我想显示 SHAP 库中的并排图表:

  1. 瀑布图:API参考https://shap.readthedocs.io/en/latest/ generated/shap.plots.waterfall.html#shap.plots.waterfall

  2. 条形图:API参考https://shap.readthedocs.io/en/latest/ generated/shap.plots.bar.html

我正在使用标准 Matplotlib 线

fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 5))

整个代码如下:

import matplotlib.pyplot as plt
def shap_diagrams(shapley_values, index=0):
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 5))
    ax0 = shap.plots.waterfall(shapley_values[index], show= False)
    ax0.set_title('SHAP Waterfall Plot')
    plt.gca()
    shap.plots.bar(shapley_values,ax=ax1, show= False)
    ax1.set_title('SHAP Bar Plot')
    
    plt.show()

调用函数后

shap_diagrams(shap_values,2)
图表会被覆盖。请指教。

我尝试了不同的分配,但应该是瀑布“轴”的地方是一个空的斧头对象,瀑布图本身跳转到正确渲染的条形图在此处输入图像描述

matplotlib subplot shap
1个回答
0
投票

那是因为你覆盖了子图的轴,这会导致四个:

[
    <Axes: >,
    <Axes: title={'center': 'SHAP Bar Plot'}, xlabel='mean(|SHAP value|)'>,
    <Axes: >,
    <Axes: title={'center': 'SHAP Waterfall Plot'}>
]

由于

waterfall
没有
ax
参数,因此您需要在调用它之前
sca
:

def shap_diagrams(shapley_values, index=0):
    fig, (ax0, ax1) = plt.subplots(1, 2)
    plt.sca(ax0)

    shap.plots.waterfall(shapley_values[index], show=False)
    ax0.set_title("SHAP Waterfall Plot")

    shap.plots.bar(shapley_values, ax=ax1, show=False)
    ax1.set_title("SHAP Bar Plot", pad=25)

    # to horizontally separate the two axes
    plt.subplots_adjust(wspace=1)
    # because waterfall seems to update the figsize
    fig.set_size_inches(10, 3)

    plt.show()

shap_diagrams(shap_values, 2)

enter image description here

使用过的输入

import xgboost
import shap

X, y = shap.datasets.adult(n_points=2000)
model = xgboost.XGBClassifier().fit(X, y)

explainer = shap.Explainer(model, X)
shap_values = explainer(X)
© www.soinside.com 2019 - 2024. All rights reserved.