Skip to content

Commit

Permalink
[XLA:GPU] Fix slow compile time for jax-ml/jax#26162
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 728125830
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Feb 18, 2025
1 parent 6b7233b commit 4c43694
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions xla/codegen/emitters/ir/xla_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
#include "xla/codegen/emitters/ir/xla_ops.h"

Expand Down Expand Up @@ -69,10 +76,12 @@ struct XlaInlinerInterface : public mlir::DialectInlinerInterface {
for (auto call : region->getOps<PureCallOp>()) {
callee_calls.insert(call.getCallee());
}
llvm::SmallDenseSet<llvm::StringRef> same_region_calls;
for (auto call : call->getParentRegion()->getOps<PureCallOp>()) {
if (callee_calls.contains(call.getCallee())) {
return true;
}
same_region_calls.insert(call.getCallee());
}
if (llvm::set_is_subset(same_region_calls, callee_calls)) {
return true;
}

constexpr int kMaxOperationsToInline = 8;
Expand Down

0 comments on commit 4c43694

Please sign in to comment.