diff --git a/ML/2-modeling.R b/ML/2-modeling.R index 59aa63a..010df32 100644 --- a/ML/2-modeling.R +++ b/ML/2-modeling.R @@ -17,9 +17,12 @@ box::use( ,ys = yardstick ,d = dials ,rsamp = rsample + ,tune ) +require(ggplot2) + # globals ----------------------------------------------------------------- @@ -85,15 +88,8 @@ rf_best_params <- tibble::tibble( ,min_n = 2 ) -rf_best_params_screen <- - tibble::tibble( - mtry = 7 - ,trees = 763 - ,min_n = 15 - ) - final_rf_workflow <- rf_workflow %>% - tune::finalize_workflow(rf_best_params_screen) + tune::finalize_workflow(rf_best_params) # Final Fit training data @@ -201,7 +197,7 @@ reg_metrics <- ys$metric_set(ys$rmse, ys$rsq, ys$mae) rf_reg_tune_model <- p$rand_forest(trees = tune(), mtry = tune(), min_n = tune()) %>% p$set_engine("ranger") %>% p$set_mode("regression") -rf_reg_recipe <- r$recipe(FT4 ~ . , data = reg_train) %>% +rf_reg_recipe <- r$recipe(FT4 ~ . , data = ds_train) %>% r$step_rm(ft4_dia) %>% r$step_impute_bag(r$all_predictors()) @@ -212,9 +208,9 @@ rf_reg_workflow <- wf$workflow() %>% rf_reg_param <- p$extract_parameter_set_dials(rf_reg_tune_model) %>% - update(mtry = d$finalize(d$mtry(), reg_train)) + update(mtry = d$finalize(d$mtry(), ds_train)) -data_fold_reg <- rsamp$vfold_cv(reg_train, v = 5) +data_fold_reg <- rsamp$vfold_cv(ds_train, v = 5) # takes around 1 hr to run grid search. saving best params manaually # rf_reg_tune <- rf_reg_workflow %>% @@ -232,22 +228,23 @@ rf_reg_best_params <- tibble::tibble( final_rf_reg_workflow <- rf_reg_workflow %>% tune::finalize_workflow(rf_reg_best_params) -final_rf_reg_fit <- p$fit(final_rf_reg_workflow, reg_train) +final_rf_reg_fit <- p$fit(final_rf_reg_workflow, ds_train) -# -final_rf_reg_predict <- reg_train %>% - dplyr::select(FT4) %>% +# predictions for training data + +final_rf_reg_predict <- ds_train %>% + dplyr::select(FT4, TSH) %>% dplyr::bind_cols( - predict(final_rf_reg_fit, reg_train) + predict(final_rf_reg_fit, ds_train) ) reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred) - -reg_test_results <- - final_rf_reg_workflow %>% - tune::last_fit() +ggplot(final_rf_reg_predict, aes(x = FT4, y = .pred)) + + gp2$geom_abline(lty = 2) + + gp2$geom_point(alpha = 0.5) + + tune::coord_obs_pred() diff --git a/ML/3_model_outputs.R b/ML/3_model_outputs.R index 08f9c78..bd0f943 100644 --- a/ML/3_model_outputs.R +++ b/ML/3_model_outputs.R @@ -18,6 +18,7 @@ box::use( ,d = dials ,rsamp = rsample ,tune + ,workflowsets )