Skip to content

Commit

Permalink
Fix d0 + (d0 // -c) * c.
Browse files Browse the repository at this point in the history
Currently, this is rewritten to d0 mod -c. However, we do not support
modulo with a negative RHS in our lowering passes, so this triggers
undefined behavior.

It would be better to not have these ad hoc simplifications at all, but
I guess that ship has sailed.
  • Loading branch information
jreiffers committed Sep 6, 2024
1 parent ddf40e0 commit 53b8b22
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mlir/lib/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,8 +760,11 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {

llrhs = lrBinOpExpr.getLHS();
rlrhs = lrBinOpExpr.getRHS();
auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rlrhs);
// We don't support modulo with a negative RHS.
bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0;

if (lhs == llrhs && rlrhs == -rrhs) {
if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) {
return lhs % rlrhs;
}
return nullptr;
Expand Down
17 changes: 17 additions & 0 deletions mlir/unittests/IR/AffineExprTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@

using namespace mlir;

static std::string toString(AffineExpr expr) {
std::string s;
llvm::raw_string_ostream ss(s);
ss << expr;
return s;
}

// Test creating AffineExprs using the overloaded binary operators.
TEST(AffineExprTest, constructFromBinaryOperators) {
MLIRContext ctx;
Expand Down Expand Up @@ -112,3 +119,13 @@ TEST(AffineExprTest, divisorOfNegativeFloorDiv) {
OpBuilder b(&ctx);
ASSERT_EQ(b.getAffineDimExpr(0).floorDiv(-1).getLargestKnownDivisor(), 1);
}

TEST(AffineExprTest, d0PlusD0FloorDivNeg2) {
// Regression test for a bug where this was rewritten to d0 mod -2. We do not
// support a negative RHS for mod in LowerAffinePass.
MLIRContext ctx;
OpBuilder b(&ctx);
auto d0 = b.getAffineDimExpr(0);
auto sum = d0 + d0.floorDiv(-2) * 2;
ASSERT_EQ(toString(sum), "d0 + (d0 floordiv -2) * 2");
}

0 comments on commit 53b8b22

Please sign in to comment.