Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stablehlo.clamp conversion bug #2185

Open
mmanzoorTT opened this issue Feb 13, 2025 · 0 comments · May be fixed by #2268
Open

stablehlo.clamp conversion bug #2185

mmanzoorTT opened this issue Feb 13, 2025 · 0 comments · May be fixed by #2268
Assignees
Labels
bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion

Comments

@mmanzoorTT
Copy link
Contributor

stablehlo.clamp is lowered to either

  1. ttir.clamp; if constant values can be determined for min/max
  2. ttir.maximum followed by ttir.minimum; if constant values can not be determined.

Currently, the conversion only look if a constant is directly used in stablehlo.clamp op which may not always be the case. A constant can be converted/reshaped before being used in stablehlo.clamp op. In this case, stablehlo.clmap is lowered to ttir.maximum/ttir.minimum. However, this pattern fails as ttnn.maximum/ttnn.minimum does not support implicit broadcast.

A sample stablehlo IR

func.func @main(%arg0: tensor<3234x2xf32>) -> tensor<3234x2xf32> {
  %cst = arith.constant dense<0> : tensor<1xi64>
  %cst_0 = arith.constant dense<320> : tensor<1xi64>
  %0 = stablehlo.convert %cst : (tensor<1xi64>) -> tensor<1xf32>
  %1 = stablehlo.reshape %0 : (tensor<1xf32>) -> tensor<f32>
  %2 = stablehlo.convert %cst_0 : (tensor<1xi64>) -> tensor<1xf32>
  %3 = stablehlo.reshape %2 : (tensor<1xf32>) -> tensor<f32>
  %4 = stablehlo.clamp %1, %arg0, %3 : (tensor<f32>, tensor<3234x2xf32>, tensor<f32>) -> tensor<3234x2xf32>
  return %4 : tensor<3234x2xf32>
}

Lowered to TTIR

func.func @main(%arg0: tensor<3234x2xf32>) -> tensor<3234x2xf32> {
  %0 = ""ttir.constant""() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
  %1 = ""ttir.constant""() <{value = dense<320> : tensor<1xi32>}> : () -> tensor<1xi32>
  %2 = tensor.empty() : tensor<1xf32>
  %3 = ""ttir.typecast""(%0, %2) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<1xi32>, tensor<1xf32>) -> tensor<1xf32>
  %4 = tensor.empty() : tensor<1xf32>
  %5 = ""ttir.reshape""(%3, %4) <{shape = [1 : i32]}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
  %6 = tensor.empty() : tensor<1xf32>
  %7 = ""ttir.typecast""(%1, %6) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<1xi32>, tensor<1xf32>) -> tensor<1xf32>
  %8 = tensor.empty() : tensor<1xf32>
  %9 = ""ttir.reshape""(%7, %8) <{shape = [1 : i32]}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
  %10 = tensor.empty() : tensor<3234x2xf32>
  %11 = ""ttir.maximum""(%5, %arg0, %10) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1xf32>, tensor<3234x2xf32>, tensor<3234x2xf32>) -> tensor<3234x2xf32>
  %12 = tensor.empty() : tensor<3234x2xf32>
  %13 = ""ttir.minimum""(%11, %9, %12) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<3234x2xf32>, tensor<1xf32>, tensor<3234x2xf32>) -> tensor<3234x2xf32>
  return %13 : tensor<3234x2xf32>
}

Error message

Error: TT_THROW @ /__w/tt-torch/tt-torch/third_party/tt-mlir/src/tt-mlir/third_party/tt-metal/src/tt-metal/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp:101: tt::exception

info:
ttnn::operations::binary::BinaryDeviceOperation: unsupported broadcast
@mmanzoorTT mmanzoorTT self-assigned this Feb 13, 2025
@mmanzoorTT mmanzoorTT added bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion labels Feb 13, 2025
@mmanzoorTT mmanzoorTT added this to the [Third Party] HLO + XLA milestone Feb 13, 2025
@mmanzoorTT mmanzoorTT linked a pull request Feb 24, 2025 that will close this issue
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stablehlo conversion bug Bugs in StableHLO conversion
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant