diff --git a/setup.py b/setup.py index d50a9b8706..d442aec872 100644 --- a/setup.py +++ b/setup.py @@ -445,6 +445,12 @@ def setup_pytorch_extension() -> setuptools.Extension: sources = [ src_dir / "common.cu", src_dir / "ts_fp8_op.cpp", + # We need to compile system.cpp because the pytorch extension uses + # transformer_engine::getenv. This is a workaround to avoid direct + # linking with libtransformer_engine.so, as the pre-built PyTorch + # wheel from conda or PyPI was not built with CXX11_ABI, and will + # cause undefined symbol issues. + root_path / "transformer_engine" / "common" / "util" / "system.cpp", ] + \ _all_files_in_dir(extensions_dir)