This commit is contained in:
Kyle Belanger 2023-04-09 19:59:56 -04:00
parent 7ed1ad7b00
commit 148a2071eb
2 changed files with 18 additions and 20 deletions

View file

@ -17,9 +17,12 @@ box::use(
,ys = yardstick ,ys = yardstick
,d = dials ,d = dials
,rsamp = rsample ,rsamp = rsample
,tune
) )
require(ggplot2)
# globals ----------------------------------------------------------------- # globals -----------------------------------------------------------------
@ -85,15 +88,8 @@ rf_best_params <- tibble::tibble(
,min_n = 2 ,min_n = 2
) )
rf_best_params_screen <-
tibble::tibble(
mtry = 7
,trees = 763
,min_n = 15
)
final_rf_workflow <- rf_workflow %>% final_rf_workflow <- rf_workflow %>%
tune::finalize_workflow(rf_best_params_screen) tune::finalize_workflow(rf_best_params)
# Final Fit training data # Final Fit training data
@ -201,7 +197,7 @@ reg_metrics <- ys$metric_set(ys$rmse, ys$rsq, ys$mae)
rf_reg_tune_model <- p$rand_forest(trees = tune(), mtry = tune(), min_n = tune()) %>% rf_reg_tune_model <- p$rand_forest(trees = tune(), mtry = tune(), min_n = tune()) %>%
p$set_engine("ranger") %>% p$set_mode("regression") p$set_engine("ranger") %>% p$set_mode("regression")
rf_reg_recipe <- r$recipe(FT4 ~ . , data = reg_train) %>% rf_reg_recipe <- r$recipe(FT4 ~ . , data = ds_train) %>%
r$step_rm(ft4_dia) %>% r$step_rm(ft4_dia) %>%
r$step_impute_bag(r$all_predictors()) r$step_impute_bag(r$all_predictors())
@ -212,9 +208,9 @@ rf_reg_workflow <- wf$workflow() %>%
rf_reg_param <- p$extract_parameter_set_dials(rf_reg_tune_model) %>% rf_reg_param <- p$extract_parameter_set_dials(rf_reg_tune_model) %>%
update(mtry = d$finalize(d$mtry(), reg_train)) update(mtry = d$finalize(d$mtry(), ds_train))
data_fold_reg <- rsamp$vfold_cv(reg_train, v = 5) data_fold_reg <- rsamp$vfold_cv(ds_train, v = 5)
# takes around 1 hr to run grid search. saving best params manaually # takes around 1 hr to run grid search. saving best params manaually
# rf_reg_tune <- rf_reg_workflow %>% # rf_reg_tune <- rf_reg_workflow %>%
@ -232,22 +228,23 @@ rf_reg_best_params <- tibble::tibble(
final_rf_reg_workflow <- rf_reg_workflow %>% final_rf_reg_workflow <- rf_reg_workflow %>%
tune::finalize_workflow(rf_reg_best_params) tune::finalize_workflow(rf_reg_best_params)
final_rf_reg_fit <- p$fit(final_rf_reg_workflow, reg_train) final_rf_reg_fit <- p$fit(final_rf_reg_workflow, ds_train)
# # predictions for training data
final_rf_reg_predict <- reg_train %>%
dplyr::select(FT4) %>% final_rf_reg_predict <- ds_train %>%
dplyr::select(FT4, TSH) %>%
dplyr::bind_cols( dplyr::bind_cols(
predict(final_rf_reg_fit, reg_train) predict(final_rf_reg_fit, ds_train)
) )
reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred) reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred)
ggplot(final_rf_reg_predict, aes(x = FT4, y = .pred)) +
reg_test_results <- gp2$geom_abline(lty = 2) +
final_rf_reg_workflow %>% gp2$geom_point(alpha = 0.5) +
tune::last_fit() tune::coord_obs_pred()

View file

@ -18,6 +18,7 @@ box::use(
,d = dials ,d = dials
,rsamp = rsample ,rsamp = rsample
,tune ,tune
,workflowsets
) )