From 5ecf5cb619bae64cf8f84fb5f8d61f2f324ff89d Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Wed, 5 Mar 2025 18:48:20 +0000 Subject: [PATCH] unify dtype api --- runtime/include/tt/runtime/detail/ttmetal.h | 1 - runtime/include/tt/runtime/detail/ttnn.h | 1 - runtime/lib/runtime.cpp | 4 ++-- runtime/lib/ttmetal/runtime.cpp | 4 ---- runtime/lib/ttnn/runtime.cpp | 7 +------ 5 files changed, 3 insertions(+), 14 deletions(-) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index d7ff26fc40..74e81c66f6 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -35,7 +35,6 @@ std::vector getShape(::tt::runtime::Tensor tensor); std::vector getStride(::tt::runtime::Tensor tensor); std::uint32_t getElementSize(::tt::runtime::Tensor tensor); std::uint32_t getVolume(::tt::runtime::Tensor tensor); -target::DataType getDtype(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index cfea0d1344..a2f0ff6d17 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -99,7 +99,6 @@ std::vector getShape(::tt::runtime::Tensor tensor); std::vector getStride(::tt::runtime::Tensor tensor); std::uint32_t getElementSize(::tt::runtime::Tensor tensor); std::uint32_t getVolume(::tt::runtime::Tensor tensor); -target::DataType getDtype(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 38560eba62..bcf8b7786c 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -579,13 +579,13 @@ std::vector Tensor::getStride() { target::DataType Tensor::getDtype() { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getDtype(*this); + return ::tt::runtime::ttnn::getTensorDataType(*this); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getDtype(*this); + return ::tt::runtime::ttmetal::getTensorDataType(*this); } #endif LOG_FATAL("runtime is not enabled"); diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index fe6050c4e6..cd14e16a9b 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -387,9 +387,5 @@ std::uint32_t getVolume(::tt::runtime::Tensor tensor) { LOG_WARNING("getVolume not implemented for metal runtime"); return 0; } -target::DataType getDtype(::tt::runtime::Tensor tensor) { - LOG_WARNING("getDtype not implemented for metal runtime"); - return {}; -} } // namespace tt::runtime::ttmetal diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 261bdcb1ff..ee2fbd39f0 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -575,7 +575,7 @@ std::vector getDataBuffer(::tt::runtime::Tensor tensor) { // Need to `memcpy` in each case because the vector will go out of scope if we // wait until after the switch case - switch (getDtype(tensor)) { + switch (getTensorDataType(tensor)) { case target::DataType::BFP_BFloat4: case target::DataType::BFP_BFloat8: case target::DataType::Float32: @@ -639,11 +639,6 @@ std::uint32_t getVolume(::tt::runtime::Tensor tensor) { return ttnnTensor->volume(); } -target::DataType getDtype(::tt::runtime::Tensor tensor) { - auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); - return utils::fromTTNNDataType(ttnnTensor->dtype()); -} - std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles) {