Skip to content

Commit

Permalink
DenseResourceElementsAttr as constant value E2E (#2381)
Browse files Browse the repository at this point in the history
### Ticket
Closes #2378

### Problem description
When the value of a `stalbehlo/ttir/ttnn` constant op is a
`DenseResourceElementsAttr` instead of a `DenseElementsAttr`, we fail to
lower stablehlo to ttir. We also do not yet handle the case where the
value of a constant is in a `DenseResourceElementsAttr` when generating
the flatbuffer.

### What's changed
- Legalize lowering of `stablehlo.constant` when the value type is
`DenseResourceElementsAttr`
- Handle case where value type is `DenseResourceElementsAttr` when
generating flatbuffer
- Add verifier to `ttnn.constant` to ensure that the `ElementsAttr`
holding the constant data is either `DenseResourceElementsAttr` or
`DenseElementsAttr`
  • Loading branch information
LPanosTT authored and odjuricicTT committed Mar 8, 2025
1 parent b37728f commit 995aad9
Show file tree
Hide file tree
Showing 18 changed files with 112 additions and 24 deletions.
3 changes: 3 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/CommonAttrConstraints.td"

Expand Down Expand Up @@ -1604,6 +1605,8 @@ def TTNN_ConstantOp : TTNN_Op<"constant", [AllShapesMatch<["value", "result"]>]>
return wa::TTNNOperandsWorkaroundsFactory::createConstantOpOperandsWorkarounds();
}
}];

let hasVerifier = 1;
}

#endif
2 changes: 1 addition & 1 deletion lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ class StableHLOToTTIRConstantOpConversionPattern
private:
LogicalResult checkBasicLegality(mlir::stablehlo::ConstantOp &srcOp,
ConversionPatternRewriter &rewriter) const {
if (srcOp.getValue().getShapedType().getShape().empty() &&
if (isa<DenseElementsAttr, DenseResourceElementsAttr>(srcOp.getValue()) &&
!srcOp.getValue().getElementType().isIntOrFloat()) {
return rewriter.notifyMatchFailure(srcOp, "Unsupported element type.");
}
Expand Down
15 changes: 15 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ttmlir/Utils.h"

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/BuiltinAttributes.h"

#include <numeric>
#include <optional>
Expand All @@ -19,6 +20,20 @@

namespace mlir::tt::ttnn {

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult mlir::tt::ttnn::ConstantOp::verify() {

if (!isa<DenseResourceElementsAttr, DenseElementsAttr>(getValue())) {
return emitOpError("value attribute must be one of "
"DenseResourceElementsAttr or DenseElementsAttr.");
}

return success();
}

//===----------------------------------------------------------------------===//
// ClampOp
//===----------------------------------------------------------------------===//
Expand Down
15 changes: 12 additions & 3 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,10 +944,19 @@ ::flatbuffers::Offset<::tt::target::ttnn::ConstantOp>
createOp(FlatbufferObjectCache &cache, ttnn::ConstantOp op) {
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedSize);
std::vector<uint8_t> rawVector;
if (auto data =
mlir::dyn_cast<mlir::DenseResourceElementsAttr>(op.getValue())) {
ArrayRef<char> rawData = data.getData();
rawVector = std::vector<uint8_t>(rawData.begin(), rawData.end());
} else if (auto data =
mlir::dyn_cast<mlir::DenseElementsAttr>(op.getValue())) {
ArrayRef<char> rawData = data.getRawData();
rawVector = std::vector<uint8_t>(rawData.begin(), rawData.end());
} else {
llvm_unreachable("Unknown constant value attribute type");
}

auto rawData =
mlir::dyn_cast<mlir::DenseElementsAttr>(op.getValue()).getRawData();
auto rawVector = std::vector<uint8_t>(rawData.begin(), rawData.end());
return ::tt::target::ttnn::CreateConstantOpDirect(*cache.fbb, output,
&rawVector);
}
Expand Down
18 changes: 18 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -339,4 +339,22 @@ module @jit_constant attributes {} {
// CHECK: return %[[CONSTANT]] : tensor<1xi32>
return %0 : tensor<i64>
}

func.func @test_dense_attr() -> tensor<1x2xbf16> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense_resource<dense_attr> : tensor<1x2xbf16>}> : () -> tensor<1x2xbf16>
%0 = stablehlo.constant dense_resource<dense_attr> : tensor<1x2xbf16>
// CHECK: return %{{[0-9]+}} : tensor<1x2xbf16>
return %0 : tensor<1x2xbf16>
}
}
{-#
dialect_resources: {
builtin: {
// This should encode for two bfloat16 values which are both 2.0
// 0x020000000 is a hex string blob
// 0x0040 is 2.0 in bfloat16
// 0x00400040 is 2.0, 2.0
dense_attr: "0x0200000000400040"
}
}
#-}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s

module attributes {} {
func.func @test_dense_attr() -> tensor<1x2xbf16> {
// CHECK: error: 'ttnn.constant' op value attribute must be one of DenseResourceElementsAttr or DenseElementsAttr.
%0 = "ttir.constant"() <{value = sparse<[[0, 0], [0, 1]], [2.0, 2.0]> : tensor<1x2xbf16>}> : () -> tensor<1x2xbf16>
return %0 : tensor<1x2xbf16>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,22 @@ module attributes {} {
%0 = "ttir.constant"() <{value = dense<[[[-1, 2, 3]]]> : tensor<1x1x3xi32>}> : () -> tensor<1x1x3xi32>
return %0 : tensor<1x1x3xi32>
}

func.func @test_dense_attr() -> tensor<1x2xbf16> {
%0 = "ttir.constant"() <{value = dense_resource<dense_attr> : tensor<1x2xbf16>}> : () -> tensor<1x2xbf16>
// CHECK: "ttnn.constant"
// CHECK-SAME: value = dense_resource<dense_attr>
return %0 : tensor<1x2xbf16>
}
}
{-#
dialect_resources: {
builtin: {
// This should encode for two bfloat16 values which are both 2.0
// 0x020000000 is a hex string blob
// 0x0040 is 2.0 in bfloat16
// 0x00400040 is 2.0, 2.0
dense_attr: "0x0200000000400040"
}
}
#-}
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

module attributes {} {
func.func @test_dense_attr() -> tensor<1x2xbf16> {
// CHECK: ttnn.constant
// CHECK-SAME: dense_resource<dense_attr>
%0 = stablehlo.constant dense_resource<dense_attr> : tensor<1x2xbf16>
return %0 : tensor<1x2xbf16>
}
}
{-#
dialect_resources: {
builtin: {
// This should encode for two bfloat16 values which are both 2.0
// 0x020000000 is a hex string blob
// 0x0040 is 2.0 in bfloat16
// 0x00400040 is 2.0, 2.0
dense_attr: "0x0200000000400040"
}
}
#-}
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

Expand Down

0 comments on commit 995aad9

Please sign in to comment.