From 9a5b3a1bbca6790602ec3291da850fc4485cc807 Mon Sep 17 00:00:00 2001 From: Adam Yang Date: Tue, 29 Oct 2024 10:17:35 -0700 Subject: [PATCH] [DXIL] Add GroupMemoryBarrierWithGroupSync intrinsic (#111884) fixes #112974 partially fixes #70103 ### Changes - Added new tablegen based way of lowering dx intrinsics to DXIL ops. - Added int_dx_group_memory_barrier_with_group_sync intrinsic in IntrinsicsDirectX.td - Added expansion for int_dx_group_memory_barrier_with_group_sync in DXILIntrinsicExpansion.cpp` - Added DXIL backend test case ### Related PRs * [[clang][HLSL] Add GroupMemoryBarrierWithGroupSync intrinsic #111883](https://github.com/llvm/llvm-project/pull/111883) * [[SPIRV] Add GroupMemoryBarrierWithGroupSync intrinsic #111888](https://github.com/llvm/llvm-project/pull/111888) --- llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 + llvm/lib/Target/DirectX/DXIL.td | 54 ++++++++ llvm/lib/Target/DirectX/DXILOpLowering.cpp | 45 +++++-- .../group_memory_barrier_with_group_sync.ll | 8 ++ llvm/utils/TableGen/DXILEmitter.cpp | 122 ++++++++++++++++-- 5 files changed, 209 insertions(+), 22 deletions(-) create mode 100644 llvm/test/CodeGen/DirectX/group_memory_barrier_with_group_sync.ll diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index e30d37f69f781..dada426368995 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -92,4 +92,6 @@ def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>; def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>; + +def int_dx_group_memory_barrier_with_group_sync : DefaultAttrsIntrinsic<[], [], []>; } diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 1e8dc63ffa257..263ca50011aa7 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -294,6 +294,43 @@ class Attributes attrs> { list op_attrs = attrs; } +class DXILConstant { + int value = value_; +} + +defset list BarrierModes = { + def BarrierMode_DeviceMemoryBarrier : DXILConstant<2>; + def BarrierMode_DeviceMemoryBarrierWithGroupSync : DXILConstant<3>; + def BarrierMode_GroupMemoryBarrier : DXILConstant<8>; + def BarrierMode_GroupMemoryBarrierWithGroupSync : DXILConstant<9>; + def BarrierMode_AllMemoryBarrier : DXILConstant<10>; + def BarrierMode_AllMemoryBarrierWithGroupSync : DXILConstant<11>; +} + +// Intrinsic arg selection +class Arg { + int index = -1; + DXILConstant value; + bit is_i8 = 0; + bit is_i32 = 0; +} +class ArgSelect : Arg { + let index = index_; +} +class ArgI32 : Arg { + let value = value_; + let is_i32 = 1; +} +class ArgI8 : Arg { + let value = value_; + let is_i8 = 1; +} + +class IntrinsicSelect args_> { + Intrinsic intrinsic = intrinsic_; + list args = args_; +} + // Abstraction DXIL Operation class DXILOp { // A short description of the operation @@ -308,6 +345,9 @@ class DXILOp { // LLVM Intrinsic DXIL Operation maps to Intrinsic LLVMIntrinsic = ?; + // Non-trivial LLVM Intrinsics DXIL Operation maps to + list intrinsic_selects = []; + // Result type of the op DXILOpParamType result; @@ -829,3 +869,17 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> { let stages = [Stages]; let attributes = [Attributes]; } + +def Barrier : DXILOp<80, barrier> { + let Doc = "inserts a memory barrier in the shader"; + let intrinsic_selects = [ + IntrinsicSelect< + int_dx_group_memory_barrier_with_group_sync, + [ ArgI32 ]>, + ]; + + let arguments = [Int32Ty]; + let result = VoidTy; + let stages = [Stages]; + let attributes = [Attributes]; +} diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 8acc9c1efa08c..b5cf1654181c6 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -106,17 +106,43 @@ class OpLowerer { return false; } - [[nodiscard]] - bool replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) { + struct ArgSelect { + enum class Type { + Index, + I8, + I32, + }; + Type Type = Type::Index; + int Value = -1; + }; + + [[nodiscard]] bool replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp, + ArrayRef ArgSelects) { bool IsVectorArgExpansion = isVectorArgExpansion(F); return replaceFunction(F, [&](CallInst *CI) -> Error { - SmallVector Args; OpBuilder.getIRB().SetInsertPoint(CI); - if (IsVectorArgExpansion) { - SmallVector NewArgs = argVectorFlatten(CI, OpBuilder.getIRB()); - Args.append(NewArgs.begin(), NewArgs.end()); - } else + SmallVector Args; + if (ArgSelects.size()) { + for (const ArgSelect &A : ArgSelects) { + switch (A.Type) { + case ArgSelect::Type::Index: + Args.push_back(CI->getArgOperand(A.Value)); + break; + case ArgSelect::Type::I8: + Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value)); + break; + case ArgSelect::Type::I32: + Args.push_back(OpBuilder.getIRB().getInt32(A.Value)); + break; + default: + llvm_unreachable("Invalid type of intrinsic arg select."); + } + } + } else if (IsVectorArgExpansion) { + Args = argVectorFlatten(CI, OpBuilder.getIRB()); + } else { Args.append(CI->arg_begin(), CI->arg_end()); + } Expected OpCall = OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType()); @@ -583,9 +609,10 @@ class OpLowerer { switch (ID) { default: continue; -#define DXIL_OP_INTRINSIC(OpCode, Intrin) \ +#define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \ case Intrin: \ - HasErrors |= replaceFunctionWithOp(F, OpCode); \ + HasErrors |= \ + replaceFunctionWithOp(F, OpCode, ArrayRef{__VA_ARGS__}); \ break; #include "DXILOperation.inc" case Intrinsic::dx_handle_fromBinding: diff --git a/llvm/test/CodeGen/DirectX/group_memory_barrier_with_group_sync.ll b/llvm/test/CodeGen/DirectX/group_memory_barrier_with_group_sync.ll new file mode 100644 index 0000000000000..baf93d4e177f0 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/group_memory_barrier_with_group_sync.ll @@ -0,0 +1,8 @@ +; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s + +define void @test_group_memory_barrier_with_group_sync() { +entry: + ; CHECK: call void @dx.op.barrier(i32 80, i32 9) + call void @llvm.dx.group.memory.barrier.with.group.sync() + ret void +} \ No newline at end of file diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp index 467a6163ae3b0..8594233244638 100644 --- a/llvm/utils/TableGen/DXILEmitter.cpp +++ b/llvm/utils/TableGen/DXILEmitter.cpp @@ -32,6 +32,20 @@ using namespace llvm::dxil; namespace { +struct DXILArgSelect { + enum class Type { + Index, + I32, + I8, + }; + Type Type = Type::Index; + int Value = -1; +}; +struct DXILIntrinsicSelect { + StringRef Intrinsic; + SmallVector Args; +}; + struct DXILOperationDesc { std::string OpName; // name of DXIL operation int OpCode; // ID of DXIL operation @@ -42,8 +56,7 @@ struct DXILOperationDesc { SmallVector OverloadRecs; SmallVector StageRecs; SmallVector AttrRecs; - StringRef Intrinsic; // The llvm intrinsic map to OpName. Default is "" which - // means no map exists + SmallVector IntrinsicSelects; SmallVector ShaderStages; // shader stages to which this applies, empty for all. int OverloadParamIndex; // Index of parameter with overload type. @@ -71,6 +84,21 @@ static void AscendingSortByVersion(std::vector &Recs) { }); } +/// Take a `int_{intrinsic_name}` and return just the intrinsic_name part if +/// available. Otherwise return the empty string. +static StringRef GetIntrinsicName(const RecordVal *RV) { + if (RV && RV->getValue()) { + if (const DefInit *DI = dyn_cast(RV->getValue())) { + auto *IntrinsicDef = DI->getDef(); + auto DefName = IntrinsicDef->getName(); + assert(DefName.starts_with("int_") && "invalid intrinsic name"); + // Remove the int_ from intrinsic name. + return DefName.substr(4); + } + } + return ""; +} + /// Construct an object using the DXIL Operation records specified /// in DXIL.td. This serves as the single source of reference of /// the information extracted from the specified Record R, for @@ -157,14 +185,63 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { OpName); } - const RecordVal *RV = R->getValue("LLVMIntrinsic"); - if (RV && RV->getValue()) { - if (const DefInit *DI = dyn_cast(RV->getValue())) { - auto *IntrinsicDef = DI->getDef(); - auto DefName = IntrinsicDef->getName(); - assert(DefName.starts_with("int_") && "invalid intrinsic name"); - // Remove the int_ from intrinsic name. - Intrinsic = DefName.substr(4); + { + DXILIntrinsicSelect IntrSelect; + IntrSelect.Intrinsic = GetIntrinsicName(R->getValue("LLVMIntrinsic")); + if (IntrSelect.Intrinsic.size()) + IntrinsicSelects.emplace_back(std::move(IntrSelect)); + } + + auto IntrinsicSelectRecords = R->getValueAsListOfDefs("intrinsic_selects"); + if (IntrinsicSelectRecords.size()) { + if (IntrinsicSelects.size()) { + PrintFatalError( + R, Twine("LLVMIntrinsic and intrinsic_selects cannot be both " + "defined for DXIL operation - ") + + OpName); + } else { + for (const Record *R : IntrinsicSelectRecords) { + DXILIntrinsicSelect IntrSelect; + IntrSelect.Intrinsic = GetIntrinsicName(R->getValue("intrinsic")); + auto Args = R->getValueAsListOfDefs("args"); + for (const Record *Arg : Args) { + bool IsI8 = Arg->getValueAsBit("is_i8"); + bool IsI32 = Arg->getValueAsBit("is_i32"); + int Index = Arg->getValueAsInt("index"); + const Record *ValueRec = Arg->getValueAsOptionalDef("value"); + + DXILArgSelect ArgSelect; + if (IsI8) { + if (!ValueRec) { + PrintFatalError(R, Twine("'value' must be defined for i8 " + "ArgSelect for DXIL operation - ") + + OpName); + } + ArgSelect.Type = DXILArgSelect::Type::I8; + ArgSelect.Value = ValueRec->getValueAsInt("value"); + } else if (IsI32) { + if (!ValueRec) { + PrintFatalError(R, Twine("'value' must be defined for i32 " + "ArgSelect for DXIL operation - ") + + OpName); + } + ArgSelect.Type = DXILArgSelect::Type::I32; + ArgSelect.Value = ValueRec->getValueAsInt("value"); + } else { + if (Index < 0) { + PrintFatalError( + R, Twine("Index in ArgSelect must be equal to or " + "greater than 0 for DXIL operation - ") + + OpName); + } + ArgSelect.Type = DXILArgSelect::Type::Index; + ArgSelect.Value = Index; + } + + IntrSelect.Args.emplace_back(std::move(ArgSelect)); + } + IntrinsicSelects.emplace_back(std::move(IntrSelect)); + } } } } @@ -377,10 +454,29 @@ static void emitDXILIntrinsicMap(ArrayRef Ops, OS << "#ifdef DXIL_OP_INTRINSIC\n"; OS << "\n"; for (const auto &Op : Ops) { - if (Op.Intrinsic.empty()) + if (Op.IntrinsicSelects.empty()) { continue; - OS << "DXIL_OP_INTRINSIC(dxil::OpCode::" << Op.OpName - << ", Intrinsic::" << Op.Intrinsic << ")\n"; + } + for (const DXILIntrinsicSelect &MappedIntr : Op.IntrinsicSelects) { + OS << "DXIL_OP_INTRINSIC(dxil::OpCode::" << Op.OpName + << ", Intrinsic::" << MappedIntr.Intrinsic; + for (const DXILArgSelect &ArgSelect : MappedIntr.Args) { + OS << ", (ArgSelect { "; + switch (ArgSelect.Type) { + case DXILArgSelect::Type::Index: + OS << "ArgSelect::Type::Index, "; + break; + case DXILArgSelect::Type::I8: + OS << "ArgSelect::Type::I8, "; + break; + case DXILArgSelect::Type::I32: + OS << "ArgSelect::Type::I32, "; + break; + } + OS << ArgSelect.Value << "})"; + } + OS << ")\n"; + } } OS << "\n"; OS << "#undef DXIL_OP_INTRINSIC\n";