#machine-learning #survival-analysis #mlr3
Вопрос:
Я использую mlr3proba
пакет для анализа выживаемости машинного обучения.
Мой набор данных содержит факторные, числовые и целочисленные объекты.
Я использовал конвейеры «масштабирование» и «кодирование» для предварительной обработки моего набора данных для методов нейронной сети deephit и deepsurv в следующих кодах:
task.mlr <- TaskSurv$new(id = "id", backend = dataset, time = time, event = status)
inner.rsmp <- rsmp("cv", folds = 5)
measure <- msr("surv.cindex")
tuner <- tnr("random_search")
terminator <- trm("evals", n_evals = 30)
deephit.learner <- lrn("surv.deephit", optimizer = "adam", epochs = 50)
nn.search_space <- ps(dropout = p_dbl(lower = 0, upper = 1),alpha = p_dbl(lower = 0, upper = 1))
deephit.learner <- po("encode") %>>% po("scale") %>>% po("learner", deephit.learner)
deephit.instance <- TuningInstanceSingleCrit$new(
task = task.mlr,
learner = deephit.learner,
search_space = nn.search_space,
resampling = inner.rsmp,
measure = measure,
terminator = terminator
)
tuner$optimize(deephit.instance)
Но когда я запускаю последнюю строку, она показывает ошибку ниже:
Error in self$assert(xs):
Assertion on 'xs' failed: Parameter 'dropout' not available. Did you mean 'encode.method'/'encode.affect_columns' / 'scale.center'?.
Я действительно ценю вашу помощь.
Ответ №1:
Привет, спасибо за использование mlr3proba! Причина этого заключается в том, что имена параметров меняются при включении в конвейер, вы можете увидеть это в примере ниже. Есть несколько вариантов решения этой проблемы: вы можете изменить идентификаторы параметров в соответствии с новыми именами после переноса в конвейеры (вариант 1 ниже), или вы можете сначала указать диапазоны настройки для учащегося, а затем перенести его в конвейер (вариант 2 ниже), или вы можете использовать автонастройку и перенести это в конвейеры. Я использую последний вариант в этом уроке.
library(mlr3proba)
library(mlr3)
library(paradox)
library(mlr3tuning)
library(mlr3extralearners)
library(mlr3pipelines)
task.mlr <- tsk("rats")
inner.rsmp <- rsmp("holdout")
measure <- msr("surv.cindex")
tuner <- tnr("random_search")
terminator <- trm("evals", n_evals = 2)
###########
# Option 1
###########
deephit.learner <- lrn("surv.deephit", optimizer = "adam", epochs = 50)
deephit.learner <- po("encode") %>>% po("scale") %>>% po("learner", deephit.learner)
deephit.learner$param_set$ids()
#> [1] "encode.method" "encode.affect_columns"
#> [3] "scale.center" "scale.scale"
#> [5] "scale.robust" "scale.affect_columns"
#> [7] "surv.deephit.frac" "surv.deephit.cuts"
#> [9] "surv.deephit.cutpoints" "surv.deephit.scheme"
#> [11] "surv.deephit.cut_min" "surv.deephit.num_nodes"
#> [13] "surv.deephit.batch_norm" "surv.deephit.dropout"
#> [15] "surv.deephit.activation" "surv.deephit.custom_net"
#> [17] "surv.deephit.device" "surv.deephit.mod_alpha"
#> [19] "surv.deephit.sigma" "surv.deephit.shrink"
#> [21] "surv.deephit.optimizer" "surv.deephit.rho"
#> [23] "surv.deephit.eps" "surv.deephit.lr"
#> [25] "surv.deephit.weight_decay" "surv.deephit.learning_rate"
#> [27] "surv.deephit.lr_decay" "surv.deephit.betas"
#> [29] "surv.deephit.amsgrad" "surv.deephit.lambd"
#> [31] "surv.deephit.alpha" "surv.deephit.t0"
#> [33] "surv.deephit.momentum" "surv.deephit.centered"
#> [35] "surv.deephit.etas" "surv.deephit.step_sizes"
#> [37] "surv.deephit.dampening" "surv.deephit.nesterov"
#> [39] "surv.deephit.batch_size" "surv.deephit.epochs"
#> [41] "surv.deephit.verbose" "surv.deephit.num_workers"
#> [43] "surv.deephit.shuffle" "surv.deephit.best_weights"
#> [45] "surv.deephit.early_stopping" "surv.deephit.min_delta"
#> [47] "surv.deephit.patience" "surv.deephit.interpolate"
#> [49] "surv.deephit.inter_scheme" "surv.deephit.sub"
nn.search_space <- ps(surv.deephit.dropout = p_dbl(lower = 0, upper = 1),
surv.deephit.alpha = p_dbl(lower = 0, upper = 1))
deephit.instance <- TuningInstanceSingleCrit$new(
task = task.mlr,
learner = deephit.learner,
search_space = nn.search_space,
resampling = inner.rsmp,
measure = measure,
terminator = terminator
)
tuner$optimize(deephit.instance)
#> INFO [08:15:29.770] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=2]'
#> INFO [08:15:29.841] [bbotk] Evaluating 1 configuration(s)
#> INFO [08:15:30.115] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [08:15:30.314] [mlr3] Applying learner 'encode.scale.surv.deephit' on task 'rats' (iter 1/1)
#> INFO [08:15:39.997] [mlr3] Finished benchmark
#> INFO [08:15:40.296] [bbotk] Result of batch 1:
#> INFO [08:15:40.302] [bbotk] surv.deephit.dropout surv.deephit.alpha surv.harrell_c
#> INFO [08:15:40.302] [bbotk] 0.06494213 0.7109244 0.7516212
#> INFO [08:15:40.302] [bbotk] uhash
#> INFO [08:15:40.302] [bbotk] 27794d84-ba46-4900-8835-de24fcda8c7f
#> INFO [08:15:40.307] [bbotk] Evaluating 1 configuration(s)
#> INFO [08:15:40.395] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [08:15:40.406] [mlr3] Applying learner 'encode.scale.surv.deephit' on task 'rats' (iter 1/1)
#> INFO [08:15:41.807] [mlr3] Finished benchmark
#> INFO [08:15:41.903] [bbotk] Result of batch 2:
#> INFO [08:15:41.905] [bbotk] surv.deephit.dropout surv.deephit.alpha surv.harrell_c
#> INFO [08:15:41.905] [bbotk] 0.05524693 0.2895437 0.7749676
#> INFO [08:15:41.905] [bbotk] uhash
#> INFO [08:15:41.905] [bbotk] 013795a3-766c-48f9-a3fe-2aae5d4cad48
#> INFO [08:15:41.918] [bbotk] Finished optimizing after 2 evaluation(s)
#> INFO [08:15:41.919] [bbotk] Result:
#> INFO [08:15:41.920] [bbotk] surv.deephit.dropout surv.deephit.alpha learner_param_vals x_domain
#> INFO [08:15:41.920] [bbotk] 0.05524693 0.2895437 <list[6]> <list[2]>
#> INFO [08:15:41.920] [bbotk] surv.harrell_c
#> INFO [08:15:41.920] [bbotk] 0.7749676
#> surv.deephit.dropout surv.deephit.alpha learner_param_vals x_domain
#> 1: 0.05524693 0.2895437 <list[6]> <list[2]>
#> surv.harrell_c
#> 1: 0.7749676
###########
# Option 2
###########
deephit.learner <- lrn("surv.deephit", optimizer = "adam", epochs = 50)
deephit.learner$param_set$values = list(
dropout = to_tune(0, 1),
alpha = to_tune(0, 1)
)
deephit.learner <- po("encode") %>>%
po("scale") %>>%
po("learner", deephit.learner)
deephit.learner = GraphLearner$new(deephit.learner)
tuned.deephit = tune_nested(
method = "random_search",
task = task.mlr,
learner = deephit.learner,
inner_resampling = rsmp("holdout"),
outer_resampling = rsmp("holdout"),
measure = msr("surv.cindex"),
term_evals = 2
)
#> INFO [08:15:43.167] [mlr3] Applying learner 'encode.scale.surv.deephit.tuned' on task 'rats' (iter 1/1)
#> INFO [08:15:43.477] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorRunTime> [secs=2]'
#> INFO [08:15:43.495] [bbotk] Evaluating 1 configuration(s)
#> INFO [08:15:43.565] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [08:15:43.575] [mlr3] Applying learner 'encode.scale.surv.deephit' on task 'rats' (iter 1/1)
#> INFO [08:15:44.969] [mlr3] Finished benchmark
#> INFO [08:15:45.058] [bbotk] Result of batch 1:
#> INFO [08:15:45.064] [bbotk] surv.deephit.dropout surv.deephit.alpha surv.harrell_c
#> INFO [08:15:45.064] [bbotk] 0.3492627 0.2304623 0.6745362
#> INFO [08:15:45.064] [bbotk] uhash
#> INFO [08:15:45.064] [bbotk] 4ce96658-4d4a-4835-9d9f-a93398471aed
#> INFO [08:15:45.069] [bbotk] Evaluating 1 configuration(s)
#> INFO [08:15:45.127] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [08:15:45.136] [mlr3] Applying learner 'encode.scale.surv.deephit' on task 'rats' (iter 1/1)
#> INFO [08:15:46.064] [mlr3] Finished benchmark
#> INFO [08:15:46.171] [bbotk] Result of batch 2:
#> INFO [08:15:46.176] [bbotk] surv.deephit.dropout surv.deephit.alpha surv.harrell_c
#> INFO [08:15:46.176] [bbotk] 0.1118406 0.7810053 0.6020236
#> INFO [08:15:46.176] [bbotk] uhash
#> INFO [08:15:46.176] [bbotk] 6a065d27-a7e0-4e72-8e1e-6151408510cf
#> INFO [08:15:46.186] [bbotk] Finished optimizing after 2 evaluation(s)
#> INFO [08:15:46.187] [bbotk] Result:
#> INFO [08:15:46.191] [bbotk] surv.deephit.dropout surv.deephit.alpha learner_param_vals x_domain
#> INFO [08:15:46.191] [bbotk] 0.3492627 0.2304623 <list[4]> <list[2]>
#> INFO [08:15:46.191] [bbotk] surv.harrell_c
#> INFO [08:15:46.191] [bbotk] 0.6745362
Создано 2021-04-26 пакетом reprex (v0.3.0)