-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR adds the ability to monkey-patch imports of dask and distributed whenever those imports occur by simply installing rapids-dask-dependency. There's a tiny bit of scope creep here because this PR added real Python code to the repo for the first time, so I also added pre-commit hooks that in turn modified some unrelated files (only minimally, though). TODO: - [x] Update conda CI and packaging - [ ] Stress test extensively --------- Signed-off-by: Vyas Ramasubramani <vyasr@nvidia.com> Co-authored-by: Richard (Rick) Zamora <rzamora217@gmail.com>
- Loading branch information
Showing
20 changed files
with
234 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ build/ | |
wheels/ | ||
*.egg-info/ | ||
*.egg | ||
*.whl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.5.0 | ||
hooks: | ||
- id: trailing-whitespace | ||
- id: end-of-file-fixer | ||
- repo: https://github.com/codespell-project/codespell | ||
rev: v2.2.6 | ||
hooks: | ||
- id: codespell | ||
- repo: https://github.com/astral-sh/ruff-pre-commit | ||
rev: v0.2.2 | ||
hooks: | ||
- id: ruff | ||
args: ["--fix"] | ||
- id: ruff-format | ||
|
||
default_language_version: | ||
python: python3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import rapids_dask_dependency |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,19 @@ | ||
#!/bin/bash | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# Copyright (c) 2023-2024, NVIDIA CORPORATION. | ||
|
||
set -euo pipefail | ||
|
||
source rapids-configure-sccache | ||
source rapids-date-string | ||
|
||
package_name=rapids-dask-dependency | ||
package_dir="pip/${package_name}" | ||
version=$(rapids-generate-version) | ||
|
||
sed -i "s/^version = .*/version = \"${version}\"/g" "${package_dir}/pyproject.toml" | ||
sed -i "s/^version = .*/version = \"${version}\"/g" "pyproject.toml" | ||
|
||
cd "${package_dir}" | ||
python -m pip wheel . -w dist -vvv --no-deps --disable-pip-version-check | ||
python -m pip wheel . -w dist -vv --no-deps --disable-pip-version-check | ||
|
||
RAPIDS_PY_WHEEL_NAME="${package_name}" RAPIDS_PY_WHEEL_PURE="1" rapids-upload-wheels-to-s3 dist | ||
RAPIDS_PY_WHEEL_NAME="rapids-dask-dependency" RAPIDS_PY_WHEEL_PURE="1" rapids-upload-wheels-to-s3 dist | ||
|
||
# Run tests | ||
python -m pip install $(ls dist/*.whl)[test] | ||
python -m pytest -v tests/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#!/bin/bash | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
python -m pytest -v tests/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../tests/ |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
from .dask_loader import DaskLoader | ||
|
||
DaskLoader.install() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
import importlib | ||
import importlib.abc | ||
import importlib.machinery | ||
import sys | ||
import warnings | ||
from contextlib import contextmanager | ||
|
||
from .patches.dask import patches as dask_patches | ||
from .patches.distributed import patches as distributed_patches | ||
|
||
original_warn = warnings.warn | ||
|
||
|
||
def _warning_with_increased_stacklevel( | ||
message, category=None, stacklevel=1, source=None, **kwargs | ||
): | ||
# Patch warnings to have the right stacklevel | ||
# Add 3 to the stacklevel to account for the 3 extra frames added by the loader: one | ||
# in this warnings function, one in the actual loader, and one in the importlib | ||
# call (not including all internal frames). | ||
original_warn(message, category, stacklevel + 3, source, **kwargs) | ||
|
||
|
||
@contextmanager | ||
def patch_warning_stacklevel(): | ||
warnings.warn = _warning_with_increased_stacklevel | ||
yield | ||
warnings.warn = original_warn | ||
|
||
|
||
class DaskLoader(importlib.abc.MetaPathFinder, importlib.abc.Loader): | ||
def create_module(self, spec): | ||
if spec.name.startswith("dask") or spec.name.startswith("distributed"): | ||
with self.disable(), patch_warning_stacklevel(): | ||
mod = importlib.import_module(spec.name) | ||
|
||
# Note: The spec does not make it clear whether we're guaranteed that spec | ||
# is not a copy of the original spec, but that is the case for now. We need | ||
# to assign this because the spec is used to update module attributes after | ||
# it is initialized by create_module. | ||
spec.origin = mod.__spec__.origin | ||
spec.submodule_search_locations = mod.__spec__.submodule_search_locations | ||
|
||
# TODO: I assume we'll want to only apply patches to specific submodules, | ||
# that'll be up to RAPIDS dask devs to decide. | ||
patches = dask_patches if "dask" in spec.name else distributed_patches | ||
for patch in patches: | ||
patch(mod) | ||
return mod | ||
|
||
def exec_module(self, _): | ||
pass | ||
|
||
@contextmanager | ||
def disable(self): | ||
sys.meta_path.remove(self) | ||
try: | ||
yield | ||
finally: | ||
sys.meta_path.insert(0, self) | ||
|
||
def find_spec(self, fullname: str, _, __=None): | ||
if ( | ||
fullname in ("dask", "distributed") | ||
or fullname.startswith("dask.") | ||
or fullname.startswith("distributed.") | ||
): | ||
return importlib.machinery.ModuleSpec( | ||
name=fullname, | ||
loader=self, | ||
# Set these parameters dynamically in create_module | ||
origin=None, | ||
loader_state=None, | ||
is_package=True, | ||
) | ||
return None | ||
|
||
@classmethod | ||
def install(cls): | ||
try: | ||
(self,) = (obj for obj in sys.meta_path if isinstance(obj, cls)) | ||
except ValueError: | ||
self = cls() | ||
sys.meta_path.insert(0, self) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
from .add_patch_attr import add_patch_attr | ||
|
||
patches = [add_patch_attr] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
|
||
def add_patch_attr(mod): | ||
mod._rapids_patched = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
from .add_patch_attr import add_patch_attr | ||
|
||
patches = [add_patch_attr] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
|
||
|
||
def add_patch_attr(mod): | ||
mod._rapids_patched = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. | ||
import os | ||
|
||
from setuptools import setup | ||
from setuptools.command.build_py import build_py | ||
|
||
|
||
# Adapted from https://stackoverflow.com/a/71137790 | ||
class build_py_with_pth_file(build_py): # noqa: N801 | ||
"""Include the .pth file in the generated wheel.""" | ||
|
||
def run(self): | ||
super().run() | ||
|
||
fn = "_rapids_dask_dependency.pth" | ||
|
||
outfile = os.path.join(self.build_lib, fn) | ||
self.copy_file(fn, outfile, preserve_mode=0) | ||
|
||
|
||
setup(cmdclass={"build_py": build_py_with_pth_file}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from functools import wraps | ||
from multiprocessing import Process | ||
|
||
|
||
def run_test_in_subprocess(func): | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
p = Process(target=func, args=args, kwargs=kwargs) | ||
p.start() | ||
p.join() | ||
|
||
return wrapper | ||
|
||
|
||
@run_test_in_subprocess | ||
def test_dask(): | ||
import dask | ||
|
||
assert hasattr(dask, "_rapids_patched") | ||
|
||
|
||
@run_test_in_subprocess | ||
def test_distributed(): | ||
import distributed | ||
|
||
assert hasattr(distributed, "_rapids_patched") |