Skip to content

Commit

Permalink
[SYCLomatic] Enable the migration of kernel function pointer with new…
Browse files Browse the repository at this point in the history
… introduced helper class: kernel_launcher and wrapper_register (#2561)

Key changes during migration: 
  1.  For each CUDA kernel,  generate a  kernel_wrapper() (which actually is SYCL code to submit a kernel) if the CUDA kernel has been called indirectly. (eg. by function pointer)
  2. The kernel_wrapper() is launched by launch() member function of new helper class kernel_launcher. 
  3.  If the CUDA kernel is called by a raw pointer,  then the new helper class wrapper_register is used to register the map relationship between raw pointer and real kernel_wrappper().

Note: No change for migration of direct kernel call.

Signed-off-by: intwanghao <hao3.wang@intel.com>
  • Loading branch information
intwanghao authored Dec 19, 2024
1 parent d236fce commit c0852a0
Show file tree
Hide file tree
Showing 18 changed files with 988 additions and 217 deletions.
1 change: 1 addition & 0 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ REGISTER_RULE(EventAPICallRule, PassKind::PK_Migration)
REGISTER_RULE(ProfilingEnableOnDemandRule, PassKind::PK_Analysis)
REGISTER_RULE(StreamAPICallRule, PassKind::PK_Migration)
REGISTER_RULE(KernelCallRule, PassKind::PK_Analysis)
REGISTER_RULE(KernelCallRefRule, PassKind::PK_Migration)
REGISTER_RULE(DeviceFunctionDeclRule, PassKind::PK_Analysis)
REGISTER_RULE(MemVarRefMigrationRule, PassKind::PK_Migration)
REGISTER_RULE(ConstantMemVarMigrationRule, PassKind::PK_Migration)
Expand Down
324 changes: 231 additions & 93 deletions clang/lib/DPCT/AnalysisInfo.cpp

Large diffs are not rendered by default.

73 changes: 37 additions & 36 deletions clang/lib/DPCT/AnalysisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ class KernelCallExpr;
class DeviceFunctionInfo;
class CallFunctionExpr;
class DeviceFunctionDecl;
class DeviceFunctionDeclInModule;
class MemVarInfo;
class VarInfo;
class ExplicitInstantiationDecl;
Expand Down Expand Up @@ -239,6 +238,12 @@ struct RnnBackwardFuncInfo {
std::vector<std::string> FuncArgs;
};

struct DeviceFunctionInfoForWrapper {
std::vector<std::pair<std::string, std::string>> ParametersInfo;
std::vector<std::pair<std::string, std::string>> TemplateParametersInfo;
std::shared_ptr<KernelCallExpr> KernelForWrapper;
};

// <function name, Info>
using HDFuncInfoMap = std::unordered_map<std::string, HostDeviceFuncInfo>;
// <file path, <Offset, Info>>
Expand Down Expand Up @@ -1000,10 +1005,14 @@ class DpctGlobalInfo {
return Cur.get<TargetTy>();
});
}
template <class TargetTy, class NodeTy>
template <class TargetTy, class NodeTy, class... SkipNodeTy>
static auto findParent(const NodeTy *Node) {
return findAncestor<TargetTy>(
Node, [](const DynTypedNode &Cur) -> bool { return true; });
return findAncestor<TargetTy>(Node, [](const DynTypedNode &Cur) -> bool {
if ((... || Cur.get<SkipNodeTy>())) {
return false;
}
return true;
});
}

template <typename TargetTy, typename NodeTy>
Expand Down Expand Up @@ -1136,8 +1145,6 @@ class DpctGlobalInfo {
std::shared_ptr<DeviceFunctionDecl> insertDeviceFunctionDecl(
const FunctionDecl *Specialization, const FunctionTypeLoc &FTL,
const ParsedAttributes &Attrs, const TemplateArgumentListInfo &TAList);
std::shared_ptr<DeviceFunctionDecl>
insertDeviceFunctionDeclInModule(const FunctionDecl *FD);

// Build kernel and device function declaration replacements and store
// them.
Expand Down Expand Up @@ -1343,6 +1350,8 @@ class DpctGlobalInfo {
static bool useNoQueueDevice() {
return getHelperFuncPreference(HelperFuncPreference::NoQueueDevice);
}
static void setCVersionCUDALaunchUsed() { CVersionCUDALaunchUsedFlag = true; }
static bool isCVersionCUDALaunchUsed() { return CVersionCUDALaunchUsedFlag; }
static void setUseSYCLCompat(bool Flag = true) { UseSYCLCompatFlag = Flag; }
static bool useSYCLCompat() { return UseSYCLCompatFlag; }
static bool useEnqueueBarrier() {
Expand Down Expand Up @@ -1665,6 +1674,7 @@ class DpctGlobalInfo {
static unsigned HelperFuncPreferenceFlag;
static bool AnalysisModeFlag;
static bool UseSYCLCompatFlag;
static bool CVersionCUDALaunchUsedFlag;
static unsigned int ColorOption;
static std::unordered_map<int, std::shared_ptr<DeviceFunctionInfo>>
CubPlaceholderIndexMap;
Expand Down Expand Up @@ -2568,7 +2578,8 @@ class DeviceFunctionDecl {
LinkDecl(D, List, Info);
}
void setFuncInfo(std::shared_ptr<DeviceFunctionInfo> Info);

void insertWrapper();
void collectInfoForWrapper(const FunctionDecl *FD);
virtual ~DeviceFunctionDecl() = default;

protected:
Expand All @@ -2591,7 +2602,10 @@ class DeviceFunctionDecl {
bool IsDefFilePathNeeded = false;
std::vector<std::shared_ptr<TextureObjectInfo>> TextureObjectList;
FormatInfo FormatInformation;

bool HasBody = false;
size_t DeclEnd;
std::map<int, std::string> TemplateParameterDefaultValueMap;
std::map<int, std::string> ParameterDefaultValueMap;
static std::shared_ptr<DeviceFunctionInfo> &getFuncInfo(const FunctionDecl *);
static std::unordered_map<std::string, std::shared_ptr<DeviceFunctionInfo>>
FuncInfoMap;
Expand Down Expand Up @@ -2622,32 +2636,6 @@ class ExplicitInstantiationDecl : public DeviceFunctionDecl {
std::string getExtraParameters(LocInfo LI) override;
};

class DeviceFunctionDeclInModule : public DeviceFunctionDecl {
void insertWrapper();
bool HasBody = false;
size_t DeclEnd;
std::string FuncName;
std::vector<std::pair<std::string, std::string>> ParametersInfo;
std::shared_ptr<KernelCallExpr> Kernel;
void buildParameterInfo(const FunctionDecl *FD);
void buildWrapperInfo(const FunctionDecl *FD);
void buildCallInfo(const FunctionDecl *FD);
std::vector<std::pair<std::string, std::string>> &getParametersInfo() {
return ParametersInfo;
}

public:
DeviceFunctionDeclInModule(unsigned Offset,
const clang::tooling::UnifiedPath &FilePathIn,
const FunctionTypeLoc &FTL,
const ParsedAttributes &Attrs,
const FunctionDecl *FD);
DeviceFunctionDeclInModule(unsigned Offset,
const clang::tooling::UnifiedPath &FilePathIn,
const FunctionDecl *FD);
void emplaceReplacement() override;
};

// device function info includes parameters num, memory variable and call
// expression in the function.
class DeviceFunctionInfo {
Expand Down Expand Up @@ -2747,6 +2735,13 @@ class DeviceFunctionInfo {
bool isParameterReferenced(unsigned int Index);
void setParameterReferencedStatus(unsigned int Index, bool IsReferenced);
std::string getFunctionName() { return FunctionName; }
void collectInfoForWrapper(const FunctionDecl *FD);
void setModuleUsed() { ModuleUsed = true; }
bool isModuleUsed() { return ModuleUsed; }
std::shared_ptr<DeviceFunctionInfoForWrapper>
getDeviceFunctionInfoForWrapper() {
return DFInfoForWrapper;
}

private:
void mergeCalledTexObj(
Expand Down Expand Up @@ -2779,12 +2774,15 @@ class DeviceFunctionInfo {
bool CallGroupFunctionInControlFlow = false;
bool HasCheckedCallGroupFunctionInControlFlow = false;
OverloadedOperatorKind OO_Kind = OverloadedOperatorKind::OO_None;
bool ModuleUsed = false;
std::shared_ptr<DeviceFunctionInfoForWrapper> DFInfoForWrapper = nullptr;
};

class KernelCallExpr : public CallFunctionExpr {
public:
bool IsInMacroDefine = false;
bool NeedLambda = false;
bool IsForWrapper = false;
bool NeedDefaultRetValue = false;

private:
Expand Down Expand Up @@ -2857,8 +2855,10 @@ class KernelCallExpr : public CallFunctionExpr {
const std::pair<clang::tooling::UnifiedPath, unsigned> &LocInfo,
const CallExpr *, bool IsAssigned = false);
static std::shared_ptr<KernelCallExpr>
buildForWrapper(clang::tooling::UnifiedPath, const FunctionDecl *,
std::shared_ptr<DeviceFunctionInfo>);
buildForWrapper(clang::tooling::UnifiedPath, const FunctionDecl *);
void setTemplateArgsStrForWrapper(std::string Str) {
TemplateArgsStrForWrapper = std::move(Str);
}
unsigned int GridDim = 3;
unsigned int BlockDim = 3;
void setEmitSizeofWarningFlag(bool Flag) { EmitSizeofWarning = Flag; }
Expand Down Expand Up @@ -2963,6 +2963,7 @@ class KernelCallExpr : public CallFunctionExpr {
OuterStmtsList OuterStmts;
StmtList KernelStmts;
std::string KernelArgs;
std::string TemplateArgsStrForWrapper;
int TotalArgsSize = 0;
bool EmitSizeofWarning = false;
unsigned int SizeOfHighestDimension = 0;
Expand Down
Loading

0 comments on commit c0852a0

Please sign in to comment.