From 1f0de32431ab386e8ba46c5fee6b0e63e9424400 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 7 Nov 2024 13:26:12 -0800 Subject: [PATCH 1/3] pass along nfeatures and feature_names to xgboost bundle --- NEWS.md | 2 ++ R/bundle_xgboost.R | 3 +++ 2 files changed, 5 insertions(+) diff --git a/NEWS.md b/NEWS.md index 90300c7..61308c2 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,8 @@ * Added bundle method for objects from `dbarts::bart()` and, by extension, `parsnip::bart(engine = "dbarts")` (#64). +* Bundling no longer removes `nfeatures` and `feature_names` from xgboost models (#67). + # bundle 0.1.1 * Fixed bundling of recipes steps situated inside of workflows. diff --git a/R/bundle_xgboost.R b/R/bundle_xgboost.R index 32516aa..1e02f25 100644 --- a/R/bundle_xgboost.R +++ b/R/bundle_xgboost.R @@ -51,6 +51,9 @@ bundle.xgb.Booster <- function(x, ...) { num_class = !!x$params$num_class ) + res$nfeatures <- !!x$nfeatures + res$feature_names <- !!x$feature_names + res }), desc_class = class(x)[1] From 60f2405662416f4138e2193465685be3f923abce Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Fri, 8 Nov 2024 09:13:34 -0800 Subject: [PATCH 2/3] add tests --- tests/testthat/test_bundle_xgboost.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/testthat/test_bundle_xgboost.R b/tests/testthat/test_bundle_xgboost.R index 7489d76..dfc71d0 100644 --- a/tests/testthat/test_bundle_xgboost.R +++ b/tests/testthat/test_bundle_xgboost.R @@ -96,4 +96,8 @@ test_that("bundling + unbundling xgboost fits", { # compare predictions expect_equal(mod_preds, mod_unbundled_preds) expect_equal(mod_preds, mod_butchered_unbundled_preds) + + # verify nfeatures and feature_names are kept + expect_identical(unbundle(mod_bundle)$nfeatures, mod_fit$nfeatures) + expect_identical(unbundle(mod_bundle)$feature_names, mod_fit$feature_names) }) From eb0144e52b7a9381b1c826b80b9cff54e256348e Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Fri, 8 Nov 2024 10:35:10 -0700 Subject: [PATCH 3/3] Update NEWS --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 61308c2..11dec93 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,7 +3,7 @@ * Added bundle method for objects from `dbarts::bart()` and, by extension, `parsnip::bart(engine = "dbarts")` (#64). -* Bundling no longer removes `nfeatures` and `feature_names` from xgboost models (#67). +* Bundling xgboost objects now takes extra steps to preserve `nfeatures` and `feature_names` (#67). # bundle 0.1.1