Skip to content

Commit

Permalink
unify dtype api
Browse files Browse the repository at this point in the history
  • Loading branch information
ctodTT committed Mar 5, 2025
1 parent a59fdfe commit 5ecf5cb
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 14 deletions.
1 change: 0 additions & 1 deletion runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ std::vector<std::uint32_t> getShape(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor);
std::uint32_t getElementSize(::tt::runtime::Tensor tensor);
std::uint32_t getVolume(::tt::runtime::Tensor tensor);
target::DataType getDtype(::tt::runtime::Tensor tensor);

size_t getNumAvailableDevices();

Expand Down
1 change: 0 additions & 1 deletion runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ std::vector<std::uint32_t> getShape(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor);
std::uint32_t getElementSize(::tt::runtime::Tensor tensor);
std::uint32_t getVolume(::tt::runtime::Tensor tensor);
target::DataType getDtype(::tt::runtime::Tensor tensor);

size_t getNumAvailableDevices();

Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,13 @@ std::vector<std::uint32_t> Tensor::getStride() {
target::DataType Tensor::getDtype() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getDtype(*this);
return ::tt::runtime::ttnn::getTensorDataType(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getDtype(*this);
return ::tt::runtime::ttmetal::getTensorDataType(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
Expand Down
4 changes: 0 additions & 4 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,5 @@ std::uint32_t getVolume(::tt::runtime::Tensor tensor) {
LOG_WARNING("getVolume not implemented for metal runtime");
return 0;
}
target::DataType getDtype(::tt::runtime::Tensor tensor) {
LOG_WARNING("getDtype not implemented for metal runtime");
return {};
}

} // namespace tt::runtime::ttmetal
7 changes: 1 addition & 6 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor) {

// Need to `memcpy` in each case because the vector will go out of scope if we
// wait until after the switch case
switch (getDtype(tensor)) {
switch (getTensorDataType(tensor)) {
case target::DataType::BFP_BFloat4:
case target::DataType::BFP_BFloat8:
case target::DataType::Float32:
Expand Down Expand Up @@ -639,11 +639,6 @@ std::uint32_t getVolume(::tt::runtime::Tensor tensor) {
return ttnnTensor->volume();
}

target::DataType getDtype(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
return utils::fromTTNNDataType(ttnnTensor->dtype());
}

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles) {
Expand Down

0 comments on commit 5ecf5cb

Please sign in to comment.