From b6ac8340510fd55caaae0368cc34b91ae413a41c Mon Sep 17 00:00:00 2001 From: walkerjameschris Date: Thu, 27 Feb 2025 23:07:46 -0500 Subject: [PATCH] Adding unit test for training on slices of ldb.Dataset --- R-package/tests/testthat/test_basic.R | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index cb43ba613be9..e821bf93ee94 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -3861,3 +3861,39 @@ test_that("Evaluation metrics aren't printed as a single-element vector", { }) expect_false(grepl("[1] \"[1]", log_txt, fixed = TRUE)) }) + +test_that("lgb.train() can use slices of an lgb.Dataset for train and valid data", { + + data("iris") + + ds <- lgb.Dataset( + as.matrix(iris[, 1:3]) + , label = iris[, 4] + ) + + train <- lgb.slice.Dataset(ds, seq(1, 100)) + test <- lgb.slice.Dataset(ds, seq(101, 150)) + + test_mat <- as.matrix(iris[101:150, 1:3]) + test_label <- iris[101:150, 4] + + params <- list( + metric = "l2" + , objective = "regression" + , num_threads = .LGB_MAX_THREADS + ) + + model <- lgb.train( + params = params + , data = train + , nrounds = 1L + , verbose = .LGB_VERBOSITY + , valids = list(test = test) + ) + + y_hat <- predict(model, newdata = test_mat) + model_l2 <- model$eval_valid()[[1]]$value + independent_l2 <- mean((test_label - y_hat) ** 2) + + expect_equal(model_l2, independent_l2, tolerance = 0.0001) +})