updates
This commit is contained in:
parent
66faa69393
commit
c55d1379d8
3 changed files with 129 additions and 127 deletions
|
@ -178,26 +178,3 @@ saveRDS(
|
|||
screen_workflows, here::here("ML", "outputs", "workflowscreen.rds")
|
||||
,compress = TRUE)
|
||||
|
||||
# grid search -------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
grid_ctrl <-
|
||||
tune$control_grid(
|
||||
save_pred = TRUE,
|
||||
parallel_over = "everything",
|
||||
save_workflow = TRUE
|
||||
)
|
||||
|
||||
|
||||
|
||||
grid_results <-
|
||||
all_workflows %>%
|
||||
workflowsets::workflow_map(
|
||||
seed = 070823
|
||||
,resamples = data_folds
|
||||
,grid = 25
|
||||
,control = grid_ctrl
|
||||
,verbose = TRUE
|
||||
)
|
||||
|
|
206
ML/2-modeling.R
206
ML/2-modeling.R
|
@ -1,3 +1,4 @@
|
|||
# Random Forests are the best models for both types finalize grid searchs
|
||||
rm(list = ls(all.names = TRUE)) # Clear the memory of variables from previous run.
|
||||
cat("\014") # Clear the console
|
||||
|
||||
|
@ -27,7 +28,8 @@ set.seed(070823) #set seed for reproducible research
|
|||
|
||||
# load-data ---------------------------------------------------------------
|
||||
|
||||
model_data <- readr$read_rds(here("ML","data-unshared","model_data.RDS"))
|
||||
model_data <- readr$read_rds(here("ML","data-unshared","model_data.RDS")) %>%
|
||||
dplyr::select(-subject_id, -charttime)
|
||||
|
||||
|
||||
# split data --------------------------------------------------------------
|
||||
|
@ -45,52 +47,30 @@ ds_test <- rsample$testing(model_data_split)
|
|||
table(ds_train$ft4_dia) %>% prop.table()
|
||||
table(ds_test$ft4_dia) %>% prop.table()
|
||||
|
||||
|
||||
class_train <- ds_train %>% dplyr::select(-FT4) # training data for classification models
|
||||
reg_train <- ds_train %>% dplyr::select(-ft4_dia) # training data for reg models predicting result
|
||||
|
||||
# random forest classification -----------------------------------------------------------
|
||||
|
||||
# base model - No Hyper Tuning
|
||||
|
||||
rf__base_model <- p$rand_forest() %>%
|
||||
p$set_engine("ranger") %>% p$set_mode("classification")
|
||||
|
||||
rf_recipe <- r$recipe(ft4_dia ~ . , data = class_train) %>%
|
||||
r$update_role(subject_id, new_role = "id") %>%
|
||||
r$update_role(charttime, new_role = "time") %>%
|
||||
rf_recipe <- r$recipe(ft4_dia ~ . , data = ds_train) %>%
|
||||
r$step_rm(FT4) %>%
|
||||
r$step_impute_bag(r$all_predictors())
|
||||
|
||||
|
||||
rf_workflow <- wf$workflow() %>%
|
||||
wf$add_model(rf__base_model) %>%
|
||||
wf$add_recipe(rf_recipe)
|
||||
|
||||
rf_base_fit <- p$fit(rf_workflow, class_train)
|
||||
|
||||
rf_predict <- class_train %>%
|
||||
dplyr::select(ft4_dia) %>%
|
||||
dplyr::bind_cols(
|
||||
predict(rf_base_fit, class_train)
|
||||
,predict(rf_base_fit, class_train, type = "prob")
|
||||
)
|
||||
|
||||
conf_mat_rf <- ys$conf_mat(rf_predict, ft4_dia, .pred_class)
|
||||
|
||||
|
||||
|
||||
rf_pred <- dplyr::select(class_train, -ft4_dia, -subject_id, -charttime)
|
||||
|
||||
rf_tuning_model <- p$rand_forest(trees = tune(), mtry = tune(), min_n = tune()) %>%
|
||||
p$set_engine("ranger") %>% p$set_mode("classification")
|
||||
|
||||
|
||||
rf_workflow <- wf$workflow() %>%
|
||||
wf$add_model(rf_tuning_model) %>%
|
||||
wf$add_recipe(rf_recipe)
|
||||
|
||||
|
||||
rf_param <- p$extract_parameter_set_dials(rf_tuning_model)
|
||||
|
||||
rf_param <- rf_param %>% update(mtry = d$finalize(d$mtry(), rf_pred))
|
||||
rf_param <- rf_param %>% update(mtry = d$finalize(d$mtry(), ds_train))
|
||||
|
||||
data_fold <- rsamp$vfold_cv(ds_train, v = 5)
|
||||
|
||||
data_fold <- rsamp$vfold_cv(class_train, v = 5)
|
||||
|
||||
rf_workflow <- wf$update_model(rf_workflow, rf_tuning_model)
|
||||
|
||||
# takes around 1 hr to run grid search. saving best params manaually
|
||||
# rf_tune <- rf_workflow %>%
|
||||
|
@ -105,82 +85,100 @@ rf_best_params <- tibble::tibble(
|
|||
,min_n = 2
|
||||
)
|
||||
|
||||
final_rf_workflow <- rf_workflow %>%
|
||||
tune::finalize_workflow(rf_best_params)
|
||||
|
||||
final_rf_fit <- p$fit(final_rf_workflow, class_train)
|
||||
|
||||
final_rf_predict <- class_train %>%
|
||||
dplyr::select(ft4_dia) %>%
|
||||
dplyr::bind_cols(
|
||||
predict(final_rf_fit, class_train)
|
||||
,predict(final_rf_fit, class_train, type = "prob")
|
||||
rf_best_params_screen <-
|
||||
tibble::tibble(
|
||||
mtry = 7
|
||||
,trees = 763
|
||||
,min_n = 15
|
||||
)
|
||||
|
||||
final_rf_workflow <- rf_workflow %>%
|
||||
tune::finalize_workflow(rf_best_params_screen)
|
||||
|
||||
# Final Fit training data
|
||||
|
||||
final_rf_fit <- p$fit(final_rf_workflow, ds_train)
|
||||
|
||||
final_rf_predict <- ds_train %>%
|
||||
dplyr::select(ft4_dia) %>%
|
||||
dplyr::bind_cols(
|
||||
predict(final_rf_fit, ds_train)
|
||||
,predict(final_rf_fit, ds_train, type = "prob")
|
||||
)
|
||||
|
||||
ys$accuracy(final_rf_predict,truth = ft4_dia, estimate = .pred_class )
|
||||
|
||||
final_conf_rf <- ys$conf_mat(final_rf_predict, ft4_dia, .pred_class)
|
||||
|
||||
# fitting test data
|
||||
|
||||
class_test_results <-
|
||||
final_rf_fit %>%
|
||||
tune::last_fit(split = model_data_split)
|
||||
|
||||
class_test_result_conf_matrix <- ys$conf_mat(
|
||||
class_test_results %>% tune::collect_predictions()
|
||||
,truth = ft4_dia
|
||||
,estimate = .pred_class
|
||||
)
|
||||
|
||||
|
||||
|
||||
# random forest regression ------------------------------------------------
|
||||
#
|
||||
# reg_metrics <- ys$metric_set(ys$rmse, ys$rsq, ys$mae)
|
||||
#
|
||||
# rf_base_reg_model <- p$rand_forest() %>%
|
||||
# p$set_engine("ranger") %>% p$set_mode("regression")
|
||||
#
|
||||
# rf_reg_recipe <- r$recipe(FT4 ~ . , data = reg_train) %>%
|
||||
# r$update_role(subject_id, new_role = "id") %>%
|
||||
# r$update_role(charttime, new_role = "time") %>%
|
||||
# r$step_impute_bag(r$all_predictors())
|
||||
#
|
||||
#
|
||||
# rf_reg_workflow <- wf$workflow() %>%
|
||||
# wf$add_model(rf_base_reg_model) %>%
|
||||
# wf$add_recipe(rf_reg_recipe)
|
||||
#
|
||||
# rf_base_reg_fit <- p$fit(rf_reg_workflow, reg_train)
|
||||
#
|
||||
# rf_reg_predict <- reg_train %>%
|
||||
# dplyr::select(FT4) %>%
|
||||
# dplyr::bind_cols(
|
||||
# predict(rf_base_reg_fit, reg_train)
|
||||
reg_metrics <- ys$metric_set(ys$rmse, ys$rsq, ys$mae)
|
||||
|
||||
rf_reg_tune_model <- p$rand_forest(trees = tune(), mtry = tune(), min_n = tune()) %>%
|
||||
p$set_engine("ranger") %>% p$set_mode("regression")
|
||||
|
||||
rf_reg_recipe <- r$recipe(FT4 ~ . , data = reg_train) %>%
|
||||
r$step_rm(ft4_dia) %>%
|
||||
r$step_impute_bag(r$all_predictors())
|
||||
|
||||
|
||||
rf_reg_workflow <- wf$workflow() %>%
|
||||
wf$add_model(rf_reg_tune_model) %>%
|
||||
wf$add_recipe(rf_reg_recipe)
|
||||
|
||||
|
||||
rf_reg_param <- p$extract_parameter_set_dials(rf_reg_tune_model) %>%
|
||||
update(mtry = d$finalize(d$mtry(), reg_train))
|
||||
|
||||
data_fold_reg <- rsamp$vfold_cv(reg_train, v = 5)
|
||||
|
||||
# takes around 1 hr to run grid search. saving best params manaually
|
||||
# rf_reg_tune <- rf_reg_workflow %>%
|
||||
# tune::tune_grid(
|
||||
# data_fold_reg
|
||||
# ,grid = rf_reg_param %>% d$grid_regular()
|
||||
# )
|
||||
|
||||
rf_reg_best_params <- tibble::tibble(
|
||||
mtry = 8
|
||||
,trees = 1000
|
||||
,min_n = 2
|
||||
)
|
||||
|
||||
final_rf_reg_workflow <- rf_reg_workflow %>%
|
||||
tune::finalize_workflow(rf_reg_best_params)
|
||||
|
||||
final_rf_reg_fit <- p$fit(final_rf_reg_workflow, reg_train)
|
||||
|
||||
|
||||
#
|
||||
# reg_metrics(rf_reg_predict, truth = FT4, estimate = .pred)
|
||||
#
|
||||
# rf_reg_tune_model <- p$rand_forest(trees = tune(), mtry = tune(), min_n = tune()) %>%
|
||||
# p$set_engine("ranger") %>% p$set_mode("regression")
|
||||
#
|
||||
# rf_reg_pred <- dplyr::select(reg_train, -FT4, -subject_id, -charttime)
|
||||
#
|
||||
# rf_reg_param <- p$extract_parameter_set_dials(rf_reg_tune_model) %>%
|
||||
# update(mtry = d$finalize(d$mtry(), rf_reg_pred))
|
||||
#
|
||||
# data_fold_reg <- rsamp$vfold_cv(reg_train, v = 5)
|
||||
#
|
||||
# rf_reg_workflow <- wf$update_model(rf_reg_workflow, rf_reg_tune_model)
|
||||
#
|
||||
# # takes around 1 hr to run grid search. saving best params manaually
|
||||
# # rf_reg_tune <- rf_reg_workflow %>%
|
||||
# # tune::tune_grid(
|
||||
# # data_fold_reg
|
||||
# # ,grid = rf_reg_param %>% d$grid_regular()
|
||||
# # )
|
||||
#
|
||||
# rf_reg_best_params <- tibble::tibble(
|
||||
# mtry = 8
|
||||
# ,trees = 1000
|
||||
# ,min_n = 2
|
||||
# )
|
||||
#
|
||||
# final_rf_reg_workflow <- rf_reg_workflow %>%
|
||||
# tune::finalize_workflow(rf_reg_best_params)
|
||||
#
|
||||
# final_rf_reg_fit <- p$fit(final_rf_reg_workflow, reg_train)
|
||||
#
|
||||
# final_rf_reg_predict <- reg_train %>%
|
||||
# dplyr::select(FT4) %>%
|
||||
# dplyr::bind_cols(
|
||||
# predict(final_rf_reg_fit, reg_train)
|
||||
# )
|
||||
#
|
||||
# reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred)
|
||||
final_rf_reg_predict <- reg_train %>%
|
||||
dplyr::select(FT4) %>%
|
||||
dplyr::bind_cols(
|
||||
predict(final_rf_reg_fit, reg_train)
|
||||
)
|
||||
|
||||
reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred)
|
||||
|
||||
|
||||
reg_test_results <-
|
||||
final_rf_reg_workflow %>%
|
||||
tune::last_fit()
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ set.seed(070823) #set seed for reproducible research
|
|||
# load-data ---------------------------------------------------------------
|
||||
|
||||
screen_workflows_reg <- readr::read_rds(here("ML","outputs","workflowscreen_reg.rds"))
|
||||
screen_workflows_class <- readr::read_rds(here("ML","outputs","workflowscreen_class.rds"))
|
||||
|
||||
|
||||
|
||||
|
@ -56,6 +57,32 @@ ggplot2::autoplot(
|
|||
ggplot2::scale_color_manual(values = rep("black", times = 5)) +
|
||||
ggplot2::theme(legend.position = "none")
|
||||
|
||||
class_results <- screen_workflows_class %>%
|
||||
workflowsets::rank_results()
|
||||
|
||||
|
||||
ggplot2::autoplot(
|
||||
screen_workflows_class
|
||||
,rank_metric = "roc_auc"
|
||||
,metric = "roc_auc"
|
||||
,select_best = TRUE
|
||||
) +
|
||||
ggplot2::geom_text(ggplot2::aes(y = mean, label = wflow_id)
|
||||
# ,angle = 90
|
||||
,hjust = -0.2
|
||||
) +
|
||||
ggplot2::theme_bw() +
|
||||
ggplot2::scale_color_manual(values = rep("black", times = 5)) +
|
||||
ggplot2::theme(legend.position = "none")
|
||||
|
||||
|
||||
|
||||
# best results ------------------------------------------------------------
|
||||
|
||||
best_class_result <-
|
||||
screen_workflows_class %>%
|
||||
workflowsets::extract_workflow_set_result("forests_RF") %>%
|
||||
tune::select_best(metric = "accuracy")
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue