This commit is contained in:
Kyle Belanger 2023-03-17 16:43:13 -04:00
parent add4d4a9f6
commit 38c07f1060

View file

@ -54,6 +54,8 @@ ds_test <- ds_test %>% dplyr::select(-ft4_dia)
data_folds <- rsamp$vfold_cv(ds_train, repeats = 5) data_folds <- rsamp$vfold_cv(ds_train, repeats = 5)
pred <- dplyr::select(ds_train, -FT4, -subject_id, -charttime)
# recipes ------------------------------------------------------------------ # recipes ------------------------------------------------------------------
@ -62,8 +64,9 @@ normalized_rec <- r$recipe(FT4 ~ ., data = ds_train) %>%
r$update_role(subject_id, new_role = "id") %>% r$update_role(subject_id, new_role = "id") %>%
r$update_role(charttime, new_role = "time") %>% r$update_role(charttime, new_role = "time") %>%
r$step_impute_bag(r$all_predictors()) %>% r$step_impute_bag(r$all_predictors()) %>%
r$step_BoxCox(r$all_numeric()) %>% r$step_dummy(r$all_nominal_predictors()) %>%
r$step_corr(r$all_numeric_predictors()) %>% r$step_corr(r$all_numeric_predictors()) %>%
r$step_log(r$all_numeric()) %>%
r$step_normalize(r$all_numeric()) r$step_normalize(r$all_numeric())
@ -99,7 +102,7 @@ rf_spec <-
xgb_spec <- xgb_spec <-
p$boost_tree(tree_depth = tune(), learn_rate = tune(), loss_reduction = tune(), p$boost_tree(tree_depth = tune(), learn_rate = tune(), loss_reduction = tune(),
min_n = tune(), sample_size = tune(), trees = tune()) %>% min_n = tune(), sample_size = tune(), trees = tune()) %>%
p$set_engine("xgboost") %>% p$set_engine("xgboost") %>%
p$set_mode("regression") p$set_mode("regression")
@ -111,6 +114,13 @@ nnet_param <-
update(hidden_units = d$hidden_units(c(1, 27))) update(hidden_units = d$hidden_units(c(1, 27)))
rf_parma <-
rf_spec %>%
tune$extract_parameter_set_dials() %>%
update(mtry = d$finalize(d$mtry(), pred))
# workflows --------------------------------------------------------------- # workflows ---------------------------------------------------------------
@ -131,3 +141,23 @@ forests <-
all_workflows <- all_workflows <-
dplyr::bind_rows(normalized, forests) %>% dplyr::bind_rows(normalized, forests) %>%
dplyr::mutate(wflow_id = gsub("(forests_)|(normalized_)", "", wflow_id)) dplyr::mutate(wflow_id = gsub("(forests_)|(normalized_)", "", wflow_id))
# grid search -------------------------------------------------------------
grid_ctrl <-
tune$control_grid(
save_pred = TRUE,
parallel_over = "everything",
save_workflow = TRUE
)
grid_results <-
all_workflows %>%
workflowsets::workflow_map(
seed = 070823
,resamples = data_folds
,grid = 25
,control = grid_ctrl
)