diff --git a/ML/2-modeling.R b/ML/2-modeling.R index 7159f27..0d27fbc 100644 --- a/ML/2-modeling.R +++ b/ML/2-modeling.R @@ -28,8 +28,6 @@ set.seed(070823) #set seed for reproducible research model_data <- readr$read_rds(here("ML","data-unshared","model_data.RDS")) - - # split data -------------------------------------------------------------- model_data_split <- rsample$initial_split( @@ -46,37 +44,44 @@ table(ds_train$ft4_dia) %>% prop.table() table(ds_test$ft4_dia) %>% prop.table() +class_train <- ds_train %>% dplyr::select(-FT4) # training data for classification models +reg_train <- ds_train %>% dplyr::select(-ft4_dia) # training data for reg models predicting result # random forest ----------------------------------------------------------- +# base model - No Hyper Tuning rf_model <- p$rand_forest(trees = 1900) %>% - p$set_engine("ranger") %>% p$set_mode("regression") + p$set_engine("ranger") %>% p$set_mode("classification") -rf_recipe <- r$recipe(FT4 ~ . , data = ds_train) %>% +rf_recipe <- r$recipe(ft4_dia ~ . , data = class_train) %>% r$update_role(subject_id, new_role = "id") %>% r$update_role(charttime, new_role = "time") %>% - r$update_role(ft4_dia, new_role = "class") %>% r$step_impute_bag(r$all_predictors()) - rf_workflow <- wf$workflow() %>% wf$add_model(rf_model) %>% wf$add_recipe(rf_recipe) -rf_fit <- p$fit(rf_workflow, ds_train) +rf_fit <- p$fit(rf_workflow, class_train) -rf_predict <- ds_train %>% - dplyr::select(FT4) %>% - dplyr::bind_cols(predict(rf_fit, ds_train)) +rf_predict <- class_train %>% + dplyr::select(ft4_dia) %>% + dplyr::bind_cols( + predict(rf_fit, class_train) + ,predict(rf_fit, class_train, type = "prob") + ) +conf_mat_rf <- ys$conf_mat(rf_predict, ft4_dia, .pred_class) -gp2$ggplot(rf_predict, gp2$aes(x = FT4, y = .pred)) + - gp2$geom_point() +# explainer_rf <- DALEXtra::explain_tidymodels( +# rf_fit +# ,data = class_train +# ,y = class_train$ft4_dia +# ) -ys$rmse(rf_predict, FT4, .pred) +# this takes awhile to run +#vip_lm <- DALEX::model_parts(explainer_rf) -metrics <- ys$metric_set(ys$rmse, ys$rsq, ys$mae) - -metrics(rf_predict, FT4, .pred) +#plot(vip_lm)