Skip to content

Commit

Permalink
Refactor TTNN verification and add tests for conv2d op
Browse files Browse the repository at this point in the history
  • Loading branch information
jserbedzijaTT committed Mar 9, 2025
1 parent c16d622 commit 50380f2
Show file tree
Hide file tree
Showing 11 changed files with 890 additions and 90 deletions.
43 changes: 23 additions & 20 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ def TTNN_Conv2dOp : TTNN_Op<"conv2d"> {
Applies a 2D convolution over an input image composed of several input planes.

Inputs:
- `input` (AnyRankedTensor): expected in the following format (N, H_in, W_in, C) where:
- `input` (AnyRankedTensor): expected in the following flattened format (1, 1, N * H_in * W_in, C) where:
- N is the batch size
- H_in is the height of the input planes
- W_in is the width of the input planes
Expand All @@ -1131,39 +1131,42 @@ def TTNN_Conv2dOp : TTNN_Op<"conv2d"> {
- G is the number of groups
- K_H is the height of the kernel
- K_W is the width of the kernel
- `output` (AnyRankedTensor): expected in the following format (N, H_out, W_out, O) where:
- `H_out = (H_in + 2 * pH - dH * (K_H - 1) - 1) / sH + 1`
- `W_out = (W_in + 2 * pW - dW * (K_W - 1) - 1) / sW + 1`

Attributes:
- `in_channels` (i32): The number of input channels.
- `out_channels` (i32): The number of output channels.
- `batch_size` (i32): The batch size.
- `input_height` (i32): The input height.
- `input_width` (i32): The input width.
- `stride` (i32 | array<2xi32>):
- i32: Same stride for height and width dimensions (sH = sW = value).
- array<2xi32>: [sH, sW] where sH is stride for height and sW is stride for width.
- `padding` (i32 | array<2xi32>):
- i32: Same padding for all sides (pH = pW = value).
- array<2xi32>: [pH, pW] where pH is padding for height (top/bottom) and pW is padding for width (left/right).
- `dilation` (i32 | array<2xi32>): Spacing between kernel elements.
- i32: Same dilation for height and width dimensions (dH = dW = value).
- array<2xi32>: [dH, dW] where dH is dilation for height and dW is dilation for width.
- `kernel_size` (array<2xi32>): [K_H, K_W] where K_H is the kernel height and K_W is the kernel width.
- `stride` (array<2xi32>): [sH, sW] where sH is stride for height and sW is stride for width.
- `padding` (array<2xi32>): [pH, pW] where pH is padding for height (top/bottom) and pW is padding for width (left/right).
- `dilation` (array<2xi32>): [dH, dW] where dH is dilation for height and dW is dilation for width.
- `groups` (i32): Number of blocked connections from input channels to output channels. Input and output channels must both be divisible by groups.

Outputs:
- `result` (AnyRankedTensor): returned in the following flattened format (1, 1, N * H_out * W_out, O) where:
- `H_out = (H_in + 2 * pH - dH * (K_H - 1) - 1) / sH + 1`
- `W_out = (W_in + 2 * pW - dW * (K_W - 1) - 1) / sW + 1`

Example:
%input = tensor.empty() : () -> tensor<1x32x32x64xbf16>
%input = tensor.empty() : () -> tensor<1x1x1024x64xbf16>
%weight = tensor.empty() : () -> tensor<64x64x3x3xbf16>
%bias = tensor.empty() : () -> tensor<1x1x1x64xbf16>
%output = tensor.empty() : () -> tensor<1x30x30x64xbf16>
%0 = "ttnn.conv2d"(%input, %weight, %bias, %output)
%device = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
%0 = "ttnn.conv2d"(%input, %weight, %bias, %device)
<{
stride = 1: i32,
padding = 0: i32,
dilation = 1: i32,
in_channels = 64: i32,
out_channels = 64: i32,
batch_size = 1: i32,
input_height = 32: i32,
input_width = 32: i32,
kernel_size = array<i32: 3, 3>,
stride = array<i32: 1, 1>,
padding = array<i32: 0, 0>,
dilation = array<i32: 1, 1>,
groups = 1: i32
> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x30x30x64xbf16>) -> tensor<1x30x30x64xbf16>
}> : (tensor<1x1x1024x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, !tt.device<#device>) -> tensor<1x1x900x64xbf16>
}];

let arguments = (ins AnyRankedTensor:$input,
Expand Down
91 changes: 47 additions & 44 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,26 +199,27 @@ ::mlir::LogicalResult mlir::tt::ttir::Conv2dOp::verify() {
}
}

uint32_t batchSize = inputType.getDimSize(0);
if (batchSize != outputType.getDimSize(0)) {
return emitOpError(
"First dimension of the input tensor must match the first dimension of "
"the output tensor, got: " +
std::to_string(batchSize) + " and " +
std::to_string(outputType.getDimSize(0)));
constexpr unsigned int BATCH_DIM = 0, HEIGHT_DIM = 1, WIDTH_DIM = 2,
CHANNEL_DIM = 3;
uint32_t batchSize = inputType.getDimSize(BATCH_DIM);
if (batchSize != outputType.getDimSize(BATCH_DIM)) {
return emitOpError()
<< "Batch size from the input tensor (" << batchSize
<< ") must match the first dimension of the output tensor ("
<< outputType.getDimSize(BATCH_DIM) << ")";
}

uint32_t inputHeight = inputType.getDimSize(1);
uint32_t inputWidth = inputType.getDimSize(2);
uint32_t inChannels = inputType.getDimSize(3);
uint32_t outChannels = outputType.getDimSize(3);
uint32_t inputHeight = inputType.getDimSize(HEIGHT_DIM);
uint32_t inputWidth = inputType.getDimSize(WIDTH_DIM);
uint32_t inChannels = inputType.getDimSize(CHANNEL_DIM);
uint32_t outChannels = outputType.getDimSize(CHANNEL_DIM);

auto stride = ttmlir::utils::getPairOfInteger<int32_t>(getStride());
if (auto error = stride.takeError()) {
return emitOpError() << llvm::toString(std::move(error)) << " for stride";
}
if (stride->first < 1 || stride->second < 1) {
return emitOpError("Stride values must be greater than 0");
return emitOpError("Stride attribute values must be greater than 0");
}

auto padding = ttmlir::utils::getQuadrupleOfInteger<int32_t>(getPadding());
Expand All @@ -229,7 +230,8 @@ ::mlir::LogicalResult mlir::tt::ttir::Conv2dOp::verify() {
auto [paddingTop, paddingLeft, paddingBottom, paddingRight] = *padding;
if (paddingTop < 0 || paddingBottom < 0 || paddingLeft < 0 ||
paddingRight < 0) {
return emitOpError("Padding values must be greater or equal than 0");
return emitOpError(
"Padding attribute values must be greater than or equal to 0");
}
int32_t verticalPadding = paddingTop + paddingBottom;
int32_t horizontalPadding = paddingLeft + paddingRight;
Expand All @@ -239,29 +241,31 @@ ::mlir::LogicalResult mlir::tt::ttir::Conv2dOp::verify() {
return emitOpError() << llvm::toString(std::move(error)) << " for dilation";
}
if (dilation->first < 1 || dilation->second < 1) {
return emitOpError("Dilation values must be greater than 0");
return emitOpError("Dilation attribute values must be greater than 0");
}

llvm::SmallVector<int32_t, 2> kernelSize = {
static_cast<int32_t>(weightType.getDimSize(2)),
static_cast<int32_t>(weightType.getDimSize(3))};
constexpr unsigned int WEIGHT_OUT_CHANNEL_DIM = 0, WEIGHT_IN_CHANNEL_DIM = 1;
constexpr unsigned int WEIGHT_KERNEL_HEIGHT_DIM = 2,
WEIGHT_KERNEL_WIDTH_DIM = 3;
llvm::SmallVector<int32_t, 2> kernelSize{
static_cast<int32_t>(weightType.getDimSize(WEIGHT_KERNEL_HEIGHT_DIM)),
static_cast<int32_t>(weightType.getDimSize(WEIGHT_KERNEL_WIDTH_DIM))};

llvm::SmallVector<uint32_t, 2> paddedInputSize = {
llvm::SmallVector<uint32_t, 2> paddedInputSize{
inputHeight + verticalPadding, inputWidth + horizontalPadding};
llvm::SmallVector<uint32_t, 2> effectiveKernelSize = {
llvm::SmallVector<uint32_t, 2> effectiveKernelSize{
static_cast<uint32_t>(kernelSize[0] +
(kernelSize[0] - 1) * (dilation->first - 1)),
static_cast<uint32_t>(kernelSize[1] +
(kernelSize[1] - 1) * (dilation->second - 1))};
if (paddedInputSize[0] < effectiveKernelSize[0] ||
paddedInputSize[1] < effectiveKernelSize[1]) {
return emitOpError(
"Calculated padded input size per channel: (" +
std::to_string(paddedInputSize[0]) + " x " +
std::to_string(paddedInputSize[1]) + "). Kernel size: (" +
std::to_string(effectiveKernelSize[0]) + " x " +
std::to_string(effectiveKernelSize[1]) +
"). Kernel size can't be greater than actual input size");
return emitOpError()
<< "Calculated padded input size per channel: ("
<< paddedInputSize[0] << " x " << paddedInputSize[1]
<< "). Kernel size: (" << effectiveKernelSize[0] << " x "
<< effectiveKernelSize[1]
<< "). Kernel size can't be greater than actual input size";
}

uint32_t groups = getGroups();
Expand All @@ -279,31 +283,30 @@ ::mlir::LogicalResult mlir::tt::ttir::Conv2dOp::verify() {
<< groups << " groups";
}

llvm::ArrayRef<std::int64_t> kernelShape = weightType.getShape();
if (outChannels != kernelShape[0]) {
llvm::ArrayRef<std::int64_t> weightShape = weightType.getShape();
if (outChannels != weightShape[WEIGHT_OUT_CHANNEL_DIM]) {
return emitOpError() << "Number of output channels from output tensor must "
"match the first dimension of the weight tensor. "
<< "Got " << outChannels << " output channels and "
<< kernelShape[0] << " in the weight tensor";
<< weightShape[WEIGHT_OUT_CHANNEL_DIM]
<< " in the weight tensor";
}

if (inChannels / groups != kernelShape[1]) {
if (inChannels / groups != weightShape[WEIGHT_IN_CHANNEL_DIM]) {
return emitOpError() << "Number of input channels per group must match "
"the second dimension of the weight tensor. "
<< "Got " << (inChannels / groups)
<< " input channels per group and " << kernelShape[1]
<< " input channels per group and "
<< weightShape[WEIGHT_IN_CHANNEL_DIM]
<< " in the weight tensor";
}

if (bias) {
if (bias->getDimSize(bias->getRank() - 1) != outChannels) {
return emitOpError() << "Mismatch in bias tensor dimensions. "
<< "Bias tensor has "
<< bias->getDimSize(bias->getRank() - 1)
<< " channels, "
<< "but the output tensor has " << outChannels
<< " channels";
}
if (bias && bias->getDimSize(CHANNEL_DIM) != outChannels) {
return emitOpError() << "Mismatch in bias tensor dimensions. "
<< "Bias tensor has " << bias->getDimSize(CHANNEL_DIM)
<< " channels, "
<< "but the output tensor has " << outChannels
<< " channels";
}

int32_t calculatedHOut = (inputHeight + verticalPadding -
Expand All @@ -314,15 +317,15 @@ ::mlir::LogicalResult mlir::tt::ttir::Conv2dOp::verify() {
dilation->second * (kernelSize[1] - 1) - 1) /
stride->second +
1;
if (calculatedHOut != outputType.getDimSize(1) ||
calculatedWOut != outputType.getDimSize(2)) {
if (calculatedHOut != outputType.getDimSize(HEIGHT_DIM) ||
calculatedWOut != outputType.getDimSize(WIDTH_DIM)) {
return emitOpError()
<< "Mismatch between calculated and got output height and width. "
<< "Calculated: (" << calculatedHOut << " x " << calculatedWOut
<< "). "
<< "Got output tensor height and width: ("
<< outputType.getDimSize(1) << " x " << outputType.getDimSize(2)
<< ")";
<< outputType.getDimSize(HEIGHT_DIM) << " x "
<< outputType.getDimSize(WIDTH_DIM) << ")";
}

return success();
Expand Down
Loading

0 comments on commit 50380f2

Please sign in to comment.