Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime support for CPU hoist ops #2152

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions lib/RegisterAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ void mlir::tt::registerAllDialects(mlir::DialectRegistry &registry) {
mlir::stablehlo::registerAllDialects(registry);
mlir::sdy::registerAllDialects(registry);
#endif
}

void mlir::tt::registerAllExtensions(mlir::DialectRegistry &registry) {
// Both the inliner for TTIRDialect and FuncDialect must be registered
// since we use a combination of TTIRDialect and FuncDialect in the IR.
mlir::func::registerInlinerExtension(registry);
LLVM::registerInlinerInterface(registry);
// Registering BufferizableOpInterface for each dialect (including
// intermediate dialects) is required to convert types to memrefs during
// lowering.
Expand All @@ -61,13 +68,6 @@ void mlir::tt::registerAllDialects(mlir::DialectRegistry &registry) {
registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
LLVM::registerInlinerInterface(registry);
}

void mlir::tt::registerAllExtensions(mlir::DialectRegistry &registry) {
// Both the inliner for TTIRDialect and FuncDialect must be registered
// since we use a combination of TTIRDialect and FuncDialect in the IR.
mlir::func::registerInlinerExtension(registry);
}

void mlir::tt::registerAllPasses() {
Expand Down
7 changes: 7 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1742,6 +1742,13 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createOp(cache, constantOp), debugString,
locInfo);
}
if (auto callOp = dyn_cast<func::CallOp>(op); callOp) {
// TODO (#2355): Here dylib_id is hardcoded to 0. In the long run, we want
// to support multiple dylibs per flatbuffer, but the exact schema is not so
// clear.
return createOperation(cache, createCpuOp(cache, callOp, 0), debugString,
locInfo);
}

llvm_unreachable("unhandled op in emitTTNNOperation");
}
Expand Down
14 changes: 13 additions & 1 deletion python/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Conversion/Passes.h"
#include "mlir/InitAllTranslations.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"

#include "ttmlir/Bindings/Python/TTMLIRModule.h"
#include "ttmlir/Conversion/Passes.h"
#include "ttmlir/RegisterAll.h"
#include "ttmlir/Target/TTKernel/TTKernelToCpp.h"
#include "ttmlir/Target/TTMetal/TTMetalToFlatbuffer.h"
Expand Down Expand Up @@ -183,6 +185,16 @@ void populatePassesModule(py::module &m) {
{}) {
mlir::Operation *moduleOp = unwrap(mlirModuleGetOperation(module));

// Create a dialect registry and register all necessary dialects and
// translations
mlir::DialectRegistry registry;

// Register all LLVM IR translations
registerAllToLLVMIRTranslations(registry);

// Apply the registry to the module's context
moduleOp->getContext()->appendDialectRegistry(registry);

std::error_code fileError;
llvm::raw_fd_ostream file(filepath, fileError);

Expand Down
29 changes: 17 additions & 12 deletions python/TTMLIRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,24 @@ PYBIND11_MODULE(_ttmlir, m) {
m.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle tt_handle = mlirGetDialectHandle__tt__();
MlirDialectHandle ttir_handle = mlirGetDialectHandle__ttir__();
MlirDialectHandle ttkernel_handle = mlirGetDialectHandle__ttkernel__();
MlirDialectHandle ttnn_handle = mlirGetDialectHandle__ttnn__();
mlirDialectHandleRegisterDialect(tt_handle, context);
mlirDialectHandleRegisterDialect(ttir_handle, context);
mlirDialectHandleRegisterDialect(ttkernel_handle, context);
mlirDialectHandleRegisterDialect(ttnn_handle, context);
// Create a dialect registry
mlir::DialectRegistry registry;

// Register all dialects including LLVM dialect
mlir::tt::registerAllDialects(registry);

// Register all extensions (interfaces, etc.)
mlir::tt::registerAllExtensions(registry);

// Get the MLIRContext from MlirContext
mlir::MLIRContext *mlirContext = unwrap(context);

// Append the registry to the context
mlirContext->appendDialectRegistry(registry);

// If load is true, load all available dialects
if (load) {
mlirDialectHandleLoadDialect(tt_handle, context);
mlirDialectHandleLoadDialect(ttir_handle, context);
mlirDialectHandleLoadDialect(ttkernel_handle, context);
mlirDialectHandleLoadDialect(ttnn_handle, context);
mlirContext->loadAllAvailableDialects();
}
},
py::arg("context"), py::arg("load") = true);
Expand Down
1 change: 1 addition & 0 deletions python/test_infra/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import sys
import os
import inspect
import torch
Expand Down
53 changes: 52 additions & 1 deletion python/test_infra/ttir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,45 @@ def empty(

return op

def zeros(
self,
shape: Shape,
data_type: Optional[Type] = None,
) -> OpView:
"""Convenience wrapper constructing `ttir.ZerosOp`."""
dtype = data_type if data_type is not None else self._default_dtype
with self._ctx, self._loc:
# Create the result type
result_type = self.ranked_tensor_type(shape, dtype)

# Convert shape to a list if it's a tuple
shape_list = list(shape) if isinstance(shape, tuple) else shape

# Create the shape attribute
from ttmlir.ir import DenseI32ArrayAttr

shape_attr = DenseI32ArrayAttr.get(shape_list, context=self._ctx)

# Create the operation
from ttmlir.ir import Operation

# Create the operation directly using the low-level API to ensure it matches the IR format
op = Operation.create(
"ttir.zeros", # Operation name
results=[result_type], # Result types
operands=[], # No operands
attributes={"shape": shape_attr}, # Shape attribute
loc=self._loc, # Location
)

# Wrap in the appropriate OpView
# op = ttir.ZerosOp(operation)

# Generate and store random golden tensor
self.generate_and_store_random_golden(op)

return op

# ----- TTIR op factories -----
def _organize_eltwise_ttir(
self, inputs: List[Operand], output: OpView, output_shape: Optional[Shape]
Expand All @@ -359,11 +398,13 @@ def op_proxy(
op_golden_function: Callable,
op_ttir_function: Callable,
inputs: List[Operand],
unit_attrs: List[str] = None,
organize_ttir_args: Optional[Callable] = None,
organize_golden_args: Optional[Callable] = None,
output_shape: Optional[Shape] = None,
golden_kwargs: dict = {},
ttir_kwargs: dict = {},
use_zeros: bool = False,
) -> Any:
"""
Provides a general interface for proxy-ing OPs and creating them.
Expand Down Expand Up @@ -422,7 +463,10 @@ def organize_golden_args(inputs: List[Operand], output: OpView, output_shape: Op

# Use the golden output to determine proper output shape unless otherwise specified
output_shape = golden.tensor.shape if not output_shape else output_shape
output = self.empty(output_shape)
if use_zeros:
output = self.zeros(output_shape)
else:
output = self.empty(output_shape)

id = self.get_next_global_id()
loc = get_loc_of_extra_file_callee(id=id)
Expand All @@ -433,6 +477,13 @@ def organize_golden_args(inputs: List[Operand], output: OpView, output_shape: Op
**ttir_kwargs,
)

# Add unit attributes if specified
if unit_attrs:
from ttmlir.ir import UnitAttr

for attr_name in unit_attrs:
op.operation.attributes[attr_name] = UnitAttr.get(self._ctx)

self.id_golden_map[str(loc)] = golden
self._store_golden(op, golden)
self._override_golden(output, golden)
Expand Down
35 changes: 35 additions & 0 deletions runtime/include/tt/runtime/detail/dylib.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TT_RUNTIME_DETAIL_DYLIB_H
#define TT_RUNTIME_DETAIL_DYLIB_H

#include "tt/runtime/types.h"

#include <cstring>
#include <dlfcn.h>
#include <iostream>
#include <string>
#include <sys/mman.h>
#include <sys/syscall.h>
#include <unistd.h>
// Linux memfd_create syscall number, if not available in <sys/mman.h>
#ifndef MFD_CLOEXEC
#define MFD_CLOEXEC 0x0001U
#endif
#ifndef SYS_memfd_create
#define SYS_memfd_create 319
#endif
#include <stdint.h>
namespace tt::runtime::common {
void *loadLibraryFromMemory(const uint8_t *data, size_t size);

DylibHandleMap openDylibHandles(
const ::flatbuffers::Vector<::flatbuffers::Offset<tt::target::DynamicLib>>
*dylibs);

void closeDylibHandles(DylibHandleMap handles);
} // namespace tt::runtime::common

#endif
24 changes: 24 additions & 0 deletions runtime/include/tt/runtime/detail/strides.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TT_RUNTIME_DETAIL_STRIDES_H
#define TT_RUNTIME_DETAIL_STRIDES_H

#include <concepts>
#include <vector>

namespace tt::runtime::common {

template <std::integral T>
inline std::vector<T> calculateStride(const std::vector<T> &shape) {
assert(!shape.empty());
std::vector<T> stride(shape.size(), 1);
for (size_t i = shape.size() - 1; i > 0; i--) {
stride[i - 1] = stride[i] * shape[i];
}
return stride;
}
} // namespace tt::runtime::common

#endif
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ struct TensorDesc {
};

using DeviceIds = std::vector<int>;
using DylibHandleMap = std::unordered_map<uint32_t, void *>;

struct Flatbuffer : public detail::ObjectImpl {
using detail::ObjectImpl::ObjectImpl;
Expand Down
24 changes: 5 additions & 19 deletions runtime/lib/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "flatbuffers/idl.h"

#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/strides.h"
#include "tt/runtime/types.h"
#include "tt/runtime/utils.h"
#include "ttmlir/Target/Common/system_desc_bfbs_generated.h"
Expand Down Expand Up @@ -36,21 +37,6 @@ static std::string asJson(void const *fbb, uint8_t const *binarySchema,
return text;
}

static std::vector<uint32_t>
calculateStride(std::vector<uint32_t> const &shape) {
// TODO(bug #2045): Our current stride calculation is incorrect for tilized
// tensors. The current solution is to remove stride entirely from the
// flatbuffer and calculate the stride in runtime assuming using the default
// method ignoring details like grid, layout etc. Once we have a more
// sophisticated way for handling this, we can remove this workaround.
LOG_ASSERT(!shape.empty());
std::vector<uint32_t> stride(shape.size(), 1);
for (size_t i = shape.size() - 1; i > 0; i--) {
stride[i - 1] = stride[i] * shape[i];
}
return stride;
}

namespace ttnn {

::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) {
Expand Down Expand Up @@ -85,7 +71,7 @@ std::vector<TensorDesc> getProgramInputs(Flatbuffer binary,
TensorDesc desc;
desc.shape = {input->desc()->shape()->begin(),
input->desc()->shape()->end()};
desc.stride = calculateStride(desc.shape);
desc.stride = common::calculateStride(desc.shape);
desc.itemsize = ::tt::runtime::utils::dataTypeElementSize(
input->desc()->layout()->memory_desc()->data_type());
desc.dataType = input->desc()->layout()->memory_desc()->data_type();
Expand All @@ -102,7 +88,7 @@ std::vector<TensorDesc> getProgramOutputs(Flatbuffer binary,
TensorDesc desc;
desc.shape = {output->desc()->shape()->begin(),
output->desc()->shape()->end()};
desc.stride = calculateStride(desc.shape);
desc.stride = common::calculateStride(desc.shape);
desc.itemsize = ::tt::runtime::utils::dataTypeElementSize(
output->desc()->layout()->memory_desc()->data_type());
desc.dataType = output->desc()->layout()->memory_desc()->data_type();
Expand Down Expand Up @@ -169,7 +155,7 @@ std::vector<TensorDesc> getProgramInputs(Flatbuffer binary,
TensorDesc desc;
desc.shape = {input->desc()->shape()->begin(),
input->desc()->shape()->end()};
desc.stride = calculateStride(desc.shape);
desc.stride = common::calculateStride(desc.shape);
desc.itemsize = utils::dataTypeElementSize(
input->desc()->layout()->memory_desc()->data_type());
desc.dataType = input->desc()->layout()->memory_desc()->data_type();
Expand All @@ -189,7 +175,7 @@ std::vector<TensorDesc> getProgramOutputs(Flatbuffer binary,
TensorDesc desc;
desc.shape = {output->desc()->shape()->begin(),
output->desc()->shape()->end()};
desc.stride = calculateStride(desc.shape);
desc.stride = common::calculateStride(desc.shape);
desc.itemsize = utils::dataTypeElementSize(
output->desc()->layout()->memory_desc()->data_type());
desc.dataType = output->desc()->layout()->memory_desc()->data_type();
Expand Down
7 changes: 7 additions & 0 deletions runtime/lib/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,10 @@ target_include_directories(TTRuntimeWorkarounds
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
)

add_library(TTRuntimeDylibs STATIC dylib.cpp)
set_property(TARGET TTRuntimeDylibs PROPERTY CXX_STANDARD 20)
target_include_directories(TTRuntimeDylibs
PUBLIC
${PROJECT_SOURCE_DIR}/runtime/include
)
Loading
Loading