我正在try 计算我的xgBoost模型的Shap值:

X_train <- as.matrix(train_data[, !names(train_data) %in% c("min_column")])
Y_train<-train_data$min_column
Y_train <- as.integer(as.factor(Y_train)) - 1
Y_train = as.matrix(Y_train)
num_class <- length(unique(Y_train)

> params <- list(booster = "gbtree", 
                 objective = "multi:softmax",
                 eta=0.01, 
                 gamma=0.01, 
                 max_depth=2, 
                 subsample=1, 
                 num_class=num_class)

> xgb1 <- xgb.train(data = X_train, label = Y_train, verbose = FALSE, params = params, nrounds = 10)


> shap_values <- SHAPforxgboost::shap.values(xgb_model = xgb1, X_train = X_train)

但是,我总是收到以下错误:

Error in `colnames<-`(`*tmp*`, value = c(colnames(X_train), "BIAS")) : 
      attempt to set 'colnames' on an object with less than two dimensions

我发现,当我没有在参数列表中指定与多类分类模型(booster = "gbtree" + objective = "multi:softmax"+ num_class)相关的参数时,shap.Values函数就会起作用.

但如果我不在参数中指定它,我不确定模型是否会识别我想要的多类分类模型.

有没有人知道该怎么做?

推荐答案

{SHAPforxgBoost}不支持多输出模型.这是我们制作{shapviz}的原因之一:

library(xgboost)
library(shapviz)

params <- list(objective = "multi:softprob", num_class = 3, learning_rate = 0.2)
X_pred <- data.matrix(iris[, -5])
dtrain <- xgboost::xgb.DMatrix(X_pred, label = as.integer(iris[, 5]) - 1)
fit <- xgboost::xgb.train(
  params = params, 
  data = dtrain, 
  nrounds = 100
)

shap_values <- shapviz(fit, X_pred = X_pred)
names(shap_values) <- levels(iris$Species)

# Analyze one class
sv_importance(shap_values$versicolor, kind = "bee")
sv_dependence(shap_values$versicolor, v = colnames(iris[, -5]))

enter image description here enter image description here

# Or all together
sv_importance(shap_values, kind = "bee", alpha = 0.2)
sv_dependence(shap_values, v = "Sepal.Width")

enter image description here enter image description here

Shap交互也起作用:

shap_inter <- shapviz(fit, X_pred = X_pred, interactions = TRUE)
names(shap_inter) <- levels(iris$Species)

sv_interaction(shap_inter$versicolor, kind = "bee")
sv_dependence(
  shap_inter$versicolor, 
  v = "Sepal.Width",
  color_var = colnames(iris[, -5]),
  interactions = TRUE
)

enter image description here enter image description here

R相关问答推荐

大规模重新标记haven标签数据

如何求解arg必须为NULL或deSolve包的ode函数中的字符向量错误

如何在RMarkdown LaTex PDF输出中包含英语和阿拉伯语?

Highcharter多次钻取不起作用,使用不同方法

par函数中的缩写,比如mgp,mar,mai是如何被破译的?

如何调整曲线图中的y轴标签?

如何在分组条形图中移动相关列?

将二进制数据库转换为频率表

有没有办法使用ggText,<;Sub>;&;<;sup>;将上标和下标添加到同一元素?

如何根据R中其他变量的类别汇总值?

将多个列合并为一个列的有效方法是什么?

`-`是否也用于数据帧,有时使用引用调用?

如何根据未知数的多列排除重复行

使用ifElse语句在ggploy中设置aes y值

以任意顺序提取具有多个可能匹配项的组匹配项

如何从嵌套数据中自动创建命名对象?在R中

在不带max()的data.table中按组查找最后一个元素

R:水平旋转图

打印的.txt文件,将值显示为&Quot;Num&Quot;而不是值

通过不完全重叠的多个柱连接