try 使用tidyModels为分类问题生成Shap值时出现问题.

当我在网站https://github.com/ModelOriented/kernelshap上按照步骤训练我的模型后,try 计算Shap值时,我不能为分类问题报告它.我的目标变量必须是因子.它总是返回:

#check_pred(pred_Fun(Object,X,...),n=n)中出错: 运行函数kernelshap后,预测必须是数字

我用矩阵代替数据帧来寻找解决方案,并用Extra Fit Parnip提取模型.但问题仍然存在,除了分类,有没有办法重现这个例子. 示例代码如下



library(tidyverse)
library(tidymodels)


Default <- ISLR::Default


Default = Default %>%
  mutate(
    default = factor(case_when(
      default == "Yes" ~ 1,
      default == "No" ~ 0
    ), levels = c(1,0)
  ) 
)# changing to factor otherwise model will not work
  

# model fitting
split <- initial_split(Default)
  tr <- training(split)
  te <- testing(split)
  

  
  
rec <- recipe(default~., data = Default) %>% 
  step_dummy(all_nominal_predictors())




spec <- boost_tree() %>% 
  set_mode("classification") %>% 
  set_engine("xgboost")


wf <- workflow() %>% 
  add_model(spec) %>% 
  add_recipe(rec)

# model fit
mod <- fit(wf,tr)



library(kernelshap)

x <- rec %>% 
  prep %>% 
  bake(te %>% 
         slice_sample(n = 50)) %>% 
  select(-default) %>% 
  as.data.frame()
bg <- rec %>% prep %>% 
  bake(te %>% slice_sample(n = 10)) %>% 
  mutate(default = as.numeric(as.character(default))) %>% 
  as.data.frame()

# test for prediction
predict(mod, te)

# extract model form tidymodels
md <- extract_fit_parsnip(mod)

# this version works
kernelshap(md$fit, 
           X = x %>% as.matrix(), # if i do it with matrix structure then it works
           bg_X =bg %>% as.matrix()
           
           )

# this version does not work
kernelshap(mod, 
           X = te %>% 
             select(-default) %>%  # remove target var
             slice_sample(n = 50) %>% 
             as.data.frame(), 
           bg_X = te %>% 
             slice_sample(n = 50) %>% 
             as.data.frame()
           
)

################################### error messgae:

#Error in check_pred(pred_fun(object, X, ...), n = n) : 
#  Predictions must be numeric
######################################
              kernelshap(mod, 
                          X = te %>% 
                            select(-default) %>%  # remove target var
                            slice_sample(n = 50) %>% 
                            as.data.frame(), 
                          bg_X = te %>% 
                            slice_sample(n = 50) %>% 
                            as.data.frame()
              )

                          

# toy example from github page using tidymodels


library(tidymodels)
library(kernelshap)

iris_recipe <- iris %>%
  recipe(Sepal.Length ~ .)

reg <- linear_reg() %>%
  set_engine("lm")

iris_wf <- workflow() %>%
  add_recipe(iris_recipe) %>%
  add_model(reg)

fit <- iris_wf %>%
  fit(iris)

ks <- kernelshap(fit, iris[, -1], bg_X = iris)
ks

推荐答案

{kernelshap}旨在与TidyModels很好地配合使用.在您的情况下,您可以简单地写下:

library(kernelshap)
library(shapviz)

x <- c("student", "balance", "income")
ks <- kernelshap(
  mod, 
  X = head(Default, 1000),    # Assuming random row order
  bg_X = head(Default, 200),  # Assuming random row order
  type = "prob",              # Predictions must be numeric
  feature_names = x           # Or use X = head(Default[x], 1000)
)

sv <- shapviz(ks)             # Contains one shapviz object per class
sv_dependence(sv$.pred_1, v = x)
sv_importance(sv$.pred_1, kind = "bee", show_numbers = TRUE)
sv_importance(sv$.pred_1)

enter image description here enter image description here enter image description here

comments

  • 因为您的模型是通过XGBoost安装的,所以使用TreeSHAP会更自然,但实际上通过TidyModels会更复杂一些.
  • 我建议使用序号编码而不是虚拟编码.

并行使用

library(doFuture)

options(doFuture.rng.onMisuse = "ignore")  # To suppress some warning on random seeds

# Set up parallel backend
registerDoFuture()
plan(multisession, workers = 4)  # Windows
# plan(multicore, workers = 4)   # Linux, macOS, Solaris

x <- c("student", "balance", "income")
ks <- kernelshap(
  mod, 
  X = head(Default, 1000),   
  bg_X = head(Default, 200),
  type = "prob",
  feature_names = x,
  parallel = TRUE,
  parallel_args = list(.packages = "tidymodels")
)

最后一个参数是必需的,因为stats::predict()屏蔽了底层预测函数的源.

R相关问答推荐

在ComplexHeatmap中,如何更改anno_barplot()标题的Angular ?

从具有随机模式的字符串中提取值

在ggplot的注释表格中突出显示最大值

查找满足SpatRaster中条件的单元格位置

在R中查找每个组不同时间段的总天数

如何对数据集进行逆向工程?

使用tidyverse方法绑定行并从一组管道列表执行左连接

R中的子集文件—读取文件名索引为4位数字序列,例如0001到4000,而不是1到4000)

基于多列将值链接到NA

我不能在docker中加载sf

S用事件解决物质平衡问题

在R中使用download. file().奇怪的URL?

如何删除仅在数据集顶部和底部包含零的行

绘制采样开始和采样结束之间的事件

如何从容器函数中提取conf并添加到ggplot2中?

如何在ggplot2中绘制具有特定 colored颜色 的连续色轮

将多个变量组合成宽格式

按镜像列值自定义行顺序

在ggplot2图表中通过端点连接点

R:改进实现简单模型