我正在自学如何使用优秀的tidyModels包集合来练习机器学习.

在下面的例子中,我基本上是在try 复制Julie Sigle在这里(https://juliasilge.com/blog/water-sources/)发表的关于使用Ranger程序包预测水源的博客文章.

我没有在那篇博客中使用她的数据集,而是使用内置的钻石数据集作为练习.

当我try 根据预测绘制真相图时,我可以重新创建除yardmark::roc_curv()之外的所有集合.

我得到的错误如下所示

Error in `dplyr::summarise()`:
! Problem while computing `.estimate = metric_fn(...)`.
ℹ The error occurred in group 1: id = "Fold01".
Caused by error in `validate_class()`:
! `estimate` should be a numeric but a factor was supplied.

虽然数据集和转换步骤不同,但以下步骤大致对应于上面链接中的内容.

我认识到,从统计学上讲,可能有更有效或更好的方法来做到这一点,但我只是试图更熟悉这些工具和包,并获得使用它们的经验.

library(tidyverse)
library(tidymodels)

# set a outcome variable that I want to try and predict (e.g. price is above $10,000)
diamonds <- diamonds %>% 
  mutate(high_price_indicator=if_else(price>10000,"high","low"))

#split data sets
data_split <- rsample::initial_split(diamonds,strata = high_price_indicator)

training_split <- rsample::training(data_split)
testing_split <- rsample::testing(data_split)

# cross fold 
diamonds_fold <- rsample::vfold_cv(training_split,strata=high_price_indicator)

#choose model, set engine and mode
rf_spec <- parsnip::rand_forest(trees = 1000) %>% 
  set_mode("classification") %>% 
  set_engine("ranger")

#set recipe and do some transformations - not sure if the error is here
rec <- recipes::recipe(high_price_indicator ~., data=training_split) %>%
  recipes::step_normalize(all_numeric_predictors()) %>% 
  step_zv(all_predictors(),) %>% 
  step_dummy(c("cut","color","clarity"),one_hot = TRUE)


# create the workflow

workflow <- workflow() %>% 
  add_model(rf_spec) %>% 
  add_recipe(rec)

# fit workflow to cross folded data and save predictions
fit_folds <- tune::fit_resamples(workflow,
    resamples = diamonds_fold,
    control = control_resamples(save_pred = TRUE)
  )

# this is where I get the error
collect_predictions(fit_folds) %>%
  group_by(id) %>%
  roc_curve(high_price_indicator, .pred_class) %>%
  autoplot()

感谢任何人的指导!

下面是我的步骤.如果有人能帮助我理解我在将预测与结果变量进行对比时的错误之处,我将不胜感激.

推荐答案

好的,我想通了.我试图将两个分类变量相互对比,但ROC_CUVE需要一个真值列和一个包含概率的列.

通过取消重新抽样表fit_folds中的.predictions列的嵌套,您可以看到有三列结果为.pred_high.pred_low.pred_class.highlow标签对应于high_price_indicator列.

.pred_class具有预测的特征结果,.pred_low.pred_high具有概率结果.在Julia Silge的示例中,这些列表示为.pred_npred_y.

所以,当你在真值列上画出一个数字概率列时,你就得到了这个图.

以下是代码

collect_predictions(fit_folds) %>%
  group_by(id) %>%
  roc_curve(high_price_indicator,.pred_high) %>%
  autoplot()

R相关问答推荐

是否有R代码来判断一个组中的所有值是否与另一个组中的所有值相同?

在边界外添加注释或标题

根据列表中项目的名称多次合并数据框和列表

基于现有类创建类的打印方法(即,打印tibles更长时间)

如何修复R码的置换部分?

用derrr在R中查找组间的重复项

如何在ggplot中标记qqplot上的点?

为什么我的基准测试会随着样本量的增加而出现一些波动?

将二进制数据库转换为频率表

计算两列中满足特定条件连续行之间的平均值

在保留列表元素属性的同时替换列表元素

为什么这个表格格罗布不打印?

如何阻止围堵地理密度图?

如何根据未知数的多列排除重复行

为R中的16组参数生成10000个样本的有效方法是什么?

Ggplot2如何找到存储在对象中的残差和拟合值?

R-如何在ggplot2中显示具有不同x轴值(日期)的多行?

使用显式二元谓词子集化sfc对象时出错

将字符变量出现次数不相等的字符框整形为pivot_wider,而不删除重复名称或嵌套字符变量

如果y中存在x中的值,则将y行中的多个值复制到相应的x行中