Skip to content

Commit

Permalink
#0: Add mising MeshBuffer APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-dma committed Mar 7, 2025
1 parent c11b5f3 commit 3635dc0
Show file tree
Hide file tree
Showing 4 changed files with 330 additions and 61 deletions.
8 changes: 8 additions & 0 deletions tt_metal/api/tt-metalium/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,14 @@ class Buffer final {

} // namespace v0

std::tuple<std::vector<std::vector<uint32_t>>, std::vector<std::array<uint32_t, 2>>> core_to_host_pages(
const uint32_t total_pages,
const uint32_t pages_per_shard,
const uint32_t num_shards,
const TensorMemoryLayout layout,
const std::array<uint32_t, 2>& page_shape,
const std::array<uint32_t, 2>& shard_shape,
const std::array<uint32_t, 2>& tensor2d_size);
BufferPageMapping generate_buffer_page_mapping(const Buffer &buffer);

inline namespace v0 {
Expand Down
44 changes: 42 additions & 2 deletions tt_metal/api/tt-metalium/mesh_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class MeshBuffer {

// Throws an exception if the corresponding MeshDevice is already deallocated
MeshDevice* device() const;
Allocator* allocator() const { return allocator_; }
DeviceAddr size() const;
DeviceAddr device_local_size() const { return device_local_size_; }
DeviceAddr address() const { return address_; };
Expand All @@ -102,8 +103,39 @@ class MeshBuffer {
uint32_t datum_size_bytes() const;
Shape2D physical_shard_shape() const;
std::pair<bool, bool> replicated_dims() const;
uint32_t page_size() const { return device_local_config_.page_size; }
DeviceAddr page_size() const { return device_local_config_.page_size; }
void set_page_size(DeviceAddr page_size);
uint32_t num_pages() const { return page_size() == 0 ? 0 : device_local_size_ / page_size(); }
uint32_t num_dev_pages() const;

BufferType buffer_type() const { return device_local_config_.buffer_type; }
CoreType core_type() const;

bool is_l1() const;
bool is_dram() const;
bool is_trace() const;

bool is_valid_region(const BufferRegion& region) const;
bool is_valid_partial_region(const BufferRegion& region) const;

TensorMemoryLayout buffer_layout() const { return device_local_config_.buffer_layout; }

bool bottom_up() const { return device_local_config_.bottom_up.value(); }

DeviceAddr page_address(uint32_t bank_id, uint32_t page_index) const;
DeviceAddr bank_local_page_address(uint32_t bank_id, uint32_t page_index) const;
uint32_t alignment() const;
DeviceAddr aligned_page_size() const;
DeviceAddr aligned_size() const;
DeviceAddr aligned_size_per_bank() const;

DeviceAddr sharded_page_address(uint32_t bank_id, uint32_t page_index) const;
ShardSpecBuffer shard_spec() const;
void set_shard_spec(const ShardSpecBuffer& shard_spec);
std::optional<uint32_t> num_cores() const;
const std::shared_ptr<const BufferPageMapping>& get_buffer_page_mapping();
std::optional<SubDeviceId> sub_device_id() const;
size_t unique_id() const { return unique_id_; }

private:
// Creates an owning `MeshBuffer`, backed by an allocation made through `backing_buffer`.
Expand All @@ -112,13 +144,15 @@ class MeshBuffer {
const DeviceLocalBufferConfig& device_local_config,
DeviceAddr device_local_size,
MeshDevice* mesh_device,
size_t unique_id,
std::shared_ptr<Buffer> backing_buffer) :
buffers_(MeshShape(mesh_device->shape()), nullptr),
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device->shared_from_this()),
address_(backing_buffer->address()),
device_local_size_(device_local_size),
unique_id_(unique_id),
state_(OwnedBufferState{std::move(backing_buffer)}) {}

// Creates a non-owning `MeshBuffer` as "view" over an existing `address`.
Expand All @@ -127,16 +161,19 @@ class MeshBuffer {
const DeviceLocalBufferConfig& device_local_config,
DeviceAddr address,
DeviceAddr device_local_size,
MeshDevice* mesh_device) :
MeshDevice* mesh_device,
size_t unique_id) :
buffers_(MeshShape(mesh_device->shape()), /*fill_value=*/nullptr),
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device->shared_from_this()),
address_(address),
device_local_size_(device_local_size),
unique_id_(unique_id),
state_(ExternallyOwnedState{}) {}

void initialize_device_buffers();
bool is_sharded() const;
MeshBufferConfig config_;
DeviceLocalBufferConfig device_local_config_;
std::weak_ptr<MeshDevice> mesh_device_;
Expand All @@ -156,6 +193,9 @@ class MeshBuffer {
struct DeallocatedState {};
using MeshBufferState = std::variant<OwnedBufferState, ExternallyOwnedState, DeallocatedState>;
MeshBufferState state_;
size_t unique_id_;
Allocator* allocator_;
std::shared_ptr<const BufferPageMapping> buffer_page_mapping_;
};

} // namespace tt::tt_metal::distributed
Loading

0 comments on commit 3635dc0

Please sign in to comment.