Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Triton to c75c6b034756629b891e7b2df406f634552331d5 #223

Merged
merged 9 commits into from
Feb 6, 2025
2 changes: 1 addition & 1 deletion backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def parse_options(self, opts) -> Any:
args.update({k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts})
return CPUOptions(**args)

def get_codegen_implementation(self):
def get_codegen_implementation(self, options):
codegen_fns = {"min_dot_size": lambda lhsType, rhsType: (1, 1, 1)}
return codegen_fns

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ struct MakeTensorPtrConverter
SmallVector<Value> newOffsets;
for (auto [offset, stride] :
llvm::zip(pointerState.offsets, pointerState.strides)) {
auto mulOp = rewriter.create<arith::MulIOp>(loc, offset.get<Value>(),
stride.get<Value>());
auto mulOp = rewriter.create<arith::MulIOp>(loc, cast<Value>(offset),
cast<Value>(stride));
newOffsets.push_back(mulOp.getResult());
}

Expand Down Expand Up @@ -435,7 +435,7 @@ struct LoadConverter : public OpConversionPattern<triton::LoadOp> {
Value dimi = dyn_cast<Value>(mstate.dims[i]);
if (!dimi) {
dimi = rewriter.create<arith::ConstantOp>(
loc, cast<IntegerAttr>(mstate.dims[i].get<Attribute>()));
loc, cast<IntegerAttr>(cast<Attribute>(mstate.dims[i])));
}

auto cmpOp = rewriter.create<arith::CmpIOp>(
Expand Down Expand Up @@ -1236,9 +1236,10 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
}

bool requiresF32Conversion(const Type elemType, Operation *redOp) const {
unsigned width =
cast<FloatType>(Float32Type::get(elemType.getContext())).getWidth();
nhat-nguyen marked this conversation as resolved.
Show resolved Hide resolved
return isa<FloatType>(elemType) &&
elemType.getIntOrFloatBitWidth() <
Float32Type::get(elemType.getContext()).getWidth() &&
elemType.getIntOrFloatBitWidth() < width &&
isa<arith::AddFOp>(redOp);
}

Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/OpFoldResultUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
namespace mlir {

std::optional<int64_t> getIntAttr(const OpFoldResult ofr) {
if (ofr.is<Attribute>() && isa<IntegerAttr>(ofr.get<Attribute>()))
return dyn_cast<IntegerAttr>(ofr.get<Attribute>()).getInt();
if (isa<Attribute>(ofr) && isa<IntegerAttr>(cast<Attribute>(ofr)))
return dyn_cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();

return std::nullopt;
}
Expand Down Expand Up @@ -185,7 +185,7 @@ OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs,

// 2. if lhs is not constant
assert(!lhsIntAttr);
auto mulOp = b.create<arith::MulIOp>(loc, lhs.get<Value>(), rhs);
auto mulOp = b.create<arith::MulIOp>(loc, cast<Value>(lhs), rhs);
return mulOp.getResult();
}

Expand Down
12 changes: 6 additions & 6 deletions lib/Analysis/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,12 +862,12 @@ void PtrAnalysis::rewriteAdvanceOp(
op.getLoc(), rewriter.getIndexAttr(0));
offsetValue = constOp.getResult();
} else {
offsetValue = offset.get<Value>();
offsetValue = cast<Value>(offset);
}
auto castOp = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), increment);
auto mulOp = rewriter.create<arith::MulIOp>(loc, castOp.getResult(),
stride.get<Value>());
cast<Value>(stride));
auto addOp =
rewriter.create<arith::AddIOp>(loc, mulOp.getResult(), offsetValue);
newOffsets.push_back(addOp.getResult());
Expand Down Expand Up @@ -999,15 +999,15 @@ void PtrAnalysis::rewriteYieldOp(
op.getLoc(), rewriter.getIndexAttr(0));
operands.push_back(constOp.getResult());
} else {
operands.push_back(s.get<Value>());
operands.push_back(cast<Value>(s));
}
}

for (auto s : state.strides) {
assert(!getIntAttr(s) && "PtrState strides for yield within for "
"loop not expected to be "
"attribute.");
operands.push_back(s.get<Value>());
operands.push_back(cast<Value>(s));
}
}

Expand Down Expand Up @@ -1171,7 +1171,7 @@ void PtrAnalysis::rewriteForOp(
newInitArgs.push_back(constOp.getResult());
state.offsets[j] = constOp.getResult();
} else {
newInitArgs.push_back(s.get<Value>());
newInitArgs.push_back(cast<Value>(s));
}
}

Expand All @@ -1183,7 +1183,7 @@ void PtrAnalysis::rewriteForOp(
newInitArgs.push_back(constOp.getResult());
state.strides[j] = constOp.getResult();
} else {
newInitArgs.push_back(s.get<Value>());
newInitArgs.push_back(cast<Value>(s));
}
}

Expand Down
8 changes: 4 additions & 4 deletions lib/AnalysisStructured/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,12 +793,12 @@ LogicalResult PtrAnalysis::rewriteAdvanceOp(triton::AdvanceOp op) {
loc, builder.getIndexAttr(offsetIntAttr.value()));
offsetValue = constOp.getResult();
} else {
offsetValue = offset.get<Value>();
offsetValue = cast<Value>(offset);
}
auto castOp = builder.create<arith::IndexCastOp>(
loc, builder.getIndexType(), increment);
auto mulOp = builder.create<arith::MulIOp>(loc, castOp.getResult(),
stride.get<Value>());
cast<Value>(stride));
auto addOp =
builder.create<arith::AddIOp>(loc, mulOp.getResult(), offsetValue);
newOffsets.push_back(addOp.getResult());
Expand Down Expand Up @@ -1029,7 +1029,7 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) {
op.getLoc(), builder.getIndexAttr(sIntAttr.value()));
replacements.push_back(constOp.getResult());
} else {
replacements.push_back(s.get<Value>());
replacements.push_back(cast<Value>(s));
}
}

Expand All @@ -1040,7 +1040,7 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) {
op.getLoc(), builder.getIndexAttr(sIntAttr.value()));
replacements.push_back(constOp.getResult());
} else {
replacements.push_back(s.get<Value>());
replacements.push_back(cast<Value>(s));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class TritonArithToLinalgPass

tensor::populateDecomposeTensorConcatPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
return failure();
}
return success();
Expand All @@ -103,7 +103,7 @@ class TritonArithToLinalgPass
{
RewritePatternSet patterns(&getContext());
populateTritonArithToLinalgCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
signalPassFailure();
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class TritonToLinalgPass : public TritonToLinalgBase<TritonToLinalgPass> {
{
RewritePatternSet patterns(&getContext());
populateTritonToLinalgCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
signalPassFailure();
}
}
Expand Down
9 changes: 7 additions & 2 deletions python/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,17 @@ def device(request):
# tt.gather not supported yet
"test_gather",
"test_gather_warp_shuffle",
# device 'cpu' does not have 'index
# device 'cpu' does not have 'index'
"test_zero_strided_tensors",
# hard-coded with 'ttg' attributes
"test_convert_mma2mma",
"test_local_load_store",
"test_local_load_store_mma"
"test_local_load_store_mma",
"test_convert_warp_local",
# hard-code to use 'cuda' device
"test_scan_1d",
"test_tma_load_block_shape_err",
"test_tma_store_block_shape_err"
}

# probably different version of MLIR on the nightly build machine is complaining
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ module {
%subview = memref.subview %reinterpret_cast[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
%subview_0 = memref.subview %alloc[0] [%9] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>>
memref.copy %subview, %subview_0 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>>
%10 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32>
%10 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32> to tensor<1024xf32>
%reinterpret_cast_1 = memref.reinterpret_cast %1 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
%alloc_2 = memref.alloc() : memref<1024xf32>
%subview_3 = memref.subview %reinterpret_cast_1[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
%subview_4 = memref.subview %alloc_2[0] [%9] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>>
memref.copy %subview_3, %subview_4 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>>
%11 = bufferization.to_tensor %alloc_2 restrict writable : memref<1024xf32>
%11 = bufferization.to_tensor %alloc_2 restrict writable : memref<1024xf32> to tensor<1024xf32>
%12 = arith.addf %10, %11 : tensor<1024xf32>
%reinterpret_cast_5 = memref.reinterpret_cast %0 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
%extracted_slice = tensor.extract_slice %12[0] [%9] [1] : tensor<1024xf32> to tensor<?xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module {
%8 = tt.splat %arg3 : i32 -> tensor<1024xi32>
%9 = arith.cmpi slt, %7, %8 : tensor<1024xi32>
%cast = memref.cast %2 : memref<*xf32> to memref<?xf32>
%10 = bufferization.to_tensor %cast restrict : memref<?xf32>
%10 = bufferization.to_tensor %cast restrict : memref<?xf32> to tensor<?xf32>
%11 = tensor.empty() : tensor<1024xf32>
%12 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%7, %9 : tensor<1024xi32>, tensor<1024xi1>) outs(%11 : tensor<1024xf32>) {
^bb0(%in: i32, %in_2: i1, %out: f32):
Expand All @@ -30,7 +30,7 @@ module {
linalg.yield %17 : f32
} -> tensor<1024xf32>
%cast_0 = memref.cast %1 : memref<*xf32> to memref<?xf32>
%13 = bufferization.to_tensor %cast_0 restrict : memref<?xf32>
%13 = bufferization.to_tensor %cast_0 restrict : memref<?xf32> to tensor<?xf32>
%14 = tensor.empty() : tensor<1024xf32>
%15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%7, %9 : tensor<1024xi32>, tensor<1024xi1>) outs(%14 : tensor<1024xf32>) {
^bb0(%in: i32, %in_2: i1, %out: f32):
Expand Down
7 changes: 3 additions & 4 deletions test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ module {
// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<4x256xbf16>
// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_3]] : i32 to index
// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_17]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>>
// CHECK: %[[VAL_19:.*]]:4 = scf.for %[[VAL_20:.*]] = %[[VAL_12]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_17]], %[[VAL_24:.*]] = %[[VAL_12]]) -> (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index, index) {
// CHECK: %[[VAL_19:.*]]:3 = scf.for %[[VAL_20:.*]] = %[[VAL_12]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_17]]) -> (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index) {
// CHECK: %[[VAL_25:.*]] = memref.alloc() : memref<4x256xbf16>
// CHECK: memref.copy %[[VAL_22]], %[[VAL_25]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16>
// CHECK: %[[VAL_26:.*]] = bufferization.to_tensor %[[VAL_25]] restrict writable : memref<4x256xbf16>
Expand All @@ -81,10 +81,9 @@ module {
// CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_28]], %[[VAL_29]] : bf16
// CHECK: linalg.yield %[[VAL_31]] : bf16
// CHECK: } -> tensor<4x256xbf16>
// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_23]], %[[VAL_10]] : index
// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_32]], %[[VAL_24]] : index
// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_23]], %[[VAL_10]] : index
// CHECK: %[[VAL_34:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_33]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>>
// CHECK: scf.yield %[[VAL_35:.*]], %[[VAL_34]], %[[VAL_33]], %[[VAL_12]] : tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index, index
// CHECK: scf.yield %[[VAL_35:.*]], %[[VAL_34]], %[[VAL_33]] : tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index
// CHECK: }
// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_3]] : i32 to index
// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_36]]], sizes: [4, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
Expand Down
25 changes: 11 additions & 14 deletions test/Conversion/TritonToLinalg/block_ptr_advance.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ module {
// CHECK: module {
// CHECK: func.func @matmul_kernel_with_block_pointers_01234567891011(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: memref<*xbf16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32, %arg18: i32, %arg19: i32) {
// CHECK: %c64 = arith.constant 64 : index
// CHECK: %c0 = arith.constant 0 : index
// CHECK: %c256_i32 = arith.constant 256 : i32
// CHECK: %c0_i32 = arith.constant 0 : i32
// CHECK: %c64_i32 = arith.constant 64 : i32
Expand All @@ -51,7 +50,7 @@ module {
// CHECK: %7 = arith.addi %5, %6 : index
// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%7], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg0 to offset: [%5], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: %8:7 = scf.for %arg20 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg21 = %1, %arg22 = %reinterpret_cast, %arg23 = %reinterpret_cast_0, %arg24 = %7, %arg25 = %c0, %arg26 = %5, %arg27 = %c0) -> (tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index, index, index) : i32 {
// CHECK: %8:5 = scf.for %arg20 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg21 = %1, %arg22 = %reinterpret_cast, %arg23 = %reinterpret_cast_0, %arg24 = %7, %arg25 = %5) -> (tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index) : i32 {
// CHECK: %alloc = memref.alloc() : memref<128x64xbf16>
// CHECK: memref.copy %arg22, %alloc : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16>
// CHECK: %17 = bufferization.to_tensor %alloc restrict writable : memref<128x64xbf16>
Expand All @@ -60,23 +59,21 @@ module {
// CHECK: %18 = bufferization.to_tensor %alloc_2 restrict writable : memref<128x64xbf16>
// CHECK: %19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%17, %18 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%17 : tensor<128x64xbf16>) {
// CHECK: ^bb0(%in: bf16, %in_5: bf16, %out: bf16):
// CHECK: %27 = arith.addf %in, %in_5 : bf16
// CHECK: linalg.yield %27 : bf16
// CHECK: %25 = arith.addf %in, %in_5 : bf16
// CHECK: linalg.yield %25 : bf16
// CHECK: } -> tensor<128x64xbf16>
// CHECK: %20 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg21, %19 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%arg21 : tensor<128x64xbf16>) {
// CHECK: ^bb0(%in: bf16, %in_5: bf16, %out: bf16):
// CHECK: %27 = arith.addf %in, %in_5 : bf16
// CHECK: linalg.yield %27 : bf16
// CHECK: %25 = arith.addf %in, %in_5 : bf16
// CHECK: linalg.yield %25 : bf16
// CHECK: } -> tensor<128x64xbf16>
// CHECK: %21 = arith.muli %4, %c64 : index
// CHECK: %22 = arith.addi %21, %arg25 : index
// CHECK: %23 = arith.addi %arg24, %22 : index
// CHECK: %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%23], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: %24 = arith.muli %3, %c64 : index
// CHECK: %25 = arith.addi %24, %arg26 : index
// CHECK: %26 = arith.addi %25, %arg27 : index
// CHECK: %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%26], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: scf.yield %20, %reinterpret_cast_3, %reinterpret_cast_4, %23, %c0, %26, %c0 : tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index, index, index
// CHECK: %22 = arith.addi %arg24, %21 : index
// CHECK: %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%22], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: %23 = arith.muli %3, %c64 : index
// CHECK: %24 = arith.addi %23, %arg25 : index
// CHECK: %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%24], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: scf.yield %20, %reinterpret_cast_3, %reinterpret_cast_4, %22, %24 : tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index
// CHECK: }
// CHECK: %9 = arith.muli %arg13, %c256_i32 : i32
// CHECK: %10 = arith.index_cast %arg12 : i32 to index
Expand Down
1 change: 1 addition & 0 deletions tools/triton-shared-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ target_link_libraries(triton-shared-opt PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonTestDialectTritonGPU
TritonSharedAnalysis
${dialect_libs}
${conversion_libs}
Expand Down
2 changes: 1 addition & 1 deletion triton
Submodule triton updated 292 files