我正在尝试从
shap
包中绘制依赖图网格。这是我想要的示例的 MWE 代码:
fig, axs = plt.subplots(2,8, figsize=(16, 4), facecolor='w', edgecolor='k') # figsize=(width, height)
fig.subplots_adjust(hspace = .5, wspace=.001)
axs = axs.ravel()
for i in range(10):
axs[i].contourf(np.random.rand(12,12),5,cmap=plt.cm.Oranges)
axs[i].set_title(str(250+i))
plt.show()
这是我到目前为止的代码。有些事情不起作用:
figsize
参数的影响fig, axs = plt.subplots(1,8, figsize=(4, 2))
axs = axs.ravel()
for b in X_test.columns[:3]:
for a in X_test.columns[:3]:
shap.dependence_plot((a, b), shap_interaction_values, X_test)
import shap
import matplotlib.pyplot as plt
X = ...
shap_values = ...
columns = X.columns
# adjust nrows, ncols to fit all your columns
fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(20, 14))
axes = axes.ravel()
for i, col in enumerate(columns):
shap.dependence_plot(col, shap_values, X, ax=axes[i], show=False)
我遇到了和你一样的问题 - 代码说 dependency_plot 需要一个可选参数:ax
因此,您可以制作子图并将后续图放入其中:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2)
shap.dependence_plot((a, b), shap_interaction_values, X_test, ax=ax1)
shap.dependence_plot((a, b), shap_interaction_values, X_test, ax=ax2)
在您的情况下,您可以
zip()
将轴和列放在一起。
我还没有解决使用
interaction_index
时该怎么做 - 在这种情况下,你会在图的末尾得到所有可能的交互索引热图,这看起来非常糟糕。
编辑:丑陋的黑客,但它似乎可以解决问题 - 如果您为每个依赖图指定交互索引,那么它将为每个图绘制一个颜色条到最后一个子图中,这看起来很糟糕。
我最终手动删除了轴(每个颜色条都是一个附加轴),然后自动重新调整子图:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, constrained_layout=True)
shap.dependence_plot((a, b), shap_interaction_values, X_test, ax=ax1)
shap.dependence_plot((a, b), shap_interaction_values, X_test, ax=ax2)
fig.axes[-1].remove()
fig.axes[-1].remove()
这将消除所有颜色条,并且constrained_layout = True将确保正确重绘最后一个子图,如果没有此参数,它将保持“压扁”状态以为不存在的颜色条腾出空间。
经过几分钟的摆弄,我找到了解决方案。设置
ax
参数本身并不够。您还必须将 show
参数设置为 False
,如下所示:
# Configure subplots to your liking
fig, axs = plt.subplots(nrows=1, ncols=2)
# Plot dependence plots
shap.dependence_plot('age', shap_values, X_test, ax=axs[0], show=False) # set show to False
shap.dependence_plot('age', shap_values, X_test, ax=axs[1], show=False) # set show to False
# Show plots
plt.tight_layout()
plt.show()