Skip to content

Commit

Permalink
[XLA:GPU/TMA] Adding TMA related attributes to the function arguments.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733256581
  • Loading branch information
Moerafaat authored and Google-ML-Automation committed Mar 4, 2025
1 parent 52a89ef commit 17e913f
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 57 deletions.
1 change: 1 addition & 0 deletions xla/backends/gpu/codegen/triton/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:GPUDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ func.func @lower_tile_extract_insert(%arg0: tensor<512x128xbf16>,
// CHECK: tt.return

// CHECK-TMA-LABEL:tt.func @lower_tile_extract_insert
// CHECK-TMA-SAME: %[[ARG_0:.*]]: !tt.ptr<bf16>, %[[ARG_1:.*]]: !tt.ptr<bf16>
// CHECK-TMA-SAME: %[[ARG_0:.*]]: !tt.ptr<bf16> {tt.nv_tma_desc = 1 : i32, tt.tma_descriptor = #triton_xla.tma_descriptor<global_shape = [512, 128], block_shape = [16, 64], element_byte_size = 2>},
// CHECK-TMA-SAME: %[[ARG_1:.*]]: !tt.ptr<bf16> {tt.nv_tma_desc = 1 : i32, tt.tma_descriptor = #triton_xla.tma_descriptor<global_shape = [256, 256], block_shape = [16, 64], element_byte_size = 2>}
// CHECK-TMA: %[[DESC_0:.*]] = tt.reinterpret_tensor_descriptor %[[ARG_0]]
// CHECK-TMA: %[[DESC_1:.*]] = tt.reinterpret_tensor_descriptor %[[ARG_1]]
// CHECK-TMA: %[[LOAD:.*]] = tt.experimental_descriptor_load %[[DESC_0]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -98,9 +100,10 @@ bool TmaIsEnabledForDevice(
return is_cuda && device_info.cuda_compute_capability().IsAtLeastHopper();
}

bool CanUseTMA(bool tma_enabled,
bool CanUseTMA(::xla::EmitterLocOpBuilder& builder, bool tma_enabled,
const stream_executor::DeviceDescription& device_description,
TiledTensorType tiled_tensor_type) {
TiledTensorType tiled_tensor_type,
TypedValue<RankedTensorType> tensor) {
if (!tma_enabled) {
return false;
}
Expand All @@ -112,6 +115,17 @@ bool CanUseTMA(bool tma_enabled,
return false;
}

// We only enable TMA for inputs that have one use only.
auto block_arg = mlir::dyn_cast<BlockArgument>(tensor);
if (!block_arg || !block_arg.hasOneUse()) {
return false;
}
auto func_op =
mlir::dyn_cast<func::FuncOp>(block_arg.getOwner()->getParentOp());
if (!func_op) {
return false;
}

// Limitations of TMA:
// - The minor dimension of the global input must be divisible by 16.
// - The block size must be less than 256 in every dimension.
Expand All @@ -124,6 +138,23 @@ bool CanUseTMA(bool tma_enabled,
[](int64_t dim) { return dim > 256; });
}

// Tile Op is rewritten to tt.reinterpret_tensor_desc if TMA is used.
// During rewriting of other ops, such as ExtractOp and InsertOp, we need to
// check if TMA is used or not. This function basically checks that the
// backward slice of the op contains a ReinterpretTensorDescOp, indicating that
// TMA is to be used.
bool IsTmaUsed(Operation* op) {
SetVector<Operation*> backwardSlice;
BackwardSliceOptions opt;
mlir::getBackwardSlice(op, &backwardSlice, opt);
for (auto op : backwardSlice) {
if (mlir::isa<ReinterpretTensorDescOp>(op)) {
return true;
}
}
return false;
}

void ComputeBoundaryChecks(std::vector<int32_t>& boundary_checks,
const TiledTensorType& tiled_tensor_type) {
for (auto [dim_idx, sizes] :
Expand All @@ -137,12 +168,6 @@ void ComputeBoundaryChecks(std::vector<int32_t>& boundary_checks,
}

struct RewriteFuncOp : mlir::OpRewritePattern<func::FuncOp> {
RewriteFuncOp(mlir::MLIRContext* context,
const stream_executor::DeviceDescription* device_description,
bool tma_enabled)
: OpRewritePattern(context),
device_description(device_description),
tma_enabled(tma_enabled) {}
using OpRewritePattern::OpRewritePattern;

// Rewrite tensors<> to !tt.ptr<tensor>
Expand Down Expand Up @@ -181,8 +206,21 @@ struct RewriteFuncOp : mlir::OpRewritePattern<func::FuncOp> {

auto new_function_type = FunctionType::get(
op.getContext(), new_operand_types, /*result_types=*/{});
auto new_func = rewriter.create<triton::FuncOp>(op.getLoc(), op.getName(),
new_function_type);

// Transfer the argument attributes from the old function to the new one.
SmallVector<DictionaryAttr> arg_attrs;
if (op.getArgAttrs().has_value()) {
auto oldArgAttrsArray = op.getArgAttrs().value();
for (int i = 0; i < oldArgAttrsArray.size(); ++i) {
arg_attrs.push_back(
mlir::cast<mlir::DictionaryAttr>(oldArgAttrsArray[i]));
}
}

// Currently not propagating any function attributes to the new function.
ArrayRef<NamedAttribute> attrs;
auto new_func = rewriter.create<triton::FuncOp>(
op.getLoc(), op.getName(), new_function_type, attrs, arg_attrs);

rewriter.inlineRegionBefore(op.getRegion(), new_func.getFunctionBody(),
new_func.end());
Expand All @@ -195,9 +233,6 @@ struct RewriteFuncOp : mlir::OpRewritePattern<func::FuncOp> {

return mlir::success();
}

const stream_executor::DeviceDescription* device_description;
const bool tma_enabled;
};

struct RewriteTile : mlir::OpRewritePattern<TileOp> {
Expand All @@ -215,27 +250,42 @@ struct RewriteTile : mlir::OpRewritePattern<TileOp> {
TileOp op, mlir::PatternRewriter& rewriter) const override {
::xla::EmitterLocOpBuilder builder(op.getLoc(), rewriter);

// tensor -> !tt.ptr<>
auto cast_to_tensor_ptr_type =
builder
.create<mlir::UnrealizedConversionCastOp>(
GetTensorPtrType(builder,
op.getTensor().getType().getElementType()),
op.getTensor())
.getResult(0);
if (CanUseTMA(builder, tma_enabled, *device_description,
op.getTiledTensor().getType(), op.getTensor())) {
// Add TMA attributes to the corresponding argument in the function.
auto block_arg = mlir::dyn_cast<BlockArgument>(op.getTensor());
auto func_op =
mlir::dyn_cast<func::FuncOp>(block_arg.getOwner()->getParentOp());
func_op.setArgAttr(block_arg.getArgNumber(), "tt.nv_tma_desc",
builder.getI32IntegerAttr(1));
// Prefixing the attribute name with "tt", otherwise tt.func will
// complain that it is not part of the dialect. Not the best way to
// do this, but it works for now.
auto tiled_tensor_type = op.getTiledTensor().getType();
func_op.setArgAttr(
block_arg.getArgNumber(), "tt.tma_descriptor",
builder.getAttr<TmaDescriptorAttr>(
tiled_tensor_type.getOriginalShape(),
tiled_tensor_type.getTileShape(),
tiled_tensor_type.getElementType().getIntOrFloatBitWidth() / 8));

// tensor -> !tt.ptr<>
auto cast_to_tensor_ptr_type =
builder
.create<mlir::UnrealizedConversionCastOp>(
GetTensorPtrType(builder,
op.getTensor().getType().getElementType()),
op.getTensor())
.getResult(0);

if (CanUseTMA(tma_enabled, *device_description,
op.getTiledTensor().getType())) {
auto reinterpret_tensor_desc =
xg::EmitTmaDescriptor(builder, cast_to_tensor_ptr_type,
op.getTiledTensor().getType().getTileType());

// !tt.tensordesc<tensor> -> tiled_tensor
auto cast_desc_ptr_to_tiled_tensor_ptr_type =
builder.create<mlir::UnrealizedConversionCastOp>(
GetTensorDescPtrType(builder,
op.getTiledTensor().getType().getTileType()),
reinterpret_tensor_desc);
op.getTiledTensor().getType(), reinterpret_tensor_desc);

rewriter.replaceOp(op, cast_desc_ptr_to_tiled_tensor_ptr_type);
return mlir::success();
Expand All @@ -245,6 +295,15 @@ struct RewriteTile : mlir::OpRewritePattern<TileOp> {
std::vector<int32_t> dim_order(op.getSizes().size());
std::iota(dim_order.begin(), dim_order.end(), 0);

// tensor -> !tt.ptr<>
auto cast_to_tensor_ptr_type =
builder
.create<mlir::UnrealizedConversionCastOp>(
GetTensorPtrType(builder,
op.getTensor().getType().getElementType()),
op.getTensor())
.getResult(0);

auto tensor_ptr =
builder
.create<MakeTensorPtrOp>(
Expand All @@ -269,12 +328,6 @@ struct RewriteTile : mlir::OpRewritePattern<TileOp> {
};

struct RewriteExtract : mlir::OpRewritePattern<ExtractOp> {
RewriteExtract(mlir::MLIRContext* context,
const stream_executor::DeviceDescription* device_description,
bool tma_enabled)
: OpRewritePattern(context),
device_description(device_description),
tma_enabled(tma_enabled) {}
using OpRewritePattern::OpRewritePattern;

// Rewriting ExtractOp as tt.advance + tt.load if TMA is not enabled,
Expand All @@ -283,7 +336,7 @@ struct RewriteExtract : mlir::OpRewritePattern<ExtractOp> {
ExtractOp op, mlir::PatternRewriter& rewriter) const override {
::xla::EmitterLocOpBuilder builder(op.getLoc(), rewriter);

if (CanUseTMA(tma_enabled, *device_description, op.getSrc().getType())) {
if (IsTmaUsed(op.getSrc().getDefiningOp())) {
// tiled_tensor -> !tt.tensordesc<tensor>
auto cast_to_tensor_desc_ptr_type =
builder
Expand Down Expand Up @@ -334,18 +387,9 @@ struct RewriteExtract : mlir::OpRewritePattern<ExtractOp> {
rewriter.replaceOp(op, load);
return mlir::success();
}

const stream_executor::DeviceDescription* device_description;
const bool tma_enabled;
};

struct RewriteInsert : mlir::OpRewritePattern<InsertOp> {
RewriteInsert(mlir::MLIRContext* context,
const stream_executor::DeviceDescription* device_description,
bool tma_enabled)
: OpRewritePattern(context),
device_description(device_description),
tma_enabled(tma_enabled) {}
using OpRewritePattern::OpRewritePattern;

// Rewriting InsertOp as tt.advance + tt.store if TMA is not enabled,
Expand All @@ -354,7 +398,7 @@ struct RewriteInsert : mlir::OpRewritePattern<InsertOp> {
InsertOp op, mlir::PatternRewriter& rewriter) const override {
::xla::EmitterLocOpBuilder builder(op.getLoc(), rewriter);

if (CanUseTMA(tma_enabled, *device_description, op.getDst().getType())) {
if (IsTmaUsed(op.getDst().getDefiningOp())) {
// tiled_tensor -> !tt.tensordesc<tensor>
auto cast_to_tensor_desc_ptr_type =
builder
Expand Down Expand Up @@ -399,9 +443,6 @@ struct RewriteInsert : mlir::OpRewritePattern<InsertOp> {

return mlir::success();
}

const stream_executor::DeviceDescription* device_description;
const bool tma_enabled;
};

// Rewriting tensor::InsertOp as tt.store.
Expand Down Expand Up @@ -471,19 +512,25 @@ struct TritonXLAExtractInsertToTritonPass
tma_enabled = tma_enabled_;

mlir::MLIRContext* mlir_context = &getContext();

mlir::RewritePatternSet tile_pattern_set(mlir_context);
tile_pattern_set.add<RewriteTile>(mlir_context, &device_description,
tma_enabled);
auto tile_result = mlir::applyPatternsGreedily(getOperation(),
std::move(tile_pattern_set));

mlir::RewritePatternSet patterns(mlir_context);
// clang-format off
patterns.add<RewriteScalarExtract, RewriteScalarInsert>(mlir_context);
patterns.add<
RewriteExtract,
RewriteFuncOp,
RewriteInsert,
RewriteTile
>(mlir_context, &device_description, tma_enabled);

patterns.add<RewriteExtract,
RewriteFuncOp,
RewriteInsert,
RewriteScalarExtract,
RewriteScalarInsert>(mlir_context);
// clang-format on
if (mlir::failed(
mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
auto result =
mlir::applyPatternsGreedily(getOperation(), std::move(patterns));

if (mlir::failed(tile_result) && mlir::failed(result)) {
signalPassFailure();
}
}
Expand Down

0 comments on commit 17e913f

Please sign in to comment.