diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding.h b/xla/hlo/experimental/auto_sharding/auto_sharding.h index 40cbc6295b0bc..2034dd6897241 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -128,10 +128,6 @@ class AutoSharding : public HloModulePass { namespace spmd { // Function declarations. // Their comments can be found in their definitions in *.cc files. -HloSharding Tile(const Shape& shape, absl::Span tensor_dims, - absl::Span mesh_dims, - const DeviceMesh& device_mesh); - std::vector CommunicationReshardingCostVector( const StrategyGroup& strategy_group, const Shape& shape, const HloSharding& required_sharding, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index dc5f976eef7cc..6e95746e39e61 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -1671,10 +1671,15 @@ HloSharding Tile(const Shape& tensor_shape, for (int i = 0; i < mesh_dims.size(); ++i) { mesh_dims_general[i].push_back(mesh_dims[i]); } - if (device_mesh.IsIota()) { - return TileV2(tensor_shape, tensor_dims, mesh_dims_general, device_mesh); - } - return TileV1(tensor_shape, tensor_dims, mesh_dims_general, device_mesh); + return Tile(tensor_shape, tensor_dims, mesh_dims_general, device_mesh); +} + +HloSharding Tile(const Shape& tensor_shape, + absl::Span tensor_dims, + std::initializer_list mesh_dims, + const DeviceMesh& device_mesh) { + return Tile(tensor_shape, tensor_dims, absl::Span(mesh_dims), + device_mesh); } AliasMap BuildAliasMap(const HloModule* module, @@ -1858,7 +1863,8 @@ absl::Status CheckAliasSetCompatibility(const AliasSet& alias_set, "tensors and may result in large memory consumption: " << "(" << instructions.at(src_strategy_group->instruction_id)->name() << ", " << instructions.at(dst_strategy_group->instruction_id)->name() - << ")" << "\n" + << ")" + << "\n" << "(" << src_strategy_group->node_idx << ", " << dst_strategy_group->node_idx << ")\n" << src_strategy_group->ToString() << "\n" @@ -1922,7 +1928,8 @@ absl::StatusOr ComputeAliasCompatibility( << instructions.at(src_strategy_group->instruction_id)->name() << ", " << instructions.at(dst_strategy_group->instruction_id)->name() - << ")" << "\n" + << ")" + << "\n" << "(" << src_strategy_group->node_idx << ", " << dst_strategy_group->node_idx << ")\n" << src_strategy_group->ToString() << "\n" diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 862aefd775506..2c8677e1e7e45 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -506,6 +506,11 @@ HloSharding Tile(const Shape& tensor_shape, absl::Span mesh_dims, const DeviceMesh& device_mesh); +HloSharding Tile(const Shape& tensor_shape, + absl::Span tensor_dims, + std::initializer_list mesh_dims, + const DeviceMesh& device_mesh); + AliasMap BuildAliasMap(const HloModule* module, const HloInputOutputAliasConfig& alias_config);