Skip to content

Commit

Permalink
[PyTorch] Stop storing fused weight tensor in linear modules (NVIDIA#719
Browse files Browse the repository at this point in the history
)

* Support noop concat without providing full tensor

Stop storing fused buffers in linear modules.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug noop cat func

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Construct TE modules in tests with correct dtypes

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add tolerances to numerical tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use plain PyTorch concat when exporting to ONNX

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
  • Loading branch information
2 people authored and pggPL committed May 16, 2024
1 parent d5c62b9 commit d4e0f80
Show file tree
Hide file tree
Showing 5 changed files with 597 additions and 482 deletions.
Loading

0 comments on commit d4e0f80

Please sign in to comment.