Update 2-modeling.R
This commit is contained in:
parent
e63bb47b49
commit
2c5049d538
1 changed files with 43 additions and 4 deletions
|
@ -53,7 +53,6 @@ table(ds_test$ft4_dia) %>% prop.table()
|
|||
|
||||
# random forest classification -----------------------------------------------------------
|
||||
|
||||
# base model - No Hyper Tuning
|
||||
|
||||
rf_recipe <- r$recipe(ft4_dia ~ . , data = ds_train) %>%
|
||||
r$step_rm(FT4) %>%
|
||||
|
@ -280,8 +279,8 @@ ds_reg_class_pred <- reg_test_results %>%
|
|||
,TSH > 4.2 & `.pred` > 0.93 ~ "Non-Hypo"
|
||||
,TSH < 0.27 & `.pred` > 1.7 ~ "Hyper"
|
||||
,TSH < 0.27 & `.pred` < 1.7 ~ "Non-Hyper"
|
||||
)
|
||||
) %>%
|
||||
)
|
||||
) %>%
|
||||
dplyr::mutate(dplyr::across(
|
||||
ft4_dia_pred
|
||||
, ~factor(., levels = c("Hypo", "Non-Hypo","Hyper", "Non-Hyper")
|
||||
|
@ -302,8 +301,48 @@ ggplot(reg_test_results %>% tune::collect_predictions() , aes(x = FT4, y = .pred
|
|||
|
||||
# check orginal data
|
||||
|
||||
model_data %>% dplyr::group_by(ft4_dia) %>%
|
||||
model_data %>%
|
||||
dplyr::mutate(tsh_level = ifelse(TSH > 4.2, "high", "low")) %>%
|
||||
dplyr::group_by(tsh_level, ft4_dia) %>%
|
||||
dplyr::summarise(
|
||||
n = n()
|
||||
) %>%
|
||||
mutate(freq = n / sum(n))
|
||||
|
||||
|
||||
|
||||
# nnet reg ----------------------------------------------------------------
|
||||
|
||||
normalized_rec <- recipes::recipe(FT4 ~ ., data = ds_train) %>%
|
||||
recipes::step_rm(ft4_dia) %>%
|
||||
recipes::step_impute_bag(recipes::all_predictors()) %>%
|
||||
# recipes::step_corr(recipes::all_numeric_predictors()) %>%
|
||||
recipes::step_normalize(recipes::all_numeric_predictors() , -anchor_age) %>%
|
||||
recipes::step_dummy(gender)
|
||||
|
||||
nnet_spec <-
|
||||
p$mlp(hidden_units = tune(), penalty = tune(), epochs = tune()) %>%
|
||||
p$set_engine("nnet", MaxNWts = 2600) %>%
|
||||
p$set_mode("regression")
|
||||
|
||||
nnet_param <-
|
||||
nnet_spec %>%
|
||||
tune$extract_parameter_set_dials() %>%
|
||||
update(hidden_units = d$hidden_units(c(1, 27)))
|
||||
|
||||
nnet_reg_workflow <- wf$workflow() %>%
|
||||
wf$add_model(nnet_spec) %>%
|
||||
wf$add_recipe(normalized_rec)
|
||||
|
||||
data_fold_reg <- rsamp$vfold_cv(ds_train, v = 5)
|
||||
|
||||
nnet_reg_tune <- nnet_reg_workflow %>%
|
||||
tune::tune_grid(
|
||||
data_fold_reg
|
||||
,grid = nnet_param %>% d$grid_regular()
|
||||
,verbose = TRUE
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue