Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aten.native_layer_norm.default lowering modification to include custom compute kernel config #812

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

amalbasaTT
Copy link

Ticket

Link to Github Issue

Problem description

ttnn.layer_norm found in albert_v2 models have low accuracy in some occurrences. Adding compute kernel config with fp32_dest_acc_en enabled fixes that issue.

What's changed

  1. Added TtnnComputeKernelConfig helper class.
  2. Added compute_kernel_config registration to torch.fx graph module.
  3. In to_tt_pass.py modified lowering of torch.ops.aten.native_layer_norm.default` to include custom compute kernel config.
  4. In tools/generate_op_accuracy_tests.py, expaned _build_code_from_aten_ttnn_graphs to include kernel config at the beginning of forward function.

@kevinwuTT
Copy link
Contributor

I think one cleaner approach when dealing with more complicated calls to ttnn functions is to create a wrapper like this: https://github.com/tenstorrent/pytorch2.0_ttnn/blob/main/torch_ttnn/passes/lowering/target_wrappers.py. That way, we don't have to modify tools/generate_op_accuracy_tests.py because it should automatically grab the wrapper function.

@amalbasaTT
Copy link
Author

amalbasaTT commented Mar 6, 2025

@kevinwuTT I assumed that in tools/generate_op_accuracy_tests.py we should add variables which are common for most of ops in forward function of generated code (like device for example). compute_kernel_config corresponds to that, and I'm not sure how will wrapping layer_norm make it cleaner (if i add compute_kernel_config to layer_norm wrapper, there will be as many same compute_kernel_configs as there is layer_norms inside graph).

@amalbasaTT amalbasaTT force-pushed the amalbasaTT/ln_lowering_mod branch from 4971112 to 251a4c5 Compare March 6, 2025 10:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants