让我们假设我们有以下简化的代码:
import pandas as pd
import shap
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
mylabel =LabelEncoder()
data =pd.read_csv("https://raw.githubusercontent.com/krishnaik06/Multiple-Linear-Regression/master/50_Startups.csv")
data['State'] =mylabel.fit_transform(data['State'])
print(data.head())
model =RandomForestRegressor()
y =data['Profit']
X =data.drop('Profit',axis=1)
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.1,random_state=1)
model.fit(X_train,y_train)
explainer =shap.TreeExplainer(model)
shap_values =explainer.shap_values(X_train)
plt.figure(figsize=(30,30))
plt.subplot(2,1,1)
shap.summary_plot(shap_values, X_train, feature_names=X.columns, plot_type="bar")
plt.subplot(2,1,2)
shap.summary_plot(shap_values, X_train, feature_names=X.columns)
plt.show()
when i run this code, i am getting two image on different figure : one image :
我想要把它们一个接一个地画出来,就像你看到的,我用了子图:
plt.subplot(2,1,1)
shap.summary_plot(shap_values, X_train, feature_names=X.columns, plot_type="bar")
plt.subplot(2,1,2)
shap.summary_plot(shap_values, X_train, feature_names=X.columns)
但它不起作用,我试图使用以下代码:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10,10))
shap.dependence_plot('age', shap_values[1], X_train, ax=axes[0, 0], show=False)
shap.dependence_plot('income', shap_values[1], X_train, ax=axes[0, 1], show=False)
shap.dependence_plot('score', shap_values[1], X_train, ax=axes[1, 0], show=False)
plt.show()
但是SUMMARY_PLOT没有参数ax,那么我如何使用它呢?