Skip to content

Commit

Permalink
Move users of BasicDeviceList::Create() to Client::MakeDeviceList()
Browse files Browse the repository at this point in the history
IFRT is moving towards runtime-controlled device list creation. This CL moves most of explicit device list creation from `BasicDeviceLIst::Create()` to `Client::MakeDeviceList()`. Once the migration is done, `BasicDeviceList::Create()` will be reserved only for IFRT implementations and all IFRT users will be expected to use `Client::MakeDeviceList::Create()` to create device lists.

PiperOrigin-RevId: 725479063
  • Loading branch information
junwhanahn authored and Google-ML-Automation committed Feb 11, 2025
1 parent f7e80bf commit c802f9b
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 27 deletions.
4 changes: 3 additions & 1 deletion xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ cc_library(
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand All @@ -293,7 +294,6 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
Expand Down Expand Up @@ -321,6 +321,7 @@ cc_library(
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
Expand Down Expand Up @@ -618,6 +619,7 @@ cc_library(
"//xla/tsl/platform:status_matchers",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
Expand Down
26 changes: 13 additions & 13 deletions xla/python/ifrt/array_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/synchronization/notification.h"
#include "absl/time/clock.h"
Expand Down Expand Up @@ -311,7 +312,7 @@ TEST(ArrayImplTest, MakeArrayFromHostBufferReplicated) {
std::iota(data->begin(), data->end(), 0);
absl::Span<Device* const> devices = client->addressable_devices();
std::shared_ptr<const Sharding> sharding = ConcreteEvenSharding::Create(
BasicDeviceList::Create(devices), MemoryKind(), shape,
client->MakeDeviceList(devices), MemoryKind(), shape,
/*shard_shape=*/shape, /*is_fully_replicated=*/true);

TF_ASSERT_OK_AND_ASSIGN(
Expand Down Expand Up @@ -375,9 +376,8 @@ TEST(ArrayImplTest, AssembleArray) {
std::vector<tsl::RCReference<Array>> arrays({array0, array1});
Shape assembled_shape({4, 3});
std::shared_ptr<const Sharding> assembled_sharding = OpaqueSharding::Create(
BasicDeviceList::Create(
{array0->sharding().devices()->devices().front(),
array1->sharding().devices()->devices().front()}),
client->MakeDeviceList({array0->sharding().devices()->devices().front(),
array1->sharding().devices()->devices().front()}),
MemoryKind());
TF_ASSERT_OK_AND_ASSIGN(
auto assembled_array,
Expand Down Expand Up @@ -423,9 +423,9 @@ TEST(ArrayImplTest, AssembleAndDisassembleArray) {
Shape assembled_shape({4, 3});
ShardingParam sharding_param(
/*dim_shards=*/{2, 1}, {/*permutation=*/{0, 1}, /*axis_sizes=*/{2, 1}});
auto ifrt_device_list = BasicDeviceList::Create(
{array0->sharding().devices()->devices().front(),
array1->sharding().devices()->devices().front()});
auto ifrt_device_list =
client->MakeDeviceList({array0->sharding().devices()->devices().front(),
array1->sharding().devices()->devices().front()});
TF_ASSERT_OK_AND_ASSIGN(
std::shared_ptr<const Sharding> sharding_param_sharding,
ShardingParamSharding::Create(std::move(sharding_param), ifrt_device_list,
Expand Down Expand Up @@ -537,7 +537,7 @@ TEST(ArrayImplTest, CopyToSameDevices) {
TEST(ArrayImplTest, CopyToDifferentDevice) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
tsl::RCReference<DeviceList> devices =
BasicDeviceList::Create(client->addressable_devices());
client->MakeDeviceList(client->addressable_devices());

DType dtype(DType::kF32);
Shape shape({2, 3});
Expand Down Expand Up @@ -580,22 +580,22 @@ TEST(ArrayImplTest, CopyToDifferentDevice) {
SingleDeviceShardSemantics::kAddressableShards));
}

BasicDeviceList::Devices new_devices;
absl::InlinedVector<xla::ifrt::Device*, 1> new_devices;
for (auto it = devices->devices().rbegin(); it != devices->devices().rend();
++it) {
new_devices.push_back(*it);
}
TF_ASSERT_OK_AND_ASSIGN(
auto new_arrays,
client->CopyArrays(absl::MakeSpan(arrays),
BasicDeviceList::Create(new_devices), MemoryKind(),
client->MakeDeviceList(new_devices), MemoryKind(),
ArrayCopySemantics::kAlwaysCopy));

for (int i = 0; i < arrays.size(); ++i) {
TF_ASSERT_OK_AND_ASSIGN(
auto expected_sharding,
arrays[i]->sharding().WithDeviceAssignment(
BasicDeviceList::Create(new_devices), MemoryKind()));
client->MakeDeviceList(new_devices), MemoryKind()));
EXPECT_EQ(new_arrays[i]->sharding(), *expected_sharding);

TF_ASSERT_OK_AND_ASSIGN(
Expand Down Expand Up @@ -637,7 +637,7 @@ TEST(ArrayImplTest, CopyMixedSourceDevices) {
Device* new_device = client->addressable_devices().at(1);
EXPECT_THAT(client
->CopyArrays(absl::MakeSpan(arrays),
BasicDeviceList::Create({new_device}),
client->MakeDeviceList({new_device}),
MemoryKind(), ArrayCopySemantics::kAlwaysCopy)
.status(),
StatusIs(absl::StatusCode::kInvalidArgument));
Expand Down Expand Up @@ -671,7 +671,7 @@ TEST(ArrayImplTest, CopyMixedSourceMemoryKind) {
Device* new_device = client->addressable_devices().at(1);
EXPECT_THAT(client
->CopyArrays(absl::MakeSpan(arrays),
BasicDeviceList::Create({new_device}),
client->MakeDeviceList({new_device}),
MemoryKind(), ArrayCopySemantics::kAlwaysCopy)
.status(),
StatusIs(absl::StatusCode::kInvalidArgument));
Expand Down
5 changes: 5 additions & 0 deletions xla/python/ifrt/device_test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ std::shared_ptr<MockClient> MakeDeviceTestClient(int num_devices,
}
return it->second.get();
});
ON_CALL(*client, MakeDeviceList)
.WillByDefault([](absl::Span<Device* const> devices)
-> tsl::RCReference<DeviceList> {
return BasicDeviceList::Create(devices);
});
ON_CALL(*client, GetTopologyForDevices)
.WillByDefault(
[](const tsl::RCReference<DeviceList>&) { return nullptr; });
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt/ir/tests/executable_impl_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ IfrtIrExecutableImplTestBase::PickDevices(int count) {
absl::Span<Device* const> devices = client_->devices();
TF_RET_CHECK(count <= devices.size())
<< "Requested " << count << " devices. Only have " << devices.size();
return BasicDeviceList::Create(devices.first(count));
return client_->MakeDeviceList(devices.first(count));
}

} // namespace test_util
Expand Down
12 changes: 6 additions & 6 deletions xla/python/ifrt/ir/tests/executable_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ module {
tsl::RCReference<Array> input,
CreateArray({data.data()}, Shape({2}), DType(DType::kS32),
ShardingParam({1}, {{0}, {1}}),
BasicDeviceList::Create({devices->devices()[0]})));
client_->MakeDeviceList({devices->devices()[0]})));

ExecuteOptions options;
options.fill_status = true;
Expand All @@ -425,7 +425,7 @@ module {
ASSERT_EQ(result.outputs.size(), 1);
ASSERT_NO_FATAL_FAILURE(AssertPerShardData<int>(
result.outputs[0], DType(DType::kS32), Shape({2}), {{1, 2}},
BasicDeviceList::Create({devices->devices()[1]})));
client_->MakeDeviceList({devices->devices()[1]})));
}

TEST_F(IfrtIrExecutableImplTest, Reshard) {
Expand Down Expand Up @@ -459,7 +459,7 @@ module {
tsl::RCReference<Array> input,
CreateArray({data.data()}, Shape({2, 2}), DType(DType::kS32),
ShardingParam({1, 1}, {{0}, {1}}),
BasicDeviceList::Create({devices->devices()[0]})));
client_->MakeDeviceList({devices->devices()[0]})));

ExecuteOptions options;
options.fill_status = true;
Expand All @@ -475,7 +475,7 @@ module {
Shape({1, 2}), {{0, 1}, {2, 3}}, devices));
ASSERT_NO_FATAL_FAILURE(AssertPerShardData<int>(
result.outputs[1], DType(DType::kS32), Shape({2, 2}), {{0, 1, 2, 3}},
BasicDeviceList::Create({devices->devices()[1]})));
client_->MakeDeviceList({devices->devices()[1]})));
}

TEST_F(IfrtIrExecutableImplTest, ZeroInput) {
Expand Down Expand Up @@ -711,10 +711,10 @@ module {
ASSERT_EQ(result.outputs.size(), 2);
ASSERT_NO_FATAL_FAILURE(AssertPerShardData<int>(
result.outputs[0], DType(DType::kS32), Shape({1, 2}), {{0, 1}},
BasicDeviceList::Create({devices->devices()[0]})));
client_->MakeDeviceList({devices->devices()[0]})));
ASSERT_NO_FATAL_FAILURE(AssertPerShardData<int>(
result.outputs[1], DType(DType::kS32), Shape({1, 2}), {{2, 3}},
BasicDeviceList::Create({devices->devices()[1]})));
client_->MakeDeviceList({devices->devices()[1]})));
}

TEST_F(IfrtIrExecutableImplTest, LoadedExecBinding) {
Expand Down
5 changes: 3 additions & 2 deletions xla/python/ifrt/remap_impl_test_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -110,7 +111,7 @@ absl::StatusOr<tsl::RCReference<Array>> CreateArray(

std::vector<tsl::RCReference<Array>> shards;
shards.reserve(base_values.size());
BasicDeviceList::Devices devices;
absl::InlinedVector<xla::ifrt::Device*, 1> devices;
devices.reserve(device_indices.size());

for (int i = 0; i < base_values.size(); ++i) {
Expand All @@ -132,7 +133,7 @@ absl::StatusOr<tsl::RCReference<Array>> CreateArray(
}

std::shared_ptr<const Sharding> assembled_sharding =
ConcreteEvenSharding::Create(BasicDeviceList::Create(std::move(devices)),
ConcreteEvenSharding::Create(client->MakeDeviceList(devices),
MemoryKind(),
/*shape=*/shape,
/*shard_shape=*/std::move(shard_shape));
Expand Down
5 changes: 5 additions & 0 deletions xla/python/ifrt/support/sharding_conversions_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ std::shared_ptr<MockClient> MakeTestClient(int num_devices) {
ON_CALL(*client, devices)
.WillByDefault(
[state]() -> absl::Span<Device* const> { return state->devices; });
ON_CALL(*client, MakeDeviceList)
.WillByDefault([](absl::Span<Device* const> devices)
-> tsl::RCReference<DeviceList> {
return BasicDeviceList::Create(devices);
});
return client;
}

Expand Down
9 changes: 5 additions & 4 deletions xla/python/ifrt/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <utility>

#include "absl/base/thread_annotations.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -91,7 +92,7 @@ void SetTestFilterIfNotUserSpecified(absl::string_view custom_filter) {

absl::StatusOr<tsl::RCReference<DeviceList>> GetDevices(
Client* client, absl::Span<const int> device_indices) {
BasicDeviceList::Devices devices;
absl::InlinedVector<xla::ifrt::Device*, 1> devices;
devices.reserve(device_indices.size());
const absl::Span<Device* const> client_devices = client->devices();
for (int device_index : device_indices) {
Expand All @@ -101,12 +102,12 @@ absl::StatusOr<tsl::RCReference<DeviceList>> GetDevices(
}
devices.push_back(client_devices[device_index]);
}
return BasicDeviceList::Create(std::move(devices));
return client->MakeDeviceList(devices);
}

absl::StatusOr<tsl::RCReference<DeviceList>> GetAddressableDevices(
Client* client, absl::Span<const int> device_indices) {
BasicDeviceList::Devices devices;
absl::InlinedVector<xla::ifrt::Device*, 1> devices;
devices.reserve(device_indices.size());
const absl::Span<Device* const> client_devices =
client->addressable_devices();
Expand All @@ -117,7 +118,7 @@ absl::StatusOr<tsl::RCReference<DeviceList>> GetAddressableDevices(
}
devices.push_back(client_devices[device_index]);
}
return BasicDeviceList::Create(std::move(devices));
return client->MakeDeviceList(std::move(devices));
}

} // namespace test_util
Expand Down

0 comments on commit c802f9b

Please sign in to comment.