From e63bb47b491cfeab9797c4470bfcfd064fc899bc Mon Sep 17 00:00:00 2001 From: Kyle Belanger Date: Wed, 12 Apr 2023 08:01:03 -0400 Subject: [PATCH] Update 2-modeling.R --- ML/2-modeling.R | 62 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/ML/2-modeling.R b/ML/2-modeling.R index 010df32..34a245f 100644 --- a/ML/2-modeling.R +++ b/ML/2-modeling.R @@ -21,7 +21,7 @@ box::use( ) -require(ggplot2) +require(ggplot2) #this is needed for autoplot to work with workflows # globals ----------------------------------------------------------------- @@ -50,6 +50,7 @@ ds_test <- rsample$testing(model_data_split) table(ds_train$ft4_dia) %>% prop.table() table(ds_test$ft4_dia) %>% prop.table() + # random forest classification ----------------------------------------------------------- # base model - No Hyper Tuning @@ -234,10 +235,27 @@ final_rf_reg_fit <- p$fit(final_rf_reg_workflow, ds_train) # predictions for training data final_rf_reg_predict <- ds_train %>% - dplyr::select(FT4, TSH) %>% + dplyr::select(FT4, TSH, ft4_dia) %>% dplyr::bind_cols( predict(final_rf_reg_fit, ds_train) + ) %>% + dplyr::mutate( + ft4_dia_pred = dplyr::case_when( + TSH > 4.2 & `.pred` < 0.93 ~ "Hypo" + ,TSH > 4.2 & `.pred` > 0.93 ~ "Non-Hypo" + ,TSH < 0.27 & `.pred` > 1.7 ~ "Hyper" + ,TSH < 0.27 & `.pred` < 1.7 ~ "Non-Hyper" + ) + ) %>% + dplyr::mutate(dplyr::across( + ft4_dia_pred + , ~factor(., levels = c("Hypo", "Non-Hypo","Hyper", "Non-Hyper") + ) ) + ) + +ys$conf_mat(final_rf_reg_predict,truth = ft4_dia ,estimate = ft4_dia_pred) +ys$accuracy(final_rf_reg_predict,truth = ft4_dia, estimate = ft4_dia_pred) reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred) @@ -246,6 +264,46 @@ ggplot(final_rf_reg_predict, aes(x = FT4, y = .pred)) + gp2$geom_point(alpha = 0.5) + tune::coord_obs_pred() +# fitting test data + +reg_test_results <- + final_rf_reg_fit %>% + tune::last_fit(split = model_data_split) + +ds_reg_class_pred <- reg_test_results %>% + tune::collect_predictions() %>% + dplyr::select(-id, -.config) %>% + dplyr::bind_cols(ds_test %>% dplyr::select(TSH, ft4_dia)) %>% + dplyr::mutate( + ft4_dia_pred = dplyr::case_when( + TSH > 4.2 & `.pred` < 0.93 ~ "Hypo" + ,TSH > 4.2 & `.pred` > 0.93 ~ "Non-Hypo" + ,TSH < 0.27 & `.pred` > 1.7 ~ "Hyper" + ,TSH < 0.27 & `.pred` < 1.7 ~ "Non-Hyper" + ) + ) %>% + dplyr::mutate(dplyr::across( + ft4_dia_pred + , ~factor(., levels = c("Hypo", "Non-Hypo","Hyper", "Non-Hyper") + ) + ) + ) + +ys$accuracy(ds_reg_class_pred,truth = ft4_dia, estimate = ft4_dia_pred) +ys$conf_mat(ds_reg_class_pred,truth = ft4_dia ,estimate = ft4_dia_pred) + +tune::collect_metrics(reg_test_results) + +ggplot(reg_test_results %>% tune::collect_predictions() , aes(x = FT4, y = .pred)) + + gp2$geom_abline(lty = 2) + + gp2$geom_point(alpha = 0.5) + + tune::coord_obs_pred() +# check orginal data + +model_data %>% dplyr::group_by(ft4_dia) %>% + dplyr::summarise( + n = n() + )