Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding input sharding collection in ModuleBuilder #289

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ajakovljevicTT
Copy link
Contributor

@ajakovljevicTT ajakovljevicTT commented Feb 28, 2025

As part of adding support for multichip, adding collection of the information about the device sharding of the inputs of the the StableHLO graph during compilation. That way, we will know how to create MultiDevice input tensors when executing.

Edit: Also added output sharding collection.

Copy link

github-actions bot commented Feb 28, 2025

TestsPassed ✅Skipped ⚠️Failed
TT-XLA Tests603 ran432 passed171 skipped0 failed
TestResult
No test annotations available

@ajakovljevicTT ajakovljevicTT force-pushed the ajakovljevic/adding_input_sharding_collection branch 3 times, most recently from 260fb4d to 3f348ad Compare March 4, 2025 09:13
@codecov-commenter
Copy link

codecov-commenter commented Mar 4, 2025

Codecov Report

Attention: Patch coverage is 76.27119% with 14 lines in your changes missing coverage. Please review.

Project coverage is 78.65%. Comparing base (5f758e5) to head (b5925e5).
Report is 14 commits behind head on main.

✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
src/common/module_builder.cc 71.42% 14 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #289      +/-   ##
==========================================
+ Coverage   77.97%   78.65%   +0.67%     
==========================================
  Files          21       21              
  Lines         990     1059      +69     
==========================================
+ Hits          772      833      +61     
- Misses        218      226       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +156 to +229
void ModuleBuilder::collectInputShardings(
const mlir::OwningOpRef<mlir::ModuleOp> &module) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::collectInputShardings");
m_input_shardings.clear();

module.get().walk([&](mlir::Operation *op) {
mlir::func::FuncOp funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op);
mlir::ModuleOp moduleOp = mlir::dyn_cast<mlir::ModuleOp>(op);

if (!funcOp || !funcOp.isPublic()) {
return;
}

for (unsigned i = 0; i < funcOp.getNumArguments(); ++i) {

mlir::tt::sharding_utils::MeshSharding meshSharding;

auto shardingAttr = llvm::dyn_cast_if_present<mlir::StringAttr>(
funcOp.getArgAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr));

if (!shardingAttr) {
m_input_shardings.push_back(meshSharding);
continue;
}

mlir::LogicalResult conversionResult =
fillMeshShardingFromGSPMDString(shardingAttr, meshSharding);

if (conversionResult.failed()) {
m_status = tt_pjrt_status::kInternal;
return;
}
m_input_shardings.push_back(meshSharding);
}
});
}

void ModuleBuilder::collectOutputShardings(
const mlir::OwningOpRef<mlir::ModuleOp> &module) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::collectOutputShardings");
m_output_shardings.clear();

module.get().walk([&](mlir::Operation *op) {
mlir::func::FuncOp funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op);
mlir::ModuleOp moduleOp = mlir::dyn_cast<mlir::ModuleOp>(op);

if (!funcOp || !funcOp.isPublic()) {
return;
}

for (unsigned i = 0; i < funcOp.getNumResults(); ++i) {

mlir::tt::sharding_utils::MeshSharding meshSharding;

auto shardingAttr = llvm::dyn_cast_if_present<mlir::StringAttr>(
funcOp.getResultAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr));

if (!shardingAttr) {
m_output_shardings.push_back(meshSharding);
continue;
}

mlir::LogicalResult conversionResult =
fillMeshShardingFromGSPMDString(shardingAttr, meshSharding);

if (conversionResult.failed()) {
m_status = tt_pjrt_status::kInternal;
return;
}
m_output_shardings.push_back(meshSharding);
}
});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two look quite similar. The only differences are m_input_shardings vs m_output_shardings and getArgAttr instead of getResultAttr. Can we somehow extract it in a common function if it is not too cumbersome (like passing function pointers or something, let's not do that, it's an overkill)?

Copy link
Contributor Author

@ajakovljevicTT ajakovljevicTT Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue I see here is with the getArgAttr and getResultAttr, I can use templates or just pass a bollean if it makes more sense? Not sure of the tradeoff between clarity and the code size here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's what I though. Leave it like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do something like:

void ModuleBuilder::collectInputShardings(
    const mlir::OwningOpRef<mlir::ModuleOp> &module) {
  DLOG_F(LOG_DEBUG, "ModuleBuilder::collectInputShardings");
  m_input_shardings.clear();

  std::vector<mlir::StringAttr> gspmd_attributes;
  for (unsigned i = 0; i < funcOp.getNumArguments(); ++i) {
    gspmd_attributes.push_back(llvm::dyn_cast_if_present<mlir::StringAttr>(
        funcOp.getArgAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr)));
  }

  mlir::LogicalResult result = createShardingsFromGSPMD(gspmd_attributes, m_input_shardings);
  if (!tt_pjrt_status_is_ok(result)) {
    m_status = tt_pjrt_status::kInternal;
  }
}

void ModuleBuilder::collectOutputShardings(
    const mlir::OwningOpRef<mlir::ModuleOp> &module) {
  DLOG_F(LOG_DEBUG, "ModuleBuilder::collectOutputShardings");
  m_output_shardings.clear();

  std::vector<mlir::StringAttr> gspmd_attributes;
  for (unsigned i = 0; i < funcOp.getNumResults(); ++i) {
    gspmd_attributes.push_back(llvm::dyn_cast_if_present<mlir::StringAttr>(
        funcOp.getResultAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr)));
  }

  mlir::LogicalResult result = createShardingsFromGSPMD(gspmd_attributes, m_output_shardings);
  if (!tt_pjrt_status_is_ok(result)) {
    m_status = tt_pjrt_status::kInternal;
  }
}

mlir::LogicalResult ModuleBuilder::createShardingsFromGSPMD(
    const std::vector<mlir::StringAttr> &gspmd_attributes,
    std::vector<mlir::tt::sharding_utils::MeshSharding> &shardings) {
  for (mlir::StringAttr gspmd_attr : gspmd_attributes) {
    mlir::tt::sharding_utils::MeshSharding meshSharding;

    if (!gspmd_attr) {
      shardings.push_back(meshSharding);
      continue;
    }

    mlir::LogicalResult conversionResult =
        fillMeshShardingFromGSPMDString(gspmd_attr, meshSharding);
    if (conversionResult.failed()) {
      return llvm::LogicalResult::failure();
    }

    shardings.push_back(meshSharding);
  }

  return llvm::LogicalResult::success();
}

@ajakovljevicTT ajakovljevicTT force-pushed the ajakovljevic/adding_input_sharding_collection branch from 3f348ad to b5925e5 Compare March 6, 2025 16:00
@@ -16,6 +16,10 @@
// tt-mlir includes
#include "tt/runtime/types.h"

// tt-mlir includes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge with the group above, like:

// tt-mlir includes
#define TTMLIR_ENABLE_STABLEHLO 1
#include "tt/runtime/types.h"
#include "ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h"

@@ -28,10 +32,12 @@
#include "stablehlo/dialect/Register.h"
#include "stablehlo/dialect/Version.h"
#include "stablehlo/transforms/Passes.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move into group below


// tt-mlir includes
#define TTMLIR_ENABLE_STABLEHLO
#define TTMLIR_ENABLE_STABLEHLO 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't be needed since you've already defined it in module_builder.h

@@ -52,6 +65,12 @@ class ModuleBuilder {
// scalar or not.
void collectOutputTypes(const mlir::OwningOpRef<mlir::ModuleOp> &module);

// Collects the information about the sharding of specifc inputs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo specifc, below too

Comment on lines +148 to +150
auto error =
meshSharding.convertGSPMDShardingToMeshSharding(shardingStr.getValue());
if (auto e = error.takeError()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is auto necessary in these places?

return;
}

for (unsigned i = 0; i < funcOp.getNumArguments(); ++i) {
Copy link
Contributor

@mrakitaTT mrakitaTT Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify the full type instead of unsigned

funcOp.getArgAttr(i, mlir::tt::sharding_utils::kXlaShardingAttr));

if (!shardingAttr) {
m_input_shardings.push_back(meshSharding);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this okay that we just put the empty sharding here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is, please add a comment explaining why

@@ -134,6 +142,92 @@ void ModuleBuilder::convertFromVHLOToSHLO(
printModule(mlir_module);
}

mlir::LogicalResult ModuleBuilder::fillMeshShardingFromGSPMDString(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this function implementation below the implementations for the collectInputShardings and collectOutputShardings since we followed top-down order in this file

// Fills sharding_utils::MeshSharding object with sharding info stored in a
// StringAttribute.
mlir::LogicalResult fillMeshShardingFromGSPMDString(
mlir::StringAttr shardingStr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PJRT we followed a slight change from tt-mlir guidelines, we use the snake case for variables, please update in all changes in ModuleBuilder

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants