From 2c5049d5380c03256679ac638572f10dd778d8b1 Mon Sep 17 00:00:00 2001 From: Kyle Belanger Date: Wed, 19 Apr 2023 06:26:37 -0400 Subject: [PATCH] Update 2-modeling.R --- ML/2-modeling.R | 47 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/ML/2-modeling.R b/ML/2-modeling.R index 34a245f..218ffe3 100644 --- a/ML/2-modeling.R +++ b/ML/2-modeling.R @@ -53,7 +53,6 @@ table(ds_test$ft4_dia) %>% prop.table() # random forest classification ----------------------------------------------------------- -# base model - No Hyper Tuning rf_recipe <- r$recipe(ft4_dia ~ . , data = ds_train) %>% r$step_rm(FT4) %>% @@ -280,8 +279,8 @@ ds_reg_class_pred <- reg_test_results %>% ,TSH > 4.2 & `.pred` > 0.93 ~ "Non-Hypo" ,TSH < 0.27 & `.pred` > 1.7 ~ "Hyper" ,TSH < 0.27 & `.pred` < 1.7 ~ "Non-Hyper" - ) - ) %>% + ) + ) %>% dplyr::mutate(dplyr::across( ft4_dia_pred , ~factor(., levels = c("Hypo", "Non-Hypo","Hyper", "Non-Hyper") @@ -302,8 +301,48 @@ ggplot(reg_test_results %>% tune::collect_predictions() , aes(x = FT4, y = .pred # check orginal data -model_data %>% dplyr::group_by(ft4_dia) %>% +model_data %>% + dplyr::mutate(tsh_level = ifelse(TSH > 4.2, "high", "low")) %>% + dplyr::group_by(tsh_level, ft4_dia) %>% dplyr::summarise( n = n() + ) %>% + mutate(freq = n / sum(n)) + + + +# nnet reg ---------------------------------------------------------------- + +normalized_rec <- recipes::recipe(FT4 ~ ., data = ds_train) %>% + recipes::step_rm(ft4_dia) %>% + recipes::step_impute_bag(recipes::all_predictors()) %>% + # recipes::step_corr(recipes::all_numeric_predictors()) %>% + recipes::step_normalize(recipes::all_numeric_predictors() , -anchor_age) %>% + recipes::step_dummy(gender) + +nnet_spec <- + p$mlp(hidden_units = tune(), penalty = tune(), epochs = tune()) %>% + p$set_engine("nnet", MaxNWts = 2600) %>% + p$set_mode("regression") + +nnet_param <- + nnet_spec %>% + tune$extract_parameter_set_dials() %>% + update(hidden_units = d$hidden_units(c(1, 27))) + +nnet_reg_workflow <- wf$workflow() %>% + wf$add_model(nnet_spec) %>% + wf$add_recipe(normalized_rec) + +data_fold_reg <- rsamp$vfold_cv(ds_train, v = 5) + +nnet_reg_tune <- nnet_reg_workflow %>% + tune::tune_grid( + data_fold_reg + ,grid = nnet_param %>% d$grid_regular() + ,verbose = TRUE ) + + +