让我们假设我们有以下简化的代码:

import pandas as pd
import shap
from sklearn.ensemble import  RandomForestRegressor
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
mylabel =LabelEncoder()
data =pd.read_csv("https://raw.githubusercontent.com/krishnaik06/Multiple-Linear-Regression/master/50_Startups.csv")
data['State'] =mylabel.fit_transform(data['State'])
print(data.head())
model =RandomForestRegressor()
y =data['Profit']
X =data.drop('Profit',axis=1)
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.1,random_state=1)
model.fit(X_train,y_train)
explainer =shap.TreeExplainer(model)
shap_values =explainer.shap_values(X_train)
plt.figure(figsize=(30,30))
plt.subplot(2,1,1)
shap.summary_plot(shap_values, X_train, feature_names=X.columns, plot_type="bar")
plt.subplot(2,1,2)
shap.summary_plot(shap_values, X_train, feature_names=X.columns)
plt.show()

when i run this code, i am getting two image on different figure : one image : enter image description here

and another image : enter image description here

我想要把它们一个接一个地画出来,就像你看到的,我用了子图:

plt.subplot(2,1,1)
shap.summary_plot(shap_values, X_train, feature_names=X.columns, plot_type="bar")
plt.subplot(2,1,2)
shap.summary_plot(shap_values, X_train, feature_names=X.columns)

但它不起作用,我试图使用以下代码:

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10,10))
shap.dependence_plot('age', shap_values[1], X_train, ax=axes[0, 0], show=False)
shap.dependence_plot('income', shap_values[1], X_train, ax=axes[0, 1], show=False)
shap.dependence_plot('score', shap_values[1], X_train, ax=axes[1, 0], show=False)
plt.show()

但是SUMMARY_PLOT没有参数ax,那么我如何使用它呢?

推荐答案

您的第一个代码示例是正确的.但你需要在第一次呼叫shap.summary_plot(..., show=False)的基础上再加show=False.使用默认值show=True时,绘图将立即显示,但也会被擦除.并创建了一个新的情节来展示第二部分.

import pandas as pd
import shap
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder

mylabel = LabelEncoder()
data = pd.read_csv("https://raw.githubusercontent.com/krishnaik06/Multiple-Linear-Regression/master/50_Startups.csv")
data['State'] = mylabel.fit_transform(data['State'])

model = RandomForestRegressor()
y = data['Profit']
X = data.drop('Profit', axis=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1)
model.fit(X_train, y_train)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_train)
plt.figure(figsize=(30, 30))
plt.subplot(2, 1, 1)
shap.summary_plot(shap_values, X_train, feature_names=X.columns, plot_type="bar", show=False)
plt.subplot(2, 1, 2)
shap.summary_plot(shap_values, X_train, feature_names=X.columns, show=False)
plt.show()

shap summary_plot into subplots

Python相关问答推荐

根据条件将新值添加到下面的行或下面新创建的行中

将特定列信息移动到当前行下的新行

如何制作10,000年及以后的日期时间对象?

log 1 p numpy的意外行为

如何在python xsModel库中定义一个可选[December]字段,以产生受约束的SON模式

如何使用pytest来查看Python中是否存在class attribution属性?

在ubuntu上安装dlib时出错

改进大型数据集的框架性能

如何在Python中使用另一个数据框更改列值(列表)

从Windows Python脚本在WSL上运行Linux应用程序

lityter不让我输入左边的方括号,'

Flash只从html表单中获取一个值

如何找出Pandas 图中的连续空值(NaN)?

根据Pandas中带条件的两个列的值创建新列

Python—在嵌套列表中添加相同索引的元素,然后计算平均值

来自任务调度程序的作为系统的Python文件

有没有一种方法可以根据不同索引集的数组从2D数组的对称子矩阵高效地构造3D数组?

如何在Django查询集中生成带有值列表的带注释的字段?

多个布尔条件的`jax.lax.cond`等效项

合并Pandas中的数据帧,但处理不存在的列