Skip to content

Commit

Permalink
PR #22723: Fix call of overloaded Tile is ambiguous
Browse files Browse the repository at this point in the history
Imported from GitHub PR #22723

#### Fix GCC-13 Build Error in AutoSharding Due to vector<vector> vs. absl::Span Ambiguity

When building auto_sharding with GCC-13, the following build error occurred:

```
xla/hlo/experimental/auto_sharding/auto_sharding.cc:895:37: error: call of overloaded 'Tile(const xla::Shape&, <brace-enclosed initializer list>, <brace-enclosed initializer list>, const xla::spmd::DeviceMesh&)' is ambiguous
  895 |       HloSharding output_spec = Tile(shape, {i}, {j}, device_mesh);
      |                                 ~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from ./xla/hlo/experimental/auto_sharding/cluster_environment.h:33,
                 from ./xla/hlo/experimental/auto_sharding/auto_sharding.h:41:
./xla/hlo/experimental/auto_sharding/auto_sharding_util.h:499:13: note: candidate: 'xla::HloSharding xla::spmd::Tile(const xla::Shape&, absl::lts_20230802::Span<const long int>, const std::vector<std::vector<long int> >&, const DeviceMesh&)'
  499 | HloSharding Tile(const Shape& tensor_shape,
      |             ^~~~
./xla/hlo/experimental/auto_sharding/auto_sharding_util.h:504:13: note: candidate: 'xla::HloSharding xla::spmd::Tile(const xla::Shape&, absl::lts_20230802::Span<const long int>, absl::lts_20230802::Span<const long int>, const DeviceMesh&)'
  504 | HloSharding Tile(const Shape& tensor_shape,
      |             ^~~~
```

#### Solution:
To resolve the ambiguity between `std::vector<std::vector<int64_t>>` and `absl::Span<const int64_t>` in `Tile()`, I introduced an overloaded Tile() function that takes `std::initializer_list<int64_t> mesh_dims`.

Now, expressions like the following now compile successfully with GCC-13:

```
Tile(shape, {0}, {0}, device_mesh);
```

#### Additional Changes
- Removed the `Tile()` declaration from `auto_sharding.h` since it is already declared in `auto_sharding_util.h`.

Copybara import of the project:

--
a28207f by Alexander Pivovarov <apivovarov@gmail.com>:

Fix call of overloaded Tile is ambiguous

Merging this change closes #22723

COPYBARA_INTEGRATE_REVIEW=#22723 from apivovarov:fix_Tile_init_list a28207f
PiperOrigin-RevId: 728206576
  • Loading branch information
apivovarov authored and Google-ML-Automation committed Feb 18, 2025
1 parent 8e9c867 commit f56dc78
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
4 changes: 0 additions & 4 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,6 @@ class AutoSharding : public HloModulePass {
namespace spmd {
// Function declarations.
// Their comments can be found in their definitions in *.cc files.
HloSharding Tile(const Shape& shape, absl::Span<const int64_t> tensor_dims,
absl::Span<const int64_t> mesh_dims,
const DeviceMesh& device_mesh);

std::vector<double> CommunicationReshardingCostVector(
const StrategyGroup& strategy_group, const Shape& shape,
const HloSharding& required_sharding,
Expand Down
19 changes: 13 additions & 6 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1671,10 +1671,15 @@ HloSharding Tile(const Shape& tensor_shape,
for (int i = 0; i < mesh_dims.size(); ++i) {
mesh_dims_general[i].push_back(mesh_dims[i]);
}
if (device_mesh.IsIota()) {
return TileV2(tensor_shape, tensor_dims, mesh_dims_general, device_mesh);
}
return TileV1(tensor_shape, tensor_dims, mesh_dims_general, device_mesh);
return Tile(tensor_shape, tensor_dims, mesh_dims_general, device_mesh);
}

HloSharding Tile(const Shape& tensor_shape,
absl::Span<const int64_t> tensor_dims,
std::initializer_list<int64_t> mesh_dims,
const DeviceMesh& device_mesh) {
return Tile(tensor_shape, tensor_dims, absl::Span<const int64_t>(mesh_dims),
device_mesh);
}

AliasMap BuildAliasMap(const HloModule* module,
Expand Down Expand Up @@ -1858,7 +1863,8 @@ absl::Status CheckAliasSetCompatibility(const AliasSet& alias_set,
"tensors and may result in large memory consumption: "
<< "(" << instructions.at(src_strategy_group->instruction_id)->name()
<< ", " << instructions.at(dst_strategy_group->instruction_id)->name()
<< ")" << "\n"
<< ")"
<< "\n"
<< "(" << src_strategy_group->node_idx << ", "
<< dst_strategy_group->node_idx << ")\n"
<< src_strategy_group->ToString() << "\n"
Expand Down Expand Up @@ -1922,7 +1928,8 @@ absl::StatusOr<AliasCompatibility> ComputeAliasCompatibility(
<< instructions.at(src_strategy_group->instruction_id)->name()
<< ", "
<< instructions.at(dst_strategy_group->instruction_id)->name()
<< ")" << "\n"
<< ")"
<< "\n"
<< "(" << src_strategy_group->node_idx << ", "
<< dst_strategy_group->node_idx << ")\n"
<< src_strategy_group->ToString() << "\n"
Expand Down
5 changes: 5 additions & 0 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,11 @@ HloSharding Tile(const Shape& tensor_shape,
absl::Span<const int64_t> mesh_dims,
const DeviceMesh& device_mesh);

HloSharding Tile(const Shape& tensor_shape,
absl::Span<const int64_t> tensor_dims,
std::initializer_list<int64_t> mesh_dims,
const DeviceMesh& device_mesh);

AliasMap BuildAliasMap(const HloModule* module,
const HloInputOutputAliasConfig& alias_config);

Expand Down

0 comments on commit f56dc78

Please sign in to comment.