From 26fa556b62397cb189246489b16fd91388778703 Mon Sep 17 00:00:00 2001 From: Hongfei Li Date: Wed, 15 May 2024 07:45:37 +0000 Subject: [PATCH] test lambdarank_unbiased --- jvm-packages/pom.xml | 11 +++++++++++ .../ml/dmlc/xgboost4j/scala/spark/PerTest.scala | 2 +- .../scala/spark/XGBoostRegressorSuite.scala | 14 ++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index b57b569db926..3d4e8cdd73e9 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -72,6 +72,17 @@ xgboost4j-spark xgboost4j-flink + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala index 24bc00e1824e..ac9eba37a17d 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PerTest.scala @@ -59,7 +59,7 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite => private def getOrCreateSession = synchronized { if (currentSession == null) { currentSession = sparkSessionBuilder.getOrCreate() - currentSession.sparkContext.setLogLevel("ERROR") + currentSession.sparkContext.setLogLevel("INFO") } currentSession } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index 1bdea7a827bd..3e241f5001be 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -132,6 +132,20 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu assert(testDF.count() === prediction.length) } + test("ranking: test position bias") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "0", "verbosity" -> "3", + "objective" -> "rank:ndcg", "num_workers" -> numWorkers, "num_round" -> 5, + "group_col" -> "group", "tree_method" -> treeMethod, "lambdarank_unbiased" -> true, "eval_metric" -> "ndcg") + + val trainingDF = buildDataFrameWithGroup(Ranking.train) + val testDF = buildDataFrame(Ranking.test) + val model = new XGBoostRegressor(paramMap).fit(trainingDF) + + val prediction = model.transform(testDF).collect() + println("hello---------hongfei") + assert(testDF.count() === prediction.length) + } + test("use weight") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,