updats
This commit is contained in:
parent
7ed1ad7b00
commit
148a2071eb
2 changed files with 18 additions and 20 deletions
|
@ -17,9 +17,12 @@ box::use(
|
||||||
,ys = yardstick
|
,ys = yardstick
|
||||||
,d = dials
|
,d = dials
|
||||||
,rsamp = rsample
|
,rsamp = rsample
|
||||||
|
,tune
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
require(ggplot2)
|
||||||
|
|
||||||
|
|
||||||
# globals -----------------------------------------------------------------
|
# globals -----------------------------------------------------------------
|
||||||
|
|
||||||
|
@ -85,15 +88,8 @@ rf_best_params <- tibble::tibble(
|
||||||
,min_n = 2
|
,min_n = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
rf_best_params_screen <-
|
|
||||||
tibble::tibble(
|
|
||||||
mtry = 7
|
|
||||||
,trees = 763
|
|
||||||
,min_n = 15
|
|
||||||
)
|
|
||||||
|
|
||||||
final_rf_workflow <- rf_workflow %>%
|
final_rf_workflow <- rf_workflow %>%
|
||||||
tune::finalize_workflow(rf_best_params_screen)
|
tune::finalize_workflow(rf_best_params)
|
||||||
|
|
||||||
# Final Fit training data
|
# Final Fit training data
|
||||||
|
|
||||||
|
@ -201,7 +197,7 @@ reg_metrics <- ys$metric_set(ys$rmse, ys$rsq, ys$mae)
|
||||||
rf_reg_tune_model <- p$rand_forest(trees = tune(), mtry = tune(), min_n = tune()) %>%
|
rf_reg_tune_model <- p$rand_forest(trees = tune(), mtry = tune(), min_n = tune()) %>%
|
||||||
p$set_engine("ranger") %>% p$set_mode("regression")
|
p$set_engine("ranger") %>% p$set_mode("regression")
|
||||||
|
|
||||||
rf_reg_recipe <- r$recipe(FT4 ~ . , data = reg_train) %>%
|
rf_reg_recipe <- r$recipe(FT4 ~ . , data = ds_train) %>%
|
||||||
r$step_rm(ft4_dia) %>%
|
r$step_rm(ft4_dia) %>%
|
||||||
r$step_impute_bag(r$all_predictors())
|
r$step_impute_bag(r$all_predictors())
|
||||||
|
|
||||||
|
@ -212,9 +208,9 @@ rf_reg_workflow <- wf$workflow() %>%
|
||||||
|
|
||||||
|
|
||||||
rf_reg_param <- p$extract_parameter_set_dials(rf_reg_tune_model) %>%
|
rf_reg_param <- p$extract_parameter_set_dials(rf_reg_tune_model) %>%
|
||||||
update(mtry = d$finalize(d$mtry(), reg_train))
|
update(mtry = d$finalize(d$mtry(), ds_train))
|
||||||
|
|
||||||
data_fold_reg <- rsamp$vfold_cv(reg_train, v = 5)
|
data_fold_reg <- rsamp$vfold_cv(ds_train, v = 5)
|
||||||
|
|
||||||
# takes around 1 hr to run grid search. saving best params manaually
|
# takes around 1 hr to run grid search. saving best params manaually
|
||||||
# rf_reg_tune <- rf_reg_workflow %>%
|
# rf_reg_tune <- rf_reg_workflow %>%
|
||||||
|
@ -232,22 +228,23 @@ rf_reg_best_params <- tibble::tibble(
|
||||||
final_rf_reg_workflow <- rf_reg_workflow %>%
|
final_rf_reg_workflow <- rf_reg_workflow %>%
|
||||||
tune::finalize_workflow(rf_reg_best_params)
|
tune::finalize_workflow(rf_reg_best_params)
|
||||||
|
|
||||||
final_rf_reg_fit <- p$fit(final_rf_reg_workflow, reg_train)
|
final_rf_reg_fit <- p$fit(final_rf_reg_workflow, ds_train)
|
||||||
|
|
||||||
|
|
||||||
#
|
# predictions for training data
|
||||||
final_rf_reg_predict <- reg_train %>%
|
|
||||||
dplyr::select(FT4) %>%
|
final_rf_reg_predict <- ds_train %>%
|
||||||
|
dplyr::select(FT4, TSH) %>%
|
||||||
dplyr::bind_cols(
|
dplyr::bind_cols(
|
||||||
predict(final_rf_reg_fit, reg_train)
|
predict(final_rf_reg_fit, ds_train)
|
||||||
)
|
)
|
||||||
|
|
||||||
reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred)
|
reg_metrics(final_rf_reg_predict, truth = FT4, estimate = .pred)
|
||||||
|
|
||||||
|
ggplot(final_rf_reg_predict, aes(x = FT4, y = .pred)) +
|
||||||
reg_test_results <-
|
gp2$geom_abline(lty = 2) +
|
||||||
final_rf_reg_workflow %>%
|
gp2$geom_point(alpha = 0.5) +
|
||||||
tune::last_fit()
|
tune::coord_obs_pred()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ box::use(
|
||||||
,d = dials
|
,d = dials
|
||||||
,rsamp = rsample
|
,rsamp = rsample
|
||||||
,tune
|
,tune
|
||||||
|
,workflowsets
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue