Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding compile_function as execute option. #536

Merged
merged 9 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/jax-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install -e '.[test]' 'jax[cpu]'
python -m pip install -e '.[test]' 'jax[cpu]' numba
python -m pip uninstall -y lithops # tests don't run on Lithops

- name: Run tests
Expand Down
40 changes: 38 additions & 2 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import atexit
import dataclasses
import inspect
import shutil
import tempfile
Expand Down Expand Up @@ -30,6 +31,8 @@
# Delete local context dirs when Python exits
CONTEXT_DIRS = set()

Decorator = Callable[[Callable], Callable]


def delete_on_exit(context_dir: str) -> None:
if context_dir not in CONTEXT_DIRS and is_local_path(context_dir):
Expand Down Expand Up @@ -200,13 +203,45 @@ def _create_lazy_zarr_arrays(self, dag):

return dag

def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGraph:
"""Compiles functions from all blockwise ops by mutating the input dag."""
# Recommended: make a copy of the dag before calling this function.

compile_with_config = 'config' in inspect.getfullargspec(compile_function).kwonlyargs

for n in dag.nodes:
node = dag.nodes[n]

if "primitive_op" not in node:
continue

if not isinstance(node["pipeline"].config, BlockwiseSpec):
continue

if compile_with_config:
compiled = compile_function(node["pipeline"].config.function, config=node["pipeline"].config)
else:
compiled = compile_function(node["pipeline"].config.function)

# node is a blockwise primitive_op.
# maybe we should investigate some sort of optics library for frozen dataclasses...
new_pipeline = dataclasses.replace(
node["pipeline"],
config=dataclasses.replace(node["pipeline"].config, function=compiled)
)
node["pipeline"] = new_pipeline

return dag

@lru_cache
def _finalize_dag(
self, optimize_graph: bool = True, optimize_function=None
self, optimize_graph: bool = True, optimize_function=None, compile_function: Optional[Decorator] = None,
) -> nx.MultiDiGraph:
dag = self.optimize(optimize_function).dag if optimize_graph else self.dag
# create a copy since _create_lazy_zarr_arrays mutates the dag
dag = dag.copy()
if callable(compile_function):
dag = self._compile_blockwise(dag, compile_function)
dag = self._create_lazy_zarr_arrays(dag)
return nx.freeze(dag)

Expand All @@ -216,11 +251,12 @@ def execute(
callbacks=None,
optimize_graph=True,
optimize_function=None,
compile_function=None,
resume=None,
spec=None,
**kwargs,
):
dag = self._finalize_dag(optimize_graph, optimize_function)
dag = self._finalize_dag(optimize_graph, optimize_function, compile_function)

compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}"

Expand Down
63 changes: 63 additions & 0 deletions cubed/tests/test_executor_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,66 @@ def test_check_runtime_memory_processes(spec, executor):

# OK if we use fewer workers
c.compute(executor=executor, max_workers=max_workers // 2)


COMPILE_FUNCTIONS = [lambda fn: fn]

try:
from numba import jit as numba_jit
COMPILE_FUNCTIONS.append(numba_jit)
except ModuleNotFoundError:
pass

try:
if 'jax' in os.environ.get('CUBED_BACKEND_ARRAY_API_MODULE', ''):
from jax import jit as jax_jit
COMPILE_FUNCTIONS.append(jax_jit)

def aot(func, *, config=None):
# TODO(alxmrs): implement lowering
return jax_jit(func)

COMPILE_FUNCTIONS.append(aot)

except ModuleNotFoundError:
pass


@pytest.mark.parametrize("compile_function", COMPILE_FUNCTIONS)
def test_check_compilation(spec, executor, compile_function):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
assert_array_equal(
c.compute(executor=executor, compile_function=compile_function), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]])
)


def test_compilation_can_fail(spec, executor):
def compile_function(func):
raise NotImplementedError(f"Cannot compile {func}")

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
try:
c.compute(executor=executor, compile_function=compile_function)
assert False, "Compile function was not called."
except NotImplementedError as e:
assert True, "Compile function was applied."
assert "add" in str(e), "Compile function was applied to add operation."


def test_compilation_with_config_can_fail(spec, executor):
def compile_function(func, *, config=None):
raise NotImplementedError(f"Cannot compile {func} with {config}")

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
try:
c.compute(executor=executor, compile_function=compile_function)
assert False, "Compile function was not called."
except NotImplementedError as e:
assert "add" in str(e), "Compile function was applied to add operation."
assert "BlockwiseSpec" in str(e), "Compile function was applied with a config argument."
Loading