From 4a87662c6d69e45414080753bbc3b28d4b2c0733 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 27 Jan 2025 20:25:30 -0800 Subject: [PATCH] test fix --- tests/cpp/test_rope.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/cpp/test_rope.cpp b/tests/cpp/test_rope.cpp index cfec6182fc5..fceecea80c1 100644 --- a/tests/cpp/test_rope.cpp +++ b/tests/cpp/test_rope.cpp @@ -946,10 +946,11 @@ TEST_F(RopeTest, EndingRepeat) { EXPECT_EQ( ref_tv->getLoopDomain().at(0)->extent()->evaluate().as(), 2L); - IdModel id_model(scheduled_fusion, /*build_graphs=*/false); - const auto& graph = id_model.buildExactGraph(); + IdModel id_model(scheduled_fusion, /*build_graphs=*/true); + const auto& graph = id_model.idGraph(IdMappingMode::EXACT); - const auto ref_loop = graph.toGroups(ref_tv->getLoopDomain()); + const auto ref_loop = + graph.toGroups(getLoopIds(ref_tv->definition(), id_model)); // The other tensors, except for the pad output, should be fully inlined into // the reference tensor. @@ -957,7 +958,7 @@ TEST_F(RopeTest, EndingRepeat) { if (tv->isFusionInput()) { continue; } - auto tv_loop = graph.toGroups(tv->getLoopDomain()); + auto tv_loop = graph.toGroups(getLoopIds(tv->definition(), id_model)); if (tv->definition() != nullptr && tv->definition()->isA()) { ValGroups ref_groups{ref_loop.begin() + 1, ref_loop.end()}; // In the case of pad, the loop domain of the output tensor @@ -965,7 +966,8 @@ TEST_F(RopeTest, EndingRepeat) { // without the outermost ID. EXPECT_EQ(tv_loop, ref_groups); } else { - EXPECT_EQ(tv_loop, ref_loop) << tv->toString(); + EXPECT_EQ(tv_loop, ref_loop) + << tv->toString() << ", reference: " << ref_tv->toString(); EXPECT_EQ(tv->getLoopDomain().size(), tv->getComputeAtPosition()) << tv->toString(); }