Skip to content

Commit

Permalink
Pass in Physical Device IDs to DeviceBoundThreadPool ctor
Browse files Browse the repository at this point in the history
  - Store a Physical Device ID to Thread ID table in the class
  - ALlows accurate thread_id lookups given the device ID
  • Loading branch information
tt-asaigal committed Mar 8, 2025
1 parent e3f1a4a commit 6da7594
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
27 changes: 25 additions & 2 deletions tt_metal/common/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <numa.h>
#include <semaphore>

#include <tt-metalium/device.hpp>
#include "tt_metal/common/thread_pool.hpp"
#include "tt_metal/llrt/tt_cluster.hpp"

Expand Down Expand Up @@ -305,20 +306,35 @@ class DistributedBoostThreadPool : public ThreadPool {
// Allows enqueuing tasks tied to specific devices.
class DeviceBoundThreadPool : public ThreadPool {
public:
// Constuctor accepting the physical device IDs this pool is bound to. Each thread will be tied to a device, and is
// guaranteed to be bound to a CPU core on a NUMA Node "closest" to that device.
DeviceBoundThreadPool(const std::vector<tt::tt_metal::IDevice*>& physical_devices, uint32_t logical_cpu_offset) {
num_workers_ = physical_devices.size();
workers_.reserve(num_workers_);
for (uint32_t i = 0; i < num_workers_; i++) {
workers_.emplace_back(std::make_unique<NumaAwareExecutor>(physical_devices[i]->id(), logical_cpu_offset));
phys_device_to_thread_id_[physical_devices[i]->id()] = i;
}
}
// Constructor accepting the number of threads to spawn. The threads in this pool will be bound to a specific CPU
// core but they are not guaranteed to be "close" to any physical device.
DeviceBoundThreadPool(uint32_t thread_count, uint32_t logical_cpu_offset) {
workers_.reserve(thread_count);
num_workers_ = thread_count;
for (uint32_t i = 0; i < thread_count; i++) {
workers_.emplace_back(std::make_unique<NumaAwareExecutor>(i, logical_cpu_offset));
phys_device_to_thread_id_[i] = i;
}
}

void enqueue(std::function<void()>&& f, std::optional<uint32_t> device_idx = std::nullopt) override {
// If the user does not provide the Device ID tied to this task, determine the thread to use
// based on the internally stored thread_idx. Tasks will get round-robined across threads,
// when relying on the thread_idx.
workers_[device_idx.value_or(thread_idx_ % num_workers_)]->enqueue(std::move(f));
++thread_idx_;
// If the device id is specified, use the thread tied to the device.
uint32_t thread_id =
device_idx.has_value() ? phys_device_to_thread_id_[device_idx.value()] : ((thread_idx_++) % num_workers_);
workers_[thread_id]->enqueue(std::move(f));
}

void wait() override {
Expand All @@ -335,6 +351,8 @@ class DeviceBoundThreadPool : public ThreadPool {
uint32_t thread_idx_ = 0;
// Store the number of workers to repeated lookups
uint32_t num_workers_ = 0;
// Mapping between the physical device id and its associated thread
std::unordered_map<uint32_t, uint32_t> phys_device_to_thread_id_;
};

} // namespace thread_pool_impls
Expand All @@ -351,4 +369,9 @@ std::shared_ptr<ThreadPool> create_device_bound_thread_pool(int num_threads, uin
return std::make_shared<thread_pool_impls::DeviceBoundThreadPool>(num_threads, logical_cpu_offset);
}

std::shared_ptr<ThreadPool> create_device_bound_thread_pool(
const std::vector<tt::tt_metal::IDevice*>& physical_devices, uint32_t logical_cpu_offset) {
return std::make_shared<thread_pool_impls::DeviceBoundThreadPool>(physical_devices, logical_cpu_offset);
}

} // namespace tt::tt_metal
14 changes: 13 additions & 1 deletion tt_metal/common/thread_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@
#include <memory>
#include <optional>
#include <thread>

#include <vector>
namespace tt::tt_metal {

inline namespace v0 {

class IDevice;

} // namespace v0

class ThreadPool {
public:
virtual ~ThreadPool() = default;
Expand All @@ -20,6 +26,12 @@ class ThreadPool {

std::shared_ptr<ThreadPool> create_boost_thread_pool(int num_threads);
std::shared_ptr<ThreadPool> create_distributed_boost_thread_pool(int num_threads);
// API accespting the number of threads to spawn in the pool. Will bind each thread to a CPU core, but the
// binding strategy will not be NUMA aware. Used for testing and benchmarking host-code.
std::shared_ptr<ThreadPool> create_device_bound_thread_pool(int num_threads, uint32_t logical_cpu_offset = 0);
// API accepting the physical devices the pool will be bound to. The threads will be bound to CPU cores in a
// NUMA aware manner (will be "closest" to the device it serves). Used for production data-paths.
std::shared_ptr<ThreadPool> create_device_bound_thread_pool(
const std::vector<tt::tt_metal::IDevice*>& physical_devices, uint32_t logical_cpu_offset = 0);

} // namespace tt::tt_metal
12 changes: 7 additions & 5 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@ MeshDeviceID generate_unique_mesh_id() {
return next_id++;
}

std::shared_ptr<ThreadPool> create_default_thread_pool(uint32_t num_threads, uint32_t logical_cpu_offset = 0) {
std::shared_ptr<ThreadPool> create_default_thread_pool(
const std::vector<IDevice*>& physical_devices, uint32_t logical_cpu_offset = 0) {
// Bind the thread-pool to the physical devices being used.
if (tt::parse_env("TT_MESH_BOOST_THREAD_POOL", false)) {
return create_boost_thread_pool(num_threads);
return create_boost_thread_pool(physical_devices.size());
} else {
return create_device_bound_thread_pool(num_threads, logical_cpu_offset);
return create_device_bound_thread_pool(physical_devices, logical_cpu_offset);
}
}

Expand Down Expand Up @@ -148,8 +150,8 @@ MeshDevice::MeshDevice(
view_(std::move(mesh_device_view)),
mesh_id_(generate_unique_mesh_id()),
parent_mesh_(std::move(parent_mesh)),
dispatch_thread_pool_(create_default_thread_pool(view_->shape().mesh_size())),
reader_thread_pool_(create_default_thread_pool(view_->shape().mesh_size(), view_->shape().mesh_size())) {}
dispatch_thread_pool_(create_default_thread_pool(scoped_devices_->root_devices())),
reader_thread_pool_(create_default_thread_pool(scoped_devices_->root_devices(), view_->shape().mesh_size())) {}

std::shared_ptr<MeshDevice> MeshDevice::create(
const MeshDeviceConfig& config,
Expand Down

0 comments on commit 6da7594

Please sign in to comment.