updats
This commit is contained in:
parent
7ed1ad7b00
commit
148a2071eb
2 changed files with 18 additions and 20 deletions
|
@ -17,9 +17,12 @@ box::use(
|
|||
,ys = yardstick
|
||||
,d = dials
|
||||
,rsamp = rsample
|
||||
,tune
|
||||
)
|
||||
|
||||
|
||||
require(ggplot2)
|
||||
|
||||
|
||||
# globals -----------------------------------------------------------------
|
||||
|
||||
|
@ -85,15 +88,8 @@ rf_best_params <- tibble::tibble(
|
|||
,min_n = 2
|
||||
)
|
||||
|
||||
rf_best_params_screen <-
|
||||
tibble::tibble(
|
||||
mtry = 7
|
||||
,trees = 763
|
||||
,min_n = 15
|
||||
)
|
||||
|
||||
final_rf_workflow <- rf_workflow %>%
|
||||
tune::finalize_workflow(rf_best_params_screen)
|
||||
tune::finalize_workflow(rf_best_params)
|
||||
|
||||
# Final Fit training data
|
||||
|
||||
|
@ -201,7 +197,7 @@ reg_metrics <- ys$metric_set(ys$rmse, ys$rsq, ys$mae)
|
|||
rf_reg_tune_model <- p$rand_forest(trees = tune(), mtry = tune(), min_n = tune()) %>%
|
||||
p$set_engine("ranger") %>% p$set_mode("regression")
|
||||
|
||||
rf_reg_recipe <- r$recipe(FT4 ~ . , data = reg_train) %>%
|
||||
rf_reg_recipe <- r$recipe(FT4 ~ . , data = ds_train) %>%
|
||||
r$step_rm(ft4_dia) %>%
|
||||
r$step_impute_bag(r$all_predictors())
|
||||
|
||||
|
@ -212,9 +208,9 @@ rf_reg_workflow <- wf$workflow() %>%
|
|||
|
||||
|
||||
rf_reg_param <- p$extract_parameter_set_dials(rf_reg_tune_model) %>%
|
||||
update(mtry = d$finalize(d$mtry(), reg_train))
|
||||
update(mtry = d$finalize(d$mtry(), ds_train))
|
||||
|
||||
data_fold_reg <- rsamp$vfold_cv(reg_train, v = 5)
|
||||
data_fold_reg <- rsamp$vfold_cv(ds_train, v = 5)
|
||||
|
||||
# takes around 1 hr to run grid search. saving best params manaually
|
||||
# rf_reg_tune <- rf_reg_workflow %>%
|
||||
|
@ -232,22 +228,23 @@ rf_reg_best_params <- tibble::tibble(
|
|||
final_rf_reg_workflow <- rf_reg_workflow %>%
|
||||
tune::finalize_workflow(rf_reg_best_params)
|
||||
|
||||
final_rf_reg_fit <- p$fit(final_rf_reg_workflow, reg_train)
|
||||
final_rf_reg_fit <- p$fit(final_rf_reg_workflow, ds_train)
|
||||
|
||||
|
||||
#
|
||||
final_rf_reg_predict <- reg_train %>%
|
||||
dplyr::select(FT4) %>%
|
||||
# predictions for training data
|
||||
|
||||
final_rf_reg_predict <- ds_train %>%
|
||||
dplyr::select(FT4, TSH) %>%
|
||||
dplyr::bind_cols(
|
||||
predict(final_rf_reg_fit, reg_train)
|
||||
predict(final_rf_reg_fit, ds_train)
|
||||
)
|
||||
|
||||
reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred)
|
||||
|
||||
|
||||
reg_test_results <-
|
||||
final_rf_reg_workflow %>%
|
||||
tune::last_fit()
|
||||
ggplot(final_rf_reg_predict, aes(x = FT4, y = .pred)) +
|
||||
gp2$geom_abline(lty = 2) +
|
||||
gp2$geom_point(alpha = 0.5) +
|
||||
tune::coord_obs_pred()
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ box::use(
|
|||
,d = dials
|
||||
,rsamp = rsample
|
||||
,tune
|
||||
,workflowsets
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue