Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: Add missing MeshBuffer APIs #18817

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tt_metal/api/tt-metalium/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,14 @@ class Buffer final {

} // namespace v0

std::tuple<std::vector<std::vector<uint32_t>>, std::vector<std::array<uint32_t, 2>>> core_to_host_pages(
const uint32_t total_pages,
const uint32_t pages_per_shard,
const uint32_t num_shards,
const TensorMemoryLayout layout,
const std::array<uint32_t, 2>& page_shape,
const std::array<uint32_t, 2>& shard_shape,
const std::array<uint32_t, 2>& tensor2d_size);
BufferPageMapping generate_buffer_page_mapping(const Buffer &buffer);

inline namespace v0 {
Expand Down
44 changes: 42 additions & 2 deletions tt_metal/api/tt-metalium/mesh_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class MeshBuffer {

// Throws an exception if the corresponding MeshDevice is already deallocated
MeshDevice* device() const;
Allocator* allocator() const { return allocator_; }
DeviceAddr size() const;
DeviceAddr device_local_size() const { return device_local_size_; }
DeviceAddr address() const { return address_; };
Expand All @@ -102,8 +103,39 @@ class MeshBuffer {
uint32_t datum_size_bytes() const;
Shape2D physical_shard_shape() const;
std::pair<bool, bool> replicated_dims() const;
uint32_t page_size() const { return device_local_config_.page_size; }
DeviceAddr page_size() const { return device_local_config_.page_size; }
void set_page_size(DeviceAddr page_size);
uint32_t num_pages() const { return page_size() == 0 ? 0 : device_local_size_ / page_size(); }
uint32_t num_dev_pages() const;
Copy link
Contributor Author

@tt-dma tt-dma Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used in one op: ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp


BufferType buffer_type() const { return device_local_config_.buffer_type; }
CoreType core_type() const;
Copy link
Contributor Author

@tt-dma tt-dma Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used in: ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp only as well


bool is_l1() const;
bool is_dram() const;
bool is_trace() const;

bool is_valid_region(const BufferRegion& region) const;
bool is_valid_partial_region(const BufferRegion& region) const;
Copy link
Contributor Author

@tt-dma tt-dma Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two are only used in tests/tt_metal/tt_metal/api/test_buffer_region.cpp and tt_metal/impl/buffers/dispatch.cpp


TensorMemoryLayout buffer_layout() const { return device_local_config_.buffer_layout; }

bool bottom_up() const { return device_local_config_.bottom_up.value(); }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No ttnn uses, but used in allocator/global_semaphore/lightmetal


DeviceAddr page_address(uint32_t bank_id, uint32_t page_index) const;
DeviceAddr bank_local_page_address(uint32_t bank_id, uint32_t page_index) const;
Copy link
Contributor Author

@tt-dma tt-dma Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two are used in metal api, and ttnn reports

uint32_t alignment() const;
DeviceAddr aligned_page_size() const;
DeviceAddr aligned_size() const;
DeviceAddr aligned_size_per_bank() const;

DeviceAddr sharded_page_address(uint32_t bank_id, uint32_t page_index) const;
ShardSpecBuffer shard_spec() const;
void set_shard_spec(const ShardSpecBuffer& shard_spec);
std::optional<uint32_t> num_cores() const;
const std::shared_ptr<const BufferPageMapping>& get_buffer_page_mapping();
std::optional<SubDeviceId> sub_device_id() const;
size_t unique_id() const { return unique_id_; }
Copy link
Contributor Author

@tt-dma tt-dma Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just used in lightmetal + ttnn/cpp/ttnn/graph/graph_processor.cpp


private:
// Creates an owning `MeshBuffer`, backed by an allocation made through `backing_buffer`.
Expand All @@ -112,13 +144,15 @@ class MeshBuffer {
const DeviceLocalBufferConfig& device_local_config,
DeviceAddr device_local_size,
MeshDevice* mesh_device,
size_t unique_id,
std::shared_ptr<Buffer> backing_buffer) :
buffers_(MeshShape(mesh_device->shape()), nullptr),
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device->shared_from_this()),
address_(backing_buffer->address()),
device_local_size_(device_local_size),
unique_id_(unique_id),
state_(OwnedBufferState{std::move(backing_buffer)}) {}

// Creates a non-owning `MeshBuffer` as "view" over an existing `address`.
Expand All @@ -127,16 +161,19 @@ class MeshBuffer {
const DeviceLocalBufferConfig& device_local_config,
DeviceAddr address,
DeviceAddr device_local_size,
MeshDevice* mesh_device) :
MeshDevice* mesh_device,
size_t unique_id) :
buffers_(MeshShape(mesh_device->shape()), /*fill_value=*/nullptr),
config_(config),
device_local_config_(device_local_config),
mesh_device_(mesh_device->shared_from_this()),
address_(address),
device_local_size_(device_local_size),
unique_id_(unique_id),
state_(ExternallyOwnedState{}) {}

void initialize_device_buffers();
bool is_sharded() const;
MeshBufferConfig config_;
DeviceLocalBufferConfig device_local_config_;
std::weak_ptr<MeshDevice> mesh_device_;
Expand All @@ -156,6 +193,9 @@ class MeshBuffer {
struct DeallocatedState {};
using MeshBufferState = std::variant<OwnedBufferState, ExternallyOwnedState, DeallocatedState>;
MeshBufferState state_;
size_t unique_id_;
Allocator* allocator_;
std::shared_ptr<const BufferPageMapping> buffer_page_mapping_;
};

} // namespace tt::tt_metal::distributed
Loading
Loading