From 9cb7fe3e137e7d1f59fd86092bfb634b485c5d89 Mon Sep 17 00:00:00 2001 From: Kyle Belanger Date: Thu, 2 Feb 2023 13:23:28 -0500 Subject: [PATCH] Update 2-modeling.R --- ML/2-modeling.R | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/ML/2-modeling.R b/ML/2-modeling.R index 0d27fbc..1245e7e 100644 --- a/ML/2-modeling.R +++ b/ML/2-modeling.R @@ -12,8 +12,10 @@ box::use( ,rsample ,r = recipes ,wf = workflows - ,p = parsnip + ,p = parsnip[tune] ,ys = yardstick + ,d = dials + ,rsamp = rsample ) @@ -47,11 +49,11 @@ table(ds_test$ft4_dia) %>% prop.table() class_train <- ds_train %>% dplyr::select(-FT4) # training data for classification models reg_train <- ds_train %>% dplyr::select(-ft4_dia) # training data for reg models predicting result -# random forest ----------------------------------------------------------- +# random forest classification ----------------------------------------------------------- # base model - No Hyper Tuning -rf_model <- p$rand_forest(trees = 1900) %>% +rf__base_model <- p$rand_forest() %>% p$set_engine("ranger") %>% p$set_mode("classification") rf_recipe <- r$recipe(ft4_dia ~ . , data = class_train) %>% @@ -61,27 +63,38 @@ rf_recipe <- r$recipe(ft4_dia ~ . , data = class_train) %>% rf_workflow <- wf$workflow() %>% - wf$add_model(rf_model) %>% + wf$add_model(rf__base_model) %>% wf$add_recipe(rf_recipe) -rf_fit <- p$fit(rf_workflow, class_train) +rf_base_fit <- p$fit(rf_workflow, class_train) rf_predict <- class_train %>% dplyr::select(ft4_dia) %>% dplyr::bind_cols( - predict(rf_fit, class_train) - ,predict(rf_fit, class_train, type = "prob") + predict(rf_base_fit, class_train) + ,predict(rf_base_fit, class_train, type = "prob") ) conf_mat_rf <- ys$conf_mat(rf_predict, ft4_dia, .pred_class) -# explainer_rf <- DALEXtra::explain_tidymodels( -# rf_fit -# ,data = class_train -# ,y = class_train$ft4_dia -# ) -# this takes awhile to run -#vip_lm <- DALEX::model_parts(explainer_rf) -#plot(vip_lm) +rf_pred <- dplyr::select(class_train, -ft4_dia, -subject_id, -charttime) + +rf_tuning_model <- p$rand_forest(trees = tune(), mtry = tune(), min_n = tune()) %>% + p$set_engine("ranger") %>% p$set_mode("classification") + +rf_param <- p$extract_parameter_set_dials(rf_tuning_model) + +rf_param <- rf_param %>% update(mtry = d$finalize(d$mtry(), rf_pred)) + +data_fold <- rsamp$vfold_cv(class_train, v = 5) + +rf_workflow <- wf$update_model(rf_workflow, rf_tuning_model) + +rf_tune <- rf_workflow %>% + tune::tune_grid( + data_fold + ,grid = rf_param %>% d$grid_regular() + ) +