我在R中使用了Keras API的一个简单应用程序.根据set.seed(value)的值,它有时会收敛,有时不会收敛.我假设是因为种子设定了最初随机的权重.如果一开始没有收敛,我通常可以通过更改种子值让它收敛到不同的运行,但我必须手动监视/停止它.如果模型在指定时间后没有收敛(例如,在600秒后停止,并使用不同的种子值重新启动),如何停止Keras.

  set.seed(42)
  x <- as.matrix(train_data)
  y <- as.matrix(train_targets)
  
  model = keras_model_sequential() %>%
    layer_dense(units=64, kernel_regularizer=regularizer_l2(0.001), activation="relu", input_shape=dim(train_data)[[2]]) %>%
    layer_dense(units=32, kernel_regularizer=regularizer_l2(0.001), activation = "relu") %>%
    layer_dense(units=1, activation="linear")
  
  model %>% compile(
    loss = "mse", 
    optimizer = "rmsprop",
    metrics = list("mae")
  )
  
  model %>% fit(x, y, epochs = 50,verbose = 0)

推荐答案

一种 Select 是定义一个调用自身的函数,可能在执行之前执行设置种子之类的操作.

library(keras)
d <- dataset_mnist()
x_train <- d$train$x
y_train <- d$train$y
x_test <- d$test$x
y_test <- d$test$y

x_train <- array_reshape(x_train, c(nrow(x_train), 784))
x_test <- array_reshape(x_test, c(nrow(x_test), 784))
x_train <- x_train / 255
x_test <- x_test / 255
y_train <- to_categorical(y_train, 10)
y_test <- to_categorical(y_test, 10)

model <- keras_model_sequential() 
model %>% 
  layer_dense(units = 256, activation = 'relu', input_shape = c(784)) %>% 
  layer_dropout(rate = 0.4) %>% 
  layer_dense(units = 128, activation = 'relu') %>%
  layer_dropout(rate = 0.3) %>%
  layer_dense(units = 10, activation = 'softmax')

model %>% compile(
  loss = 'categorical_crossentropy',
  optimizer = optimizer_rmsprop(),
  metrics = c('accuracy')
)

我们可以创建一个递归函数,在超时后调用它自己.

timed_fit <- function(t = 5) {
  Sys.sleep(1)
  set.seed(t)
  message("seed set to ", t)
  setTimeLimit(cpu = t, elapsed = t, transient = TRUE)
  on.exit({setTimeLimit(cpu = Inf, elapsed = Inf, transient = FALSE)})
  tryCatch({
    model %>% fit(
      x_train, y_train, 
      epochs = 4, batch_size = 128, 
      validation_split = 0.2
    )
  }, error = function(e) {
    if (grepl("reached elapsed time limit|reached CPU time limit", e$message)) {
      message("\n timed out!\n") # or set another seed, continue
      timed_fit(t + 10)
    } else {
      # error not related to timeout
      stop(e)
    }
  })
}

timed_fit()

插入Sys.sleep(1)是为了避免出现错误,尽管出现了错误,但不会正确中断流程.

enter image description here

R相关问答推荐

基于不同组的列的相关性

从开始时间和结束时间导出时间

在发布到PowerBI Service时,是否可以使用R脚本作为PowerBI的数据源?

如何在modelsummary中重命名统计数据?

如何使用STAT_SUMMARY向ggplot2中的密度图添加垂直线

如何改变x轴比例的列在面

在ggplot中为不同几何体使用不同的 colored颜色 比例

可以替代与NSE一起使用的‘any_of()’吗?

如何将R中数据帧中的任何Nas替换为最后4个值

基于R中的间隔扩展数据集行

从R中的对数正态分布生成随机数的正确方法

根据另一列中的值和条件查找新列的值

在纵向数据集中创建新行

如何将一列中的值拆分到R中各自的列中

提高圣彼得堡模拟的速度

删除数据帧中特定行号之间的每第三行和第四行

如何计算每12行的平均数?

创建新列,其中S列的值取决于该行S值是否与其他行冗余

使用其他DF中的文件名将列表中的每个元素保存到文件中

向数据添加标签