-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding input sharding collection in ModuleBuilder #289
base: main
Are you sure you want to change the base?
Conversation
|
260fb4d
to
3f348ad
Compare
Codecov ReportAttention: Patch coverage is
✅ All tests successful. No failed tests found.
Additional details and impacted files@@ Coverage Diff @@
## main #289 +/- ##
==========================================
+ Coverage 77.97% 78.65% +0.67%
==========================================
Files 21 21
Lines 990 1059 +69
==========================================
+ Hits 772 833 +61
- Misses 218 226 +8 ☔ View full report in Codecov by Sentry. |
void ModuleBuilder::collectInputShardings( | ||
const mlir::OwningOpRef<mlir::ModuleOp> &module) { | ||
DLOG_F(LOG_DEBUG, "ModuleBuilder::collectInputShardings"); | ||
m_input_shardings.clear(); | ||
|
||
module.get().walk([&](mlir::Operation *op) { | ||
mlir::func::FuncOp funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op); | ||
mlir::ModuleOp moduleOp = mlir::dyn_cast<mlir::ModuleOp>(op); | ||
|
||
if (!funcOp || !funcOp.isPublic()) { | ||
return; | ||
} | ||
|
||
for (unsigned i = 0; i < funcOp.getNumArguments(); ++i) { | ||
|
||
mlir::tt::sharding_utils::MeshSharding meshSharding; | ||
|
||
auto shardingAttr = llvm::dyn_cast_if_present<mlir::StringAttr>( | ||
funcOp.getArgAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr)); | ||
|
||
if (!shardingAttr) { | ||
m_input_shardings.push_back(meshSharding); | ||
continue; | ||
} | ||
|
||
mlir::LogicalResult conversionResult = | ||
fillMeshShardingFromGSPMDString(shardingAttr, meshSharding); | ||
|
||
if (conversionResult.failed()) { | ||
m_status = tt_pjrt_status::kInternal; | ||
return; | ||
} | ||
m_input_shardings.push_back(meshSharding); | ||
} | ||
}); | ||
} | ||
|
||
void ModuleBuilder::collectOutputShardings( | ||
const mlir::OwningOpRef<mlir::ModuleOp> &module) { | ||
DLOG_F(LOG_DEBUG, "ModuleBuilder::collectOutputShardings"); | ||
m_output_shardings.clear(); | ||
|
||
module.get().walk([&](mlir::Operation *op) { | ||
mlir::func::FuncOp funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op); | ||
mlir::ModuleOp moduleOp = mlir::dyn_cast<mlir::ModuleOp>(op); | ||
|
||
if (!funcOp || !funcOp.isPublic()) { | ||
return; | ||
} | ||
|
||
for (unsigned i = 0; i < funcOp.getNumResults(); ++i) { | ||
|
||
mlir::tt::sharding_utils::MeshSharding meshSharding; | ||
|
||
auto shardingAttr = llvm::dyn_cast_if_present<mlir::StringAttr>( | ||
funcOp.getResultAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr)); | ||
|
||
if (!shardingAttr) { | ||
m_output_shardings.push_back(meshSharding); | ||
continue; | ||
} | ||
|
||
mlir::LogicalResult conversionResult = | ||
fillMeshShardingFromGSPMDString(shardingAttr, meshSharding); | ||
|
||
if (conversionResult.failed()) { | ||
m_status = tt_pjrt_status::kInternal; | ||
return; | ||
} | ||
m_output_shardings.push_back(meshSharding); | ||
} | ||
}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two look quite similar. The only differences are m_input_shardings
vs m_output_shardings
and getArgAttr
instead of getResultAttr
. Can we somehow extract it in a common function if it is not too cumbersome (like passing function pointers or something, let's not do that, it's an overkill)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue I see here is with the getArgAttr
and getResultAttr
, I can use templates or just pass a bollean if it makes more sense? Not sure of the tradeoff between clarity and the code size here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's what I though. Leave it like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do something like:
void ModuleBuilder::collectInputShardings(
const mlir::OwningOpRef<mlir::ModuleOp> &module) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::collectInputShardings");
m_input_shardings.clear();
std::vector<mlir::StringAttr> gspmd_attributes;
for (unsigned i = 0; i < funcOp.getNumArguments(); ++i) {
gspmd_attributes.push_back(llvm::dyn_cast_if_present<mlir::StringAttr>(
funcOp.getArgAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr)));
}
mlir::LogicalResult result = createShardingsFromGSPMD(gspmd_attributes, m_input_shardings);
if (!tt_pjrt_status_is_ok(result)) {
m_status = tt_pjrt_status::kInternal;
}
}
void ModuleBuilder::collectOutputShardings(
const mlir::OwningOpRef<mlir::ModuleOp> &module) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::collectOutputShardings");
m_output_shardings.clear();
std::vector<mlir::StringAttr> gspmd_attributes;
for (unsigned i = 0; i < funcOp.getNumResults(); ++i) {
gspmd_attributes.push_back(llvm::dyn_cast_if_present<mlir::StringAttr>(
funcOp.getResultAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr)));
}
mlir::LogicalResult result = createShardingsFromGSPMD(gspmd_attributes, m_output_shardings);
if (!tt_pjrt_status_is_ok(result)) {
m_status = tt_pjrt_status::kInternal;
}
}
mlir::LogicalResult ModuleBuilder::createShardingsFromGSPMD(
const std::vector<mlir::StringAttr> &gspmd_attributes,
std::vector<mlir::tt::sharding_utils::MeshSharding> &shardings) {
for (mlir::StringAttr gspmd_attr : gspmd_attributes) {
mlir::tt::sharding_utils::MeshSharding meshSharding;
if (!gspmd_attr) {
shardings.push_back(meshSharding);
continue;
}
mlir::LogicalResult conversionResult =
fillMeshShardingFromGSPMDString(gspmd_attr, meshSharding);
if (conversionResult.failed()) {
return llvm::LogicalResult::failure();
}
shardings.push_back(meshSharding);
}
return llvm::LogicalResult::success();
}
3f348ad
to
b5925e5
Compare
@@ -16,6 +16,10 @@ | |||
// tt-mlir includes | |||
#include "tt/runtime/types.h" | |||
|
|||
// tt-mlir includes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge with the group above, like:
// tt-mlir includes
#define TTMLIR_ENABLE_STABLEHLO 1
#include "tt/runtime/types.h"
#include "ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h"
@@ -28,10 +32,12 @@ | |||
#include "stablehlo/dialect/Register.h" | |||
#include "stablehlo/dialect/Version.h" | |||
#include "stablehlo/transforms/Passes.h" | |||
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move into group below
|
||
// tt-mlir includes | ||
#define TTMLIR_ENABLE_STABLEHLO | ||
#define TTMLIR_ENABLE_STABLEHLO 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't be needed since you've already defined it in module_builder.h
@@ -52,6 +65,12 @@ class ModuleBuilder { | |||
// scalar or not. | |||
void collectOutputTypes(const mlir::OwningOpRef<mlir::ModuleOp> &module); | |||
|
|||
// Collects the information about the sharding of specifc inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo specifc
, below too
auto error = | ||
meshSharding.convertGSPMDShardingToMeshSharding(shardingStr.getValue()); | ||
if (auto e = error.takeError()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is auto
necessary in these places?
return; | ||
} | ||
|
||
for (unsigned i = 0; i < funcOp.getNumArguments(); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specify the full type instead of unsigned
funcOp.getArgAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr)); | ||
|
||
if (!shardingAttr) { | ||
m_input_shardings.push_back(meshSharding); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this okay that we just put the empty sharding here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is, please add a comment explaining why
@@ -134,6 +142,92 @@ void ModuleBuilder::convertFromVHLOToSHLO( | |||
printModule(mlir_module); | |||
} | |||
|
|||
mlir::LogicalResult ModuleBuilder::fillMeshShardingFromGSPMDString( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this function implementation below the implementations for the collectInputShardings
and collectOutputShardings
since we followed top-down order in this file
// Fills sharding_utils::MeshSharding object with sharding info stored in a | ||
// StringAttribute. | ||
mlir::LogicalResult fillMeshShardingFromGSPMDString( | ||
mlir::StringAttr shardingStr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In PJRT we followed a slight change from tt-mlir guidelines, we use the snake case for variables, please update in all changes in ModuleBuilder
As part of adding support for multichip, adding collection of the information about the device sharding of the inputs of the the StableHLO graph during compilation. That way, we will know how to create MultiDevice input tensors when executing.
Edit: Also added output sharding collection.