diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 51cffac808768..e8d7430fb3d5c 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -597,11 +597,9 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) { /// /// Since we do not continue running this routine on expression trees once we /// have seen unequal values, there is no need to track them in the cache. -static int -CompareValueComplexity(EquivalenceClasses &EqCacheValue, - const LoopInfo *const LI, Value *LV, Value *RV, - unsigned Depth) { - if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV)) +static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, + Value *RV, unsigned Depth) { + if (Depth > MaxValueCompareDepth) return 0; // Order pointer values after integer values. This helps SCEVExpander form @@ -660,15 +658,13 @@ CompareValueComplexity(EquivalenceClasses &EqCacheValue, return (int)LNumOps - (int)RNumOps; for (unsigned Idx : seq(LNumOps)) { - int Result = - CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx), - RInst->getOperand(Idx), Depth + 1); + int Result = CompareValueComplexity(LI, LInst->getOperand(Idx), + RInst->getOperand(Idx), Depth + 1); if (Result != 0) return Result; } } - EqCacheValue.unionSets(LV, RV); return 0; } @@ -679,7 +675,6 @@ CompareValueComplexity(EquivalenceClasses &EqCacheValue, // not know if they are equivalent for sure. static std::optional CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, - EquivalenceClasses &EqCacheValue, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) { // Fast-path: SCEVs are uniqued so we can do a quick equality check. @@ -705,8 +700,8 @@ CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, const SCEVUnknown *LU = cast(LHS); const SCEVUnknown *RU = cast(RHS); - int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(), - RU->getValue(), Depth + 1); + int X = + CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1); if (X == 0) EqCacheSCEV.unionSets(LHS, RHS); return X; @@ -773,8 +768,8 @@ CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, return (int)LNumOps - (int)RNumOps; for (unsigned i = 0; i != LNumOps; ++i) { - auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LOps[i], - ROps[i], DT, Depth + 1); + auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT, + Depth + 1); if (X != 0) return X; } @@ -802,12 +797,10 @@ static void GroupByComplexity(SmallVectorImpl &Ops, if (Ops.size() < 2) return; // Noop EquivalenceClasses EqCacheSCEV; - EquivalenceClasses EqCacheValue; // Whether LHS has provably less complexity than RHS. auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) { - auto Complexity = - CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT); + auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT); return Complexity && *Complexity < 0; }; if (Ops.size() == 2) { diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index a6a5ffda3cb70..76e6095636305 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1625,4 +1625,40 @@ TEST_F(ScalarEvolutionsTest, ForgetValueWithOverflowInst) { }); } +TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering) { + // Regression test for a case where caching of equivalent values caused the + // comparator to get inconsistent. + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString(R"( + define i32 @foo(i32 %arg0) { + %1 = add i32 %arg0, 1 + %2 = add i32 %arg0, 1 + %3 = xor i32 %2, %1 + %4 = add i32 %3, %2 + %5 = add i32 %arg0, 1 + %6 = xor i32 %5, %arg0 + %7 = add i32 %arg0, %6 + %8 = add i32 %5, %7 + %9 = xor i32 %8, %7 + %10 = add i32 %9, %8 + %11 = xor i32 %10, %9 + %12 = add i32 %11, %10 + %13 = xor i32 %12, %11 + %14 = add i32 %12, %13 + %15 = add i32 %14, %4 + ret i32 %15 + })", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + // When _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG, this will + // crash if the comparator has the specific caching bug. + SE.getSCEV(F.getEntryBlock().getTerminator()->getOperand(0)); + }); +} + } // end namespace llvm