Update 2-modeling.R

This commit is contained in:
Kyle Belanger 2023-04-12 08:01:03 -04:00
parent 148a2071eb
commit e63bb47b49

View file

@ -21,7 +21,7 @@ box::use(
) )
require(ggplot2) require(ggplot2) #this is needed for autoplot to work with workflows
# globals ----------------------------------------------------------------- # globals -----------------------------------------------------------------
@ -50,6 +50,7 @@ ds_test <- rsample$testing(model_data_split)
table(ds_train$ft4_dia) %>% prop.table() table(ds_train$ft4_dia) %>% prop.table()
table(ds_test$ft4_dia) %>% prop.table() table(ds_test$ft4_dia) %>% prop.table()
# random forest classification ----------------------------------------------------------- # random forest classification -----------------------------------------------------------
# base model - No Hyper Tuning # base model - No Hyper Tuning
@ -234,10 +235,27 @@ final_rf_reg_fit <- p$fit(final_rf_reg_workflow, ds_train)
# predictions for training data # predictions for training data
final_rf_reg_predict <- ds_train %>% final_rf_reg_predict <- ds_train %>%
dplyr::select(FT4, TSH) %>% dplyr::select(FT4, TSH, ft4_dia) %>%
dplyr::bind_cols( dplyr::bind_cols(
predict(final_rf_reg_fit, ds_train) predict(final_rf_reg_fit, ds_train)
) %>%
dplyr::mutate(
ft4_dia_pred = dplyr::case_when(
TSH > 4.2 & `.pred` < 0.93 ~ "Hypo"
,TSH > 4.2 & `.pred` > 0.93 ~ "Non-Hypo"
,TSH < 0.27 & `.pred` > 1.7 ~ "Hyper"
,TSH < 0.27 & `.pred` < 1.7 ~ "Non-Hyper"
)
) %>%
dplyr::mutate(dplyr::across(
ft4_dia_pred
, ~factor(., levels = c("Hypo", "Non-Hypo","Hyper", "Non-Hyper")
)
) )
)
ys$conf_mat(final_rf_reg_predict,truth = ft4_dia ,estimate = ft4_dia_pred)
ys$accuracy(final_rf_reg_predict,truth = ft4_dia, estimate = ft4_dia_pred)
reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred) reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred)
@ -246,6 +264,46 @@ ggplot(final_rf_reg_predict, aes(x = FT4, y = .pred)) +
gp2$geom_point(alpha = 0.5) + gp2$geom_point(alpha = 0.5) +
tune::coord_obs_pred() tune::coord_obs_pred()
# fitting test data
reg_test_results <-
final_rf_reg_fit %>%
tune::last_fit(split = model_data_split)
ds_reg_class_pred <- reg_test_results %>%
tune::collect_predictions() %>%
dplyr::select(-id, -.config) %>%
dplyr::bind_cols(ds_test %>% dplyr::select(TSH, ft4_dia)) %>%
dplyr::mutate(
ft4_dia_pred = dplyr::case_when(
TSH > 4.2 & `.pred` < 0.93 ~ "Hypo"
,TSH > 4.2 & `.pred` > 0.93 ~ "Non-Hypo"
,TSH < 0.27 & `.pred` > 1.7 ~ "Hyper"
,TSH < 0.27 & `.pred` < 1.7 ~ "Non-Hyper"
)
) %>%
dplyr::mutate(dplyr::across(
ft4_dia_pred
, ~factor(., levels = c("Hypo", "Non-Hypo","Hyper", "Non-Hyper")
)
)
)
ys$accuracy(ds_reg_class_pred,truth = ft4_dia, estimate = ft4_dia_pred)
ys$conf_mat(ds_reg_class_pred,truth = ft4_dia ,estimate = ft4_dia_pred)
tune::collect_metrics(reg_test_results)
ggplot(reg_test_results %>% tune::collect_predictions() , aes(x = FT4, y = .pred)) +
gp2$geom_abline(lty = 2) +
gp2$geom_point(alpha = 0.5) +
tune::coord_obs_pred()
# check orginal data
model_data %>% dplyr::group_by(ft4_dia) %>%
dplyr::summarise(
n = n()
)