diff --git a/include/ttmlir/Target/Common/debug_info.fbs b/include/ttmlir/Target/Common/debug_info.fbs index e250c13b60..75cd5ec6ab 100644 --- a/include/ttmlir/Target/Common/debug_info.fbs +++ b/include/ttmlir/Target/Common/debug_info.fbs @@ -10,9 +10,14 @@ table GoldenTensor { data: [uint8]; } +table GoldenDevice { + device: uint32; + value: GoldenTensor; +} + table GoldenKV { key: string; - value: GoldenTensor; + value: [GoldenDevice]; } table GoldenInfo { diff --git a/include/ttmlir/Target/TTMetal/TTMetalToFlatbuffer.h b/include/ttmlir/Target/TTMetal/TTMetalToFlatbuffer.h index a6f8b80855..dbaa9204ca 100644 --- a/include/ttmlir/Target/TTMetal/TTMetalToFlatbuffer.h +++ b/include/ttmlir/Target/TTMetal/TTMetalToFlatbuffer.h @@ -15,7 +15,16 @@ namespace mlir::tt::ttmetal { // stream. LogicalResult translateTTMetalToFlatbuffer( Operation *op, llvm::raw_ostream &os, - std::unordered_map goldenMap = {}); + /* golden map has following structure + { + loc: { + device_id: GoldenTensor + } + } + */ + std::unordered_map> + goldenMap = {}); } // namespace mlir::tt::ttmetal #endif diff --git a/include/ttmlir/Target/TTNN/TTNNToFlatbuffer.h b/include/ttmlir/Target/TTNN/TTNNToFlatbuffer.h index 0712f08208..25e3e4f243 100644 --- a/include/ttmlir/Target/TTNN/TTNNToFlatbuffer.h +++ b/include/ttmlir/Target/TTNN/TTNNToFlatbuffer.h @@ -14,7 +14,7 @@ namespace mlir::tt::ttnn { // Convert a TTNNIR operation to a flatbuffer std::shared_ptr ttnnToFlatbuffer( Operation *op, - const std::unordered_map &goldenMap = {}, + const std::unordered_map> &goldenMap = {}, const std::vector> &moduleCache = {}); // Convert a TTNNIR operation to a flatbuffer @@ -22,7 +22,7 @@ std::shared_ptr ttnnToFlatbuffer( // mlir translation framework LogicalResult translateTTNNToFlatbuffer( Operation *op, llvm::raw_ostream &os, - const std::unordered_map &goldenMap = {}, + const std::unordered_map> &goldenMap = {}, const std::vector> &moduleCache = {}); } // namespace mlir::tt::ttnn diff --git a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp index 933e9b3390..83f3c02ccc 100644 --- a/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp +++ b/lib/Target/TTMetal/TTMetalToFlatbuffer.cpp @@ -258,7 +258,10 @@ Value getOperandThroughDPSOps(Value value) { } static std::shared_ptr translateModuleToFlatbuffer( - Operation *op, std::unordered_map goldenMap) { + Operation *op, + std::unordered_map> + goldenMap) { ::flatbuffers::FlatBufferBuilder fbb; FlatbufferObjectCache cache(&fbb); @@ -452,7 +455,9 @@ static std::shared_ptr translateModuleToFlatbuffer( LogicalResult translateTTMetalToFlatbuffer( Operation *op, llvm::raw_ostream &os, - std::unordered_map goldenMap) { + std::unordered_map> + goldenMap) { std::shared_ptr data = translateModuleToFlatbuffer(op, goldenMap); std::size_t size = ::flatbuffers::GetSizePrefixedBufferLength( static_cast(data.get())); diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index d49e1d65f4..f1dbf1a37f 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -1444,7 +1444,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, std::shared_ptr ttnnToFlatbuffer( Operation *op, - const std::unordered_map &goldenMap, + const std::unordered_map> + &goldenMap, const std::vector> &moduleCache) { ModuleOp module = dyn_cast(op); assert(module && "Expected ModuleOp as top level operation"); @@ -1469,13 +1471,24 @@ std::shared_ptr ttnnToFlatbuffer( std::vector<::flatbuffers::Offset<::tt::target::GoldenKV>> goldenKVList; goldenKVList.reserve(goldenMap.size()); - for (auto element : goldenMap) { - std::vector dataTensor = element.second.convertDataToVector(); - auto goldenTensor = ::tt::target::CreateGoldenTensorDirect( - fbb, element.second.name.c_str(), &element.second.shape, - &element.second.strides, element.second.dtype, &dataTensor); + for (auto locMap : goldenMap) { + std::vector<::flatbuffers::Offset<::tt::target::GoldenDevice>> + goldenDeviceList; + goldenDeviceList.reserve(locMap.second.size()); + + for (auto tensorMap : locMap.second) { + std::vector dataTensor = + tensorMap.second.convertDataToVector(); + auto goldenTensor = ::tt::target::CreateGoldenTensorDirect( + fbb, tensorMap.second.name.c_str(), &tensorMap.second.shape, + &tensorMap.second.strides, tensorMap.second.dtype, &dataTensor); + auto goldenDevice = + ::tt::target::CreateGoldenDevice(fbb, tensorMap.first, goldenTensor); + goldenDeviceList.push_back(goldenDevice); + } + auto goldenKV = ::tt::target::CreateGoldenKVDirect( - fbb, element.first.c_str(), goldenTensor); + fbb, locMap.first.c_str(), &goldenDeviceList); goldenKVList.push_back(goldenKV); } @@ -1522,7 +1535,7 @@ std::shared_ptr ttnnToFlatbuffer( LogicalResult translateTTNNToFlatbuffer( Operation *op, llvm::raw_ostream &os, - const std::unordered_map &goldenMap, + const std::unordered_map> &goldenMap, const std::vector> &moduleCache) { std::shared_ptr data = ttnnToFlatbuffer(op, goldenMap, moduleCache); std::size_t size = ::flatbuffers::GetSizePrefixedBufferLength( diff --git a/python/Passes.cpp b/python/Passes.cpp index fcba3a2728..6e6fe5c36a 100644 --- a/python/Passes.cpp +++ b/python/Passes.cpp @@ -175,8 +175,7 @@ void populatePassesModule(py::module &m) { m.def( "ttnn_to_flatbuffer_file", [](MlirModule module, std::string &filepath, - const std::unordered_map - &goldenMap = {}, + const std::unordered_map> &goldenMap = {}, const std::vector> &moduleCache = {}) { mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module)); @@ -201,7 +200,10 @@ void populatePassesModule(py::module &m) { m.def("ttmetal_to_flatbuffer_file", [](MlirModule module, std::string &filepath, - std::unordered_map goldenMap) { + std::unordered_map< + std::string, + std::unordered_map> + &goldenMap) { mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module)); std::error_code fileError; llvm::raw_fd_ostream file(filepath, fileError); diff --git a/python/test_infra/ttir_builder.py b/python/test_infra/ttir_builder.py index 5d673bcb89..faa6704a65 100644 --- a/python/test_infra/ttir_builder.py +++ b/python/test_infra/ttir_builder.py @@ -171,14 +171,20 @@ def generate_input_golden( def get_golden_map(self) -> Dict: golden_info = {} for name, golden_tensor in self.id_golden_map.items(): + golden_device_info = {} golden_tensor = golden_tensor.contiguous() - golden_info[name] = create_golden_tensor( + + # for now, assume all golden tensors live on device 0 (todo: tapspatel - extend to multichip) + golden_device_info[0] = create_golden_tensor( name, list(golden_tensor.tensor.shape), list(golden_tensor.tensor.stride()), DataType.Float32, golden_tensor.tensor.data_ptr(), ) + + golden_info[name] = golden_device_info + return golden_info # ----- Private helpers ----- diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index 4f14f6b5b4..c3790f5f1f 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -65,10 +65,12 @@ std::string getOpDebugString(OpContext opContextHandle); std::string getOpLocInfo(OpContext opContextHandle); -Tensor getOpOutputTensor(OpContext opContextHandle, - CallbackContext programContextHandle); +std::unordered_map +getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle); -std::vector getTensorData(Tensor tensor); +std::unordered_map> +getTensorData(std::unordered_map tensor_map); using InputBuffer = std::tuple, diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index abd2343360..5706be5b92 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -130,10 +130,12 @@ std::string getOpDebugString(OpContext opContextHandle); std::string getOpLocInfo(OpContext opContextHandle); -Tensor getOpOutputTensor(OpContext opContextHandle, - CallbackContext programContextHandle); +std::unordered_map +getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle); -std::vector getTensorData(Tensor tensor); +std::unordered_map> +getTensorData(std::unordered_map tensor_map); std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 7e94a506fa..aa2e87dfa8 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -117,10 +117,12 @@ std::string getOpDebugString(OpContext opContextHandle); std::string getOpLocInfo(OpContext opContextHandle); -Tensor getOpOutputTensor(OpContext opContextHandle, - CallbackContext programContextHandle); +std::unordered_map +getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle); -std::vector getTensorData(Tensor tensor); +std::unordered_map> +getTensorData(std::unordered_map tensor_map); std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index e0f4e7b219..44baebd585 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -64,6 +64,7 @@ struct RuntimeCheckedObjectImpl { std::shared_ptr handle; ::tt::runtime::DeviceRuntime associatedRuntime; + RuntimeCheckedObjectImpl() = default; RuntimeCheckedObjectImpl(std::shared_ptr handle, ::tt::runtime::DeviceRuntime runtime) : handle(handle), associatedRuntime(runtime) {} @@ -128,7 +129,8 @@ struct Binary : public Flatbuffer { std::vector getProgramInputs(std::uint32_t programIndex) const; std::vector getProgramOutputs(std::uint32_t programIndex) const; - const ::tt::target::GoldenTensor *getDebugInfoGolden(std::string &loc) const; + std::unordered_map + getDebugInfoGolden(std::string &loc) const; }; struct Device : public detail::RuntimeCheckedObjectImpl { @@ -142,6 +144,7 @@ struct Event : public detail::RuntimeCheckedObjectImpl { struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr data; Event event; + Tensor() = default; Tensor(std::shared_ptr handle, std::shared_ptr data, DeviceRuntime runtime) : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), diff --git a/runtime/lib/binary.cpp b/runtime/lib/binary.cpp index fd0037389f..b0d3fdecdf 100644 --- a/runtime/lib/binary.cpp +++ b/runtime/lib/binary.cpp @@ -111,21 +111,25 @@ std::vector getProgramOutputs(Flatbuffer binary, return outputs; } -const ::tt::target::GoldenTensor *getDebugInfoGolden(Flatbuffer binary, - std::string &loc) { +std::unordered_map +getDebugInfoGolden(Flatbuffer binary, std::string &loc) { + std::unordered_map + goldenTensorDeviceMap; + auto const *programs = getBinary(binary)->programs(); for (auto const *program : *programs) { for (const ::tt::target::GoldenKV *goldenKV : *program->debug_info()->golden_info()->golden_map()) { if (std::string(goldenKV->key()->c_str()) == loc) { - return goldenKV->value(); - ; + for (const ::tt::target::GoldenDevice *goldenDevice : + *goldenKV->value()) { + goldenTensorDeviceMap[goldenDevice->device()] = goldenDevice->value(); + } } } } - LOG_WARNING("Golden information not found"); - return nullptr; + return goldenTensorDeviceMap; } } // namespace ttnn @@ -198,10 +202,10 @@ std::vector getProgramOutputs(Flatbuffer binary, return outputs; } -const ::tt::target::GoldenTensor *getDebugInfoGolden(Flatbuffer binary, - std::string &loc) { +std::unordered_map +getDebugInfoGolden(Flatbuffer binary, std::string &loc) { LOG_WARNING("Debug golden information not enabled for metal yet!"); - return nullptr; + return {}; } } // namespace metal @@ -368,7 +372,7 @@ Binary::getProgramOutputs(std::uint32_t programIndex) const { LOG_FATAL("Unsupported binary format"); } -const ::tt::target::GoldenTensor * +std::unordered_map Binary::getDebugInfoGolden(std::string &loc) const { if (::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier( handle.get())) { diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 887c3390c0..1c63a1dadd 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -437,8 +437,9 @@ std::string getOpLocInfo(OpContext opContextHandle) { throw std::runtime_error("runtime is not enabled"); } -Tensor getOpOutputTensor(OpContext opContextHandle, - CallbackContext programContextHandle) { +std::unordered_map +getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { return ::tt::runtime::ttnn::getOpOutputTensor(opContextHandle, @@ -455,16 +456,17 @@ Tensor getOpOutputTensor(OpContext opContextHandle, LOG_FATAL("runtime is not enabled"); } -std::vector getTensorData(Tensor tensor) { +std::unordered_map> +getTensorData(std::unordered_map tensor_map) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getTensorData(tensor); + return ::tt::runtime::ttnn::getTensorData(tensor_map); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getTensorData(tensor); + return ::tt::runtime::ttmetal::getTensorData(tensor_map); } #endif diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 1b09e71153..2db7de755a 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -30,10 +30,6 @@ static ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) { return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get()); } -static Tensor createNullTensor() { - return Tensor(nullptr, nullptr, DeviceRuntime::TTMetal); -} - static tt::runtime::MemoryView createMemoryView(tt::tt_metal::detail::MemoryView const &memoryView) { return tt::runtime::MemoryView{ @@ -348,14 +344,16 @@ std::string getOpLocInfo(OpContext opContextHandle) { return ""; } -Tensor getOpOutputTensor(OpContext opContextHandle, - CallbackContext programContextHandle) { +std::unordered_map +getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle) { // Not implemented LOG_WARNING("obtaining op output tensor for metal runtime not implemented"); - return createNullTensor(); + return {}; } -std::vector getTensorData(Tensor tensor) { +std::unordered_map> +getTensorData(std::unordered_map tensor_map) { // Not implemented LOG_WARNING("obtaining tensor data for metal runtime not implemented"); return {}; diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 13c1cf67e7..987a20ddfd 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -83,10 +83,6 @@ createOwnedTensor(std::shared_ptr data, ::ttnn::Layout::ROW_MAJOR); } -static Tensor createNullTensor() { - return Tensor(nullptr, nullptr, DeviceRuntime::TTNN); -} - static DeviceVariant getTargetDevice(::ttnn::MeshDevice &meshDevice) { if (meshDevice.num_devices() == 1) { return std::ref(*(meshDevice.get_device_index(0))); @@ -407,8 +403,9 @@ std::string getOpLocInfo(OpContext opContextHandle) { return std::string(opContext.loc_info()->c_str()); } -Tensor getOpOutputTensor(OpContext opContextHandle, - CallbackContext programContextHandle) { +std::unordered_map +getOpOutputTensor(OpContext opContextHandle, + CallbackContext programContextHandle) { auto const &programContext = programContextHandle.as( DeviceRuntime::TTNN); @@ -417,6 +414,7 @@ Tensor getOpOutputTensor(OpContext opContextHandle, const ttnn::ProgramTensorPool &tensorPool = programContext.getTensorPool(); std::int32_t globalId{-1}; const ::ttnn::Tensor *outPtr = nullptr; + std::unordered_map opOutputTensorMap; switch (opContext.type_type()) { case ::tt::target::ttnn::OpType::GetDeviceOp: { @@ -509,7 +507,7 @@ Tensor getOpOutputTensor(OpContext opContextHandle, } case ::tt::target::ttnn::OpType::DeallocateOp: { LOG_WARNING("getting output tensor for DeallocateOp is not supported"); - return createNullTensor(); + return opOutputTensorMap; } default: { LOG_FATAL("Unsupported operation type"); @@ -520,28 +518,66 @@ Tensor getOpOutputTensor(OpContext opContextHandle, outPtr = &tensorPool.at(globalId); } else { LOG_WARNING("Output tensor not found in tensor pool"); - return createNullTensor(); + return opOutputTensorMap; + } + + if (outPtr->storage_type() == StorageType::MULTI_DEVICE) { + const auto &tensor_storage = + std::get((*outPtr).get_storage()); + for (unsigned long i = 0; i < tensor_storage.ordered_device_ids.size(); + ++i) { + auto device_id = tensor_storage.ordered_device_ids[i]; + ::ttnn::Tensor ttnnTensor = ::ttnn::Tensor{ + DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, + tensor_storage.specs.at(device_id)}; + std::shared_ptr<::ttnn::Tensor> hostTensor = + std::make_shared<::ttnn::Tensor>(::ttnn::to_layout( + ::ttnn::from_device(ttnnTensor), ::ttnn::Layout::ROW_MAJOR, + std::nullopt, std::nullopt, + static_cast<::ttnn::IDevice *>(nullptr))); + opOutputTensorMap[device_id] = + Tensor(std::static_pointer_cast(hostTensor), nullptr, + DeviceRuntime::TTNN); + } + } else if (outPtr->storage_type() == StorageType::DEVICE) { + std::shared_ptr<::ttnn::Tensor> hostTensor = + std::make_shared<::ttnn::Tensor>(::ttnn::to_layout( + ::ttnn::from_device(*outPtr), ::ttnn::Layout::ROW_MAJOR, + std::nullopt, std::nullopt, + static_cast<::ttnn::IDevice *>(nullptr))); + opOutputTensorMap[0] = Tensor(std::static_pointer_cast(hostTensor), + nullptr, DeviceRuntime::TTNN); + } else { + LOG_WARNING("Unsupported storage type of output tensor. Cannot acquire " + "from device"); + return opOutputTensorMap; } - std::shared_ptr<::ttnn::Tensor> hostTensor = - std::make_shared<::ttnn::Tensor>(::ttnn::to_layout( - ::ttnn::from_device(*outPtr), ::ttnn::Layout::ROW_MAJOR, std::nullopt, - std::nullopt, static_cast<::ttnn::IDevice *>(nullptr))); - - return Tensor(std::static_pointer_cast(hostTensor), nullptr, - DeviceRuntime::TTNN); + return opOutputTensorMap; } -std::vector getTensorData(Tensor tensor) { - const ::ttnn::Tensor *nnTensor = - static_cast<::ttnn::Tensor *>(tensor.handle.get()); - if (nnTensor == nullptr) { - return {}; +std::unordered_map> +getTensorData(std::unordered_map tensor_map) { + std::unordered_map> tensor_data_map; + + for (auto tensor_element : tensor_map) { + auto device_id = tensor_element.first; + auto tensor = tensor_element.second; + + const ::ttnn::Tensor *nnTensor = + static_cast<::ttnn::Tensor *>(tensor.handle.get()); + + if (nnTensor == nullptr) { + continue; + } + + void *dataPtr = ::tt::tt_metal::get_raw_host_data_ptr(*nnTensor); + tensor_data_map[device_id] = + std::vector(static_cast(dataPtr), + static_cast(dataPtr) + nnTensor->volume()); } - void *dataPtr = ::tt::tt_metal::get_raw_host_data_ptr(*nnTensor); - return std::vector(static_cast(dataPtr), - static_cast(dataPtr) + nnTensor->volume()); + return tensor_data_map; } std::vector submit(Device deviceHandle, Binary executableHandle, diff --git a/runtime/tools/python/ttrt/common/callback.py b/runtime/tools/python/ttrt/common/callback.py index 93a9af267b..6c3951c3ec 100644 --- a/runtime/tools/python/ttrt/common/callback.py +++ b/runtime/tools/python/ttrt/common/callback.py @@ -204,22 +204,29 @@ def golden(callback_runtime_config, binary, program_context, op_context): loc = ttrt.runtime.get_op_loc_info(op_context) - op_golden_tensor = binary.get_debug_info_golden(loc) + op_golden_tensor_map = binary.get_debug_info_golden(loc) - if op_golden_tensor is None: + if len(op_golden_tensor_map) == 0: logging.debug("Golden tensor is None - skipping golden comparison") return - op_output_tensor = ttrt.runtime.get_op_output_tensor(op_context, program_context) + op_output_tensor_map = ttrt.runtime.get_op_output_tensor( + op_context, program_context + ) - if len(op_output_tensor) == 0: + if len(op_output_tensor_map) == 0: logging.debug("Output tensor is empty - skipping golden comparison") return + op_golden_tensor = op_golden_tensor_map[ + 0 + ] # todo: tapspatel - currently it's supported for single device, extend to multi-device dtype = ttrt_datatype_to_torch_dtype(op_golden_tensor.dtype) - golden_tensor_torch = torch.frombuffer(op_golden_tensor, dtype=dtype).flatten() + op_output_tensor = op_output_tensor_map[ + 0 + ] # todo: tapspatel - currently it's supported for single device, extend to multi-device output_tensor_torch = torch.tensor(op_output_tensor, dtype=dtype).flatten() if callback_runtime_config.save_golden_tensors: diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 3bb1d88299..44f0d96efd 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -174,9 +174,10 @@ PYBIND11_MODULE(_C, m) { "get_op_output_tensor", [](tt::runtime::OpContext &opContextHandle, tt::runtime::CallbackContext &programContextHandle) { - tt::runtime::Tensor tensor = tt::runtime::getOpOutputTensor( - opContextHandle, programContextHandle); - return tt::runtime::getTensorData(tensor); + std::unordered_map tensor_map = + tt::runtime::getOpOutputTensor(opContextHandle, + programContextHandle); + return tt::runtime::getTensorData(tensor_map); }, "Get the input tensor of the op"); m.def("get_op_debug_str", &tt::runtime::getOpDebugString,