-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
71 lines (64 loc) · 1.9 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
from skbuild import setup
import os
from skbuild.command.install_lib import install_lib
import glob
from setuptools import find_namespace_packages
import sys
class install_metal_libs(install_lib):
def run(self):
install_lib.run(self)
install_path = os.path.join(self.install_dir, "tt_mlir")
os.makedirs(install_path, exist_ok=True)
ttmlir_opt = os.path.abspath(
os.path.join(
os.getcwd(),
"third_party",
"tt-mlir",
"src",
"tt-mlir-build",
"bin",
"ttmlir-opt",
)
)
self.copy_file(ttmlir_opt, install_path)
# Compile time env vars
os.environ["DONT_OVERRIDE_INSTALL_PATH"] = "1"
cmake_args = [
"-GNinja",
"-DBUILD_TTRT=OFF",
]
if "--code_coverage" in sys.argv:
cmake_args += [
"-DCODE_COVERAGE=ON",
]
sys.argv.remove("--code_coverage")
with open("README.md", "r") as f:
long_description = f.read()
setup(
name="tt_torch",
version="0.1",
author="Aleks Knezevic",
author_email="aknezevic@tenstorrent.com",
license="Apache-2.0",
homepage="https://github.com/tenstorrent/tt-torch",
packages=find_namespace_packages(include=["tt_torch*"])
+ find_namespace_packages(
where="third_party/torch-mlir/src/torch-mlir-build/python_packages/torch_mlir"
),
description="TT PyTorch FrontEnd",
long_description=long_description,
long_description_content_type="text/markdown",
include_package_data=True,
cmake_args=cmake_args,
cmdclass={
"install_lib": install_metal_libs,
},
zip_safe=False,
install_requires=[
"torch@https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.5.0%2Bcpu.cxx11.abi-cp311-cp311-linux_x86_64.whl",
"numpy",
],
)