Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Python Bindings from being defined with Pybind -> Nanobind #2379

Merged
merged 5 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions include/ttmlir/Bindings/Python/TTMLIRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#define TTMLIR_BINDINGS_PYTHON_TTMLIRMODULE_H

#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
Expand All @@ -21,48 +22,49 @@
#include "ttmlir/RegisterAll.h"
#include "llvm/Support/CommandLine.h"

#include <nanobind/stl/variant.h>
#include <variant>

namespace py = pybind11;
namespace nb = nanobind;

namespace mlir::ttmlir::python {

template <typename T>
py::class_<T> tt_attribute_class(py::module &m, const char *class_name) {
py::class_<T> cls(m, class_name);
nb::class_<T> tt_attribute_class(nb::module_ &m, const char *class_name) {
nb::class_<T> cls(m, class_name);
cls.def_static("maybe_downcast",
[](MlirAttribute attr) -> std::variant<T, py::object> {
[](MlirAttribute attr) -> std::variant<T, nb::object> {
auto res = mlir::dyn_cast<T>(unwrap(attr));
if (res) {
return res;
}
return py::none();
return nb::none();
});
return cls;
}

template <typename T>
py::class_<T> tt_type_class(py::module &m, const char *class_name) {
py::class_<T> cls(m, class_name);
nb::class_<T> tt_type_class(nb::module_ &m, const char *class_name) {
nb::class_<T> cls(m, class_name);
cls.def_static("maybe_downcast",
[](MlirType type) -> std::variant<T, py::object> {
[](MlirType type) -> std::variant<T, nb::object> {
auto res = mlir::dyn_cast<T>(unwrap(type));
if (res) {
return res;
}
return py::none();
return nb::none();
});
return cls;
}

void populateTTModule(py::module &m);
void populateTTIRModule(py::module &m);
void populateTTKernelModule(py::module &m);
void populateTTNNModule(py::module &m);
void populateOverridesModule(py::module &m);
void populateOptimizerOverridesModule(py::module &m);
void populatePassesModule(py::module &m);
void populateUtilModule(py::module &m);
void populateTTModule(nb::module_ &m);
void populateTTIRModule(nb::module_ &m);
void populateTTKernelModule(nb::module_ &m);
void populateTTNNModule(nb::module_ &m);
void populateOverridesModule(nb::module_ &m);
void populateOptimizerOverridesModule(nb::module_ &m);
void populatePassesModule(nb::module_ &m);
void populateUtilModule(nb::module_ &m);
} // namespace mlir::ttmlir::python

#endif // TTMLIR_BINDINGS_PYTHON_TTMLIRMODULE_H
11 changes: 6 additions & 5 deletions include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,15 @@ class OptimizerOverridesHandler {

// Wrapper methods we use to expose the adders to the python bindings
std::unordered_map<std::string, InputLayoutOverrideParams>
getInputLayoutOverridesPybindWrapper() const;
getInputLayoutOverridesNanobindWrapper() const;
std::unordered_map<std::string, OutputLayoutOverrideParams>
getOutputLayoutOverridesPybindWrapper() const;
getOutputLayoutOverridesNanobindWrapper() const;

// Wrapper methods we use to expose the adders to the python bindings
void addInputLayoutOverridePybindWrapper(std::string, std::vector<int64_t> &);
void addOutputLayoutOverridePybindWrapper(std::string,
OutputLayoutOverrideParams);
void addInputLayoutOverrideNanobindWrapper(std::string,
std::vector<int64_t> &);
void addOutputLayoutOverrideNanobindWrapper(std::string,
OutputLayoutOverrideParams);

private:
// Flags for enabling/disabling the optimizer passes
Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ struct GoldenTensor {
std::vector<std::uint8_t> &&_data)
: name(name), shape(shape), strides(strides), dtype(dtype),
data(std::move(_data)) {}

// Create an explicit empty constructor
GoldenTensor() {}
};

inline ::tt::target::OOBVal toFlatbuffer(FlatbufferObjectCache &,
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ OptimizerOverridesHandler::getOutputLayoutOverrides() const {
}

std::unordered_map<std::string, InputLayoutOverrideParams>
OptimizerOverridesHandler::getInputLayoutOverridesPybindWrapper() const {
OptimizerOverridesHandler::getInputLayoutOverridesNanobindWrapper() const {
std::unordered_map<std::string, InputLayoutOverrideParams>
inputLayoutOverridesWrapper;
for (auto &entry : inputLayoutOverrides) {
Expand All @@ -93,7 +93,7 @@ OptimizerOverridesHandler::getInputLayoutOverridesPybindWrapper() const {
}

std::unordered_map<std::string, OutputLayoutOverrideParams>
OptimizerOverridesHandler::getOutputLayoutOverridesPybindWrapper() const {
OptimizerOverridesHandler::getOutputLayoutOverridesNanobindWrapper() const {
std::unordered_map<std::string, OutputLayoutOverrideParams>
outputLayoutOverridesWrapper;
for (auto &entry : outputLayoutOverrides) {
Expand Down Expand Up @@ -190,15 +190,15 @@ void OptimizerOverridesHandler::addOutputLayoutOverride(
std::move(grid), bufferType, tensorMemoryLayout, memoryLayout, dataType};
}

void OptimizerOverridesHandler::addInputLayoutOverridePybindWrapper(
void OptimizerOverridesHandler::addInputLayoutOverrideNanobindWrapper(
std::string opName, std::vector<int64_t> &operandIdxes) {
StringRef opNameStringRef(opName);
SmallVector<int64_t> operandIdxesSmallVector(operandIdxes.begin(),
operandIdxes.end());
addInputLayoutOverride(opNameStringRef, operandIdxesSmallVector);
}

void OptimizerOverridesHandler::addOutputLayoutOverridePybindWrapper(
void OptimizerOverridesHandler::addOutputLayoutOverrideNanobindWrapper(
std::string opName, OutputLayoutOverrideParams overrideParams) {
StringRef opNameStringRef(opName);
addOutputLayoutOverride(opNameStringRef, overrideParams);
Expand Down
1 change: 1 addition & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ declare_mlir_python_extension(TTMLIRPythonExtensions.Main
MLIRTestToLLVMIRTranslation
MLIRVCIXToLLVMIRTranslation
MLIRX86VectorToLLVMIRTranslation
PYTHON_BINDINGS_LIBRARY nanobind
)

set(TTMLIR_PYTHON_SOURCES
Expand Down
69 changes: 35 additions & 34 deletions python/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

namespace mlir::ttmlir::python {

void populateOptimizerOverridesModule(py::module &m) {
void populateOptimizerOverridesModule(nb::module_ &m) {

py::class_<tt::ttnn::OptimizerOverridesHandler>(m,
nb::class_<tt::ttnn::OptimizerOverridesHandler>(m,
"OptimizerOverridesHandler")
.def(py::init<>())
.def(nb::init<>())

.def("set_enable_optimizer",
&tt::ttnn::OptimizerOverridesHandler::setEnableOptimizer)
Expand Down Expand Up @@ -56,47 +56,48 @@ void populateOptimizerOverridesModule(py::module &m) {

.def("get_input_layout_overrides",
&tt::ttnn::OptimizerOverridesHandler::
getInputLayoutOverridesPybindWrapper)
getInputLayoutOverridesNanobindWrapper)
.def("get_output_layout_overrides",
&tt::ttnn::OptimizerOverridesHandler::
getOutputLayoutOverridesPybindWrapper)
getOutputLayoutOverridesNanobindWrapper)

.def("add_input_layout_override", &tt::ttnn::OptimizerOverridesHandler::
addInputLayoutOverridePybindWrapper)
.def("add_input_layout_override",
&tt::ttnn::OptimizerOverridesHandler::
addInputLayoutOverrideNanobindWrapper)
.def("add_output_layout_override",
&tt::ttnn::OptimizerOverridesHandler::
addOutputLayoutOverridePybindWrapper)
addOutputLayoutOverrideNanobindWrapper)

.def("to_string", &tt::ttnn::OptimizerOverridesHandler::toString);

py::enum_<mlir::tt::MemoryLayoutAnalysisPolicyType>(
nb::enum_<mlir::tt::MemoryLayoutAnalysisPolicyType>(
m, "MemoryLayoutAnalysisPolicyType")
.value("DFSharding", mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding)
.value("GreedyL1Interleaved",
mlir::tt::MemoryLayoutAnalysisPolicyType::GreedyL1Interleaved)
.value("BFInterleaved",
mlir::tt::MemoryLayoutAnalysisPolicyType::BFInterleaved);

py::enum_<mlir::tt::ttnn::BufferType>(m, "BufferType")
nb::enum_<mlir::tt::ttnn::BufferType>(m, "BufferType")
.value("DRAM", mlir::tt::ttnn::BufferType::DRAM)
.value("L1", mlir::tt::ttnn::BufferType::L1)
.value("SystemMemory", mlir::tt::ttnn::BufferType::SystemMemory)
.value("L1Small", mlir::tt::ttnn::BufferType::L1Small)
.value("Trace", mlir::tt::ttnn::BufferType::Trace);

py::enum_<mlir::tt::ttnn::Layout>(m, "Layout")
nb::enum_<mlir::tt::ttnn::Layout>(m, "Layout")
.value("RowMajor", mlir::tt::ttnn::Layout::RowMajor)
.value("Tile", mlir::tt::ttnn::Layout::Tile)
.value("Invalid", mlir::tt::ttnn::Layout::Invalid);

py::enum_<mlir::tt::ttnn::TensorMemoryLayout>(m, "TensorMemoryLayout")
nb::enum_<mlir::tt::ttnn::TensorMemoryLayout>(m, "TensorMemoryLayout")
.value("Interleaved", mlir::tt::ttnn::TensorMemoryLayout::Interleaved)
.value("SingleBank", mlir::tt::ttnn::TensorMemoryLayout::SingleBank)
.value("HeightSharded", mlir::tt::ttnn::TensorMemoryLayout::HeightSharded)
.value("WidthSharded", mlir::tt::ttnn::TensorMemoryLayout::WidthSharded)
.value("BlockSharded", mlir::tt::ttnn::TensorMemoryLayout::BlockSharded);

py::enum_<mlir::tt::DataType>(m, "DataType")
nb::enum_<mlir::tt::DataType>(m, "DataType")
.value("Float32", mlir::tt::DataType::Float32)
.value("Float16", mlir::tt::DataType::Float16)
.value("BFloat16", mlir::tt::DataType::BFloat16)
Expand All @@ -111,10 +112,10 @@ void populateOptimizerOverridesModule(py::module &m) {
.value("UInt8", mlir::tt::DataType::UInt8)
.value("Int32", mlir::tt::DataType::Int32);

py::class_<mlir::tt::ttnn::InputLayoutOverrideParams>(
nb::class_<mlir::tt::ttnn::InputLayoutOverrideParams>(
m, "InputLayoutOverrideParams")
.def(py::init<>())
.def_property(
.def(nb::init<>())
.def_prop_rw(
"operand_idxes",
[](const mlir::tt::ttnn::InputLayoutOverrideParams &obj) {
// Getter: Convert SmallVector to std::vector
Expand All @@ -128,10 +129,10 @@ void populateOptimizerOverridesModule(py::module &m) {
obj.operandIdxes.append(input.begin(), input.end());
});

py::class_<mlir::tt::ttnn::OutputLayoutOverrideParams>(
nb::class_<mlir::tt::ttnn::OutputLayoutOverrideParams>(
m, "OutputLayoutOverrideParams")
.def(py::init<>())
.def_property(
.def(nb::init<>())
.def_prop_rw(
"grid",
[](const mlir::tt::ttnn::OutputLayoutOverrideParams &obj) {
// Getter: Convert SmallVector to std::vector
Expand All @@ -151,20 +152,20 @@ void populateOptimizerOverridesModule(py::module &m) {
}
obj.grid->append(input.begin(), input.end());
})
.def_readwrite("buffer_type",
&mlir::tt::ttnn::OutputLayoutOverrideParams::bufferType)
.def_readwrite(
"tensor_memory_layout",
&mlir::tt::ttnn::OutputLayoutOverrideParams::tensorMemoryLayout)
.def_readwrite("memory_layout",
&mlir::tt::ttnn::OutputLayoutOverrideParams::memoryLayout)
.def_readwrite("data_type",
&mlir::tt::ttnn::OutputLayoutOverrideParams::dataType)
.def_rw("buffer_type",
&mlir::tt::ttnn::OutputLayoutOverrideParams::bufferType)
.def_rw("tensor_memory_layout",
&mlir::tt::ttnn::OutputLayoutOverrideParams::tensorMemoryLayout)
.def_rw("memory_layout",
&mlir::tt::ttnn::OutputLayoutOverrideParams::memoryLayout)
.def_rw("data_type",
&mlir::tt::ttnn::OutputLayoutOverrideParams::dataType)
.def("set_buffer_type_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto bufferType = mlir::tt::ttnn::symbolizeBufferType(value)) {
obj.bufferType = bufferType;
if (auto bufferType_ =
mlir::tt::ttnn::symbolizeBufferType(value)) {
obj.bufferType = bufferType_;
} else {
throw std::invalid_argument("Invalid buffer type: " + value);
}
Expand All @@ -183,17 +184,17 @@ void populateOptimizerOverridesModule(py::module &m) {
.def("set_memory_layout_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto memoryLayout = mlir::tt::ttnn::symbolizeLayout(value)) {
obj.memoryLayout = memoryLayout;
if (auto memoryLayout_ = mlir::tt::ttnn::symbolizeLayout(value)) {
obj.memoryLayout = memoryLayout_;
} else {
throw std::invalid_argument("Invalid memory layout: " + value);
}
})
.def("set_data_type_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto dataType = mlir::tt::DataTypeStringToEnum(value)) {
obj.dataType = dataType;
if (auto dataType_ = mlir::tt::DataTypeStringToEnum(value)) {
obj.dataType = dataType_;
} else {
throw std::invalid_argument("Invalid data type: " + value);
}
Expand Down
4 changes: 2 additions & 2 deletions python/Overrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

namespace mlir::ttmlir::python {

void populateOverridesModule(py::module &m) {
void populateOverridesModule(nb::module_ &m) {

m.def(
"get_ptr", [](void *op) { return reinterpret_cast<uintptr_t>(op); },
py::arg("op").noconvert());
nb::arg("op").noconvert());
}

} // namespace mlir::ttmlir::python
Loading
Loading