Skip to content

Commit

Permalink
test lambdarank_unbiased
Browse files Browse the repository at this point in the history
  • Loading branch information
tubihongfeili committed May 15, 2024
1 parent 82d846b commit 26fa556
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
11 changes: 11 additions & 0 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@
<module>xgboost4j-spark</module>
<module>xgboost4j-flink</module>
</modules>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skipTests>true</skipTests>
</configuration>
</plugin>
</plugins>
</build>
</profile>

<profile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ trait PerTest extends BeforeAndAfterEach { self: AnyFunSuite =>
private def getOrCreateSession = synchronized {
if (currentSession == null) {
currentSession = sparkSessionBuilder.getOrCreate()
currentSession.sparkContext.setLogLevel("ERROR")
currentSession.sparkContext.setLogLevel("INFO")
}
currentSession
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,20 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
assert(testDF.count() === prediction.length)
}

test("ranking: test position bias") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "0", "verbosity" -> "3",
"objective" -> "rank:ndcg", "num_workers" -> numWorkers, "num_round" -> 5,
"group_col" -> "group", "tree_method" -> treeMethod, "lambdarank_unbiased" -> true, "eval_metric" -> "ndcg")

val trainingDF = buildDataFrameWithGroup(Ranking.train)
val testDF = buildDataFrame(Ranking.test)
val model = new XGBoostRegressor(paramMap).fit(trainingDF)

val prediction = model.transform(testDF).collect()
println("hello---------hongfei")
assert(testDF.count() === prediction.length)
}

test("use weight") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers,
Expand Down

0 comments on commit 26fa556

Please sign in to comment.