From 50f63b8a2c5a14ba330043137cd2e7c6d5b1b2fa Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 31 Dec 2023 12:13:38 -0800 Subject: [PATCH 1/3] move array api files to own folder Signed-off-by: nstarman --- src/array_api_jax_compat/__init__.py | 46 ++-------------- .../_array_api/__init__.py | 54 +++++++++++++++++++ .../{ => _array_api}/_constants.py | 0 .../{ => _array_api}/_creation_functions.py | 6 +-- .../{ => _array_api}/_data_type_functions.py | 4 +- .../_elementwise_functions.py | 2 +- .../{ => _array_api}/_indexing_functions.py | 2 +- .../_linear_algebra_functions.py | 2 +- .../_manipulation_functions.py | 2 +- .../{ => _array_api}/_searching_functions.py | 2 +- .../{ => _array_api}/_set_functions.py | 2 +- .../{ => _array_api}/_sorting_functions.py | 2 +- .../_statistical_functions.py | 4 +- .../{ => _array_api}/_utility_functions.py | 2 +- .../{ => _array_api}/fft.py | 2 +- .../{ => _array_api}/linalg.py | 4 +- 16 files changed, 76 insertions(+), 60 deletions(-) create mode 100644 src/array_api_jax_compat/_array_api/__init__.py rename src/array_api_jax_compat/{ => _array_api}/_constants.py (100%) rename src/array_api_jax_compat/{ => _array_api}/_creation_functions.py (97%) rename src/array_api_jax_compat/{ => _array_api}/_data_type_functions.py (89%) rename src/array_api_jax_compat/{ => _array_api}/_elementwise_functions.py (99%) rename src/array_api_jax_compat/{ => _array_api}/_indexing_functions.py (82%) rename src/array_api_jax_compat/{ => _array_api}/_linear_algebra_functions.py (93%) rename src/array_api_jax_compat/{ => _array_api}/_manipulation_functions.py (97%) rename src/array_api_jax_compat/{ => _array_api}/_searching_functions.py (93%) rename src/array_api_jax_compat/{ => _array_api}/_set_functions.py (92%) rename src/array_api_jax_compat/{ => _array_api}/_sorting_functions.py (92%) rename src/array_api_jax_compat/{ => _array_api}/_statistical_functions.py (95%) rename src/array_api_jax_compat/{ => _array_api}/_utility_functions.py (91%) rename src/array_api_jax_compat/{ => _array_api}/fft.py (98%) rename src/array_api_jax_compat/{ => _array_api}/linalg.py (97%) diff --git a/src/array_api_jax_compat/__init__.py b/src/array_api_jax_compat/__init__.py index 9e24447..217b5a3 100644 --- a/src/array_api_jax_compat/__init__.py +++ b/src/array_api_jax_compat/__init__.py @@ -10,53 +10,15 @@ from typing import Any -from jax.experimental.array_api import __array_api_version__ from jaxtyping import install_import_hook with install_import_hook("array_api_jax_compat", None): - from . import ( - _constants, - _creation_functions, - _data_type_functions, - _elementwise_functions, - _indexing_functions, - _linear_algebra_functions, - _manipulation_functions, - _searching_functions, - _set_functions, - _sorting_functions, - _statistical_functions, - _utility_functions, - fft, - linalg, - ) - from ._constants import * - from ._creation_functions import * - from ._data_type_functions import * - from ._elementwise_functions import * - from ._indexing_functions import * - from ._linear_algebra_functions import * - from ._manipulation_functions import * - from ._searching_functions import * - from ._set_functions import * - from ._sorting_functions import * - from ._statistical_functions import * - from ._utility_functions import * + from . import _array_api + from ._array_api import * from ._version import version as __version__ -__all__ = ["__version__", "__array_api_version__", "fft", "linalg"] -__all__ += _constants.__all__ -__all__ += _creation_functions.__all__ -__all__ += _data_type_functions.__all__ -__all__ += _elementwise_functions.__all__ -__all__ += _indexing_functions.__all__ -__all__ += _linear_algebra_functions.__all__ -__all__ += _manipulation_functions.__all__ -__all__ += _searching_functions.__all__ -__all__ += _set_functions.__all__ -__all__ += _sorting_functions.__all__ -__all__ += _statistical_functions.__all__ -__all__ += _utility_functions.__all__ +__all__ = ["__version__"] +__all__ += _array_api.__all__ def __getattr__(name: str) -> Any: # TODO: fuller annotation diff --git a/src/array_api_jax_compat/_array_api/__init__.py b/src/array_api_jax_compat/_array_api/__init__.py new file mode 100644 index 0000000..2f4429f --- /dev/null +++ b/src/array_api_jax_compat/_array_api/__init__.py @@ -0,0 +1,54 @@ +"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved. + +array-api-jax-compat: Array-API JAX compatibility +""" + +# pylint: disable=redefined-builtin + + +from __future__ import annotations + +from jax.experimental.array_api import __array_api_version__ + +from . import ( + _constants, + _creation_functions, + _data_type_functions, + _elementwise_functions, + _indexing_functions, + _linear_algebra_functions, + _manipulation_functions, + _searching_functions, + _set_functions, + _sorting_functions, + _statistical_functions, + _utility_functions, + fft, + linalg, +) +from ._constants import * +from ._creation_functions import * +from ._data_type_functions import * +from ._elementwise_functions import * +from ._indexing_functions import * +from ._linear_algebra_functions import * +from ._manipulation_functions import * +from ._searching_functions import * +from ._set_functions import * +from ._sorting_functions import * +from ._statistical_functions import * +from ._utility_functions import * + +__all__ = ["__array_api_version__", "fft", "linalg"] +__all__ += _constants.__all__ +__all__ += _creation_functions.__all__ +__all__ += _data_type_functions.__all__ +__all__ += _elementwise_functions.__all__ +__all__ += _indexing_functions.__all__ +__all__ += _linear_algebra_functions.__all__ +__all__ += _manipulation_functions.__all__ +__all__ += _searching_functions.__all__ +__all__ += _set_functions.__all__ +__all__ += _sorting_functions.__all__ +__all__ += _statistical_functions.__all__ +__all__ += _utility_functions.__all__ diff --git a/src/array_api_jax_compat/_constants.py b/src/array_api_jax_compat/_array_api/_constants.py similarity index 100% rename from src/array_api_jax_compat/_constants.py rename to src/array_api_jax_compat/_array_api/_constants.py diff --git a/src/array_api_jax_compat/_creation_functions.py b/src/array_api_jax_compat/_array_api/_creation_functions.py similarity index 97% rename from src/array_api_jax_compat/_creation_functions.py rename to src/array_api_jax_compat/_array_api/_creation_functions.py index e47c48c..cfcb7a0 100644 --- a/src/array_api_jax_compat/_creation_functions.py +++ b/src/array_api_jax_compat/_array_api/_creation_functions.py @@ -29,9 +29,9 @@ from jax.experimental import array_api from quax import Value -from ._dispatch import dispatcher -from ._types import DType -from ._utils import quaxify +from array_api_jax_compat._dispatch import dispatcher +from array_api_jax_compat._types import DType +from array_api_jax_compat._utils import quaxify T = TypeVar("T") diff --git a/src/array_api_jax_compat/_data_type_functions.py b/src/array_api_jax_compat/_array_api/_data_type_functions.py similarity index 89% rename from src/array_api_jax_compat/_data_type_functions.py rename to src/array_api_jax_compat/_array_api/_data_type_functions.py index 6b9fa4f..7e092ed 100644 --- a/src/array_api_jax_compat/_data_type_functions.py +++ b/src/array_api_jax_compat/_array_api/_data_type_functions.py @@ -5,8 +5,8 @@ from jax.experimental.array_api._data_type_functions import FInfo, IInfo from quax import Value -from ._types import DType -from ._utils import quaxify +from array_api_jax_compat._types import DType +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/_elementwise_functions.py b/src/array_api_jax_compat/_array_api/_elementwise_functions.py similarity index 99% rename from src/array_api_jax_compat/_elementwise_functions.py rename to src/array_api_jax_compat/_array_api/_elementwise_functions.py index 68c68a1..1ed0a00 100644 --- a/src/array_api_jax_compat/_elementwise_functions.py +++ b/src/array_api_jax_compat/_array_api/_elementwise_functions.py @@ -64,7 +64,7 @@ from jax.experimental import array_api from quax import Value -from ._utils import quaxify +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/_indexing_functions.py b/src/array_api_jax_compat/_array_api/_indexing_functions.py similarity index 82% rename from src/array_api_jax_compat/_indexing_functions.py rename to src/array_api_jax_compat/_array_api/_indexing_functions.py index 4c0354d..3d6654f 100644 --- a/src/array_api_jax_compat/_indexing_functions.py +++ b/src/array_api_jax_compat/_array_api/_indexing_functions.py @@ -3,7 +3,7 @@ from jax.experimental import array_api from quax import Value -from ._utils import quaxify +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/_linear_algebra_functions.py b/src/array_api_jax_compat/_array_api/_linear_algebra_functions.py similarity index 93% rename from src/array_api_jax_compat/_linear_algebra_functions.py rename to src/array_api_jax_compat/_array_api/_linear_algebra_functions.py index b2bc159..2d0ff78 100644 --- a/src/array_api_jax_compat/_linear_algebra_functions.py +++ b/src/array_api_jax_compat/_array_api/_linear_algebra_functions.py @@ -6,7 +6,7 @@ from jax.experimental import array_api from quax import Value -from ._utils import quaxify +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/_manipulation_functions.py b/src/array_api_jax_compat/_array_api/_manipulation_functions.py similarity index 97% rename from src/array_api_jax_compat/_manipulation_functions.py rename to src/array_api_jax_compat/_array_api/_manipulation_functions.py index b9708ef..8fca86c 100644 --- a/src/array_api_jax_compat/_manipulation_functions.py +++ b/src/array_api_jax_compat/_array_api/_manipulation_functions.py @@ -14,7 +14,7 @@ from jax.experimental import array_api from quax import Value -from ._utils import quaxify +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/_searching_functions.py b/src/array_api_jax_compat/_array_api/_searching_functions.py similarity index 93% rename from src/array_api_jax_compat/_searching_functions.py rename to src/array_api_jax_compat/_array_api/_searching_functions.py index 96362d5..4e9abd1 100644 --- a/src/array_api_jax_compat/_searching_functions.py +++ b/src/array_api_jax_compat/_array_api/_searching_functions.py @@ -4,7 +4,7 @@ from jax.experimental import array_api from quax import Value -from ._utils import quaxify +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/_set_functions.py b/src/array_api_jax_compat/_array_api/_set_functions.py similarity index 92% rename from src/array_api_jax_compat/_set_functions.py rename to src/array_api_jax_compat/_array_api/_set_functions.py index 22c159b..27e9baa 100644 --- a/src/array_api_jax_compat/_set_functions.py +++ b/src/array_api_jax_compat/_array_api/_set_functions.py @@ -4,7 +4,7 @@ from jax.experimental import array_api from quax import Value -from ._utils import quaxify +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/_sorting_functions.py b/src/array_api_jax_compat/_array_api/_sorting_functions.py similarity index 92% rename from src/array_api_jax_compat/_sorting_functions.py rename to src/array_api_jax_compat/_array_api/_sorting_functions.py index 96fb65b..4ec505e 100644 --- a/src/array_api_jax_compat/_sorting_functions.py +++ b/src/array_api_jax_compat/_array_api/_sorting_functions.py @@ -6,7 +6,7 @@ from jax.experimental import array_api from quax import Value -from ._utils import quaxify +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/_statistical_functions.py b/src/array_api_jax_compat/_array_api/_statistical_functions.py similarity index 95% rename from src/array_api_jax_compat/_statistical_functions.py rename to src/array_api_jax_compat/_array_api/_statistical_functions.py index 6806d1e..5714db3 100644 --- a/src/array_api_jax_compat/_statistical_functions.py +++ b/src/array_api_jax_compat/_array_api/_statistical_functions.py @@ -4,8 +4,8 @@ from jax.experimental import array_api from quax import Value -from ._types import DType -from ._utils import quaxify +from array_api_jax_compat._types import DType +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/_utility_functions.py b/src/array_api_jax_compat/_array_api/_utility_functions.py similarity index 91% rename from src/array_api_jax_compat/_utility_functions.py rename to src/array_api_jax_compat/_array_api/_utility_functions.py index e7cf4f4..91ffa65 100644 --- a/src/array_api_jax_compat/_utility_functions.py +++ b/src/array_api_jax_compat/_array_api/_utility_functions.py @@ -5,7 +5,7 @@ from jax.experimental import array_api from quax import Value -from ._utils import quaxify +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/fft.py b/src/array_api_jax_compat/_array_api/fft.py similarity index 98% rename from src/array_api_jax_compat/fft.py rename to src/array_api_jax_compat/_array_api/fft.py index da923b6..53e37ca 100644 --- a/src/array_api_jax_compat/fft.py +++ b/src/array_api_jax_compat/_array_api/fft.py @@ -24,7 +24,7 @@ from jax.experimental.array_api import fft as _jax_fft from quax import Value -from ._utils import quaxify +from array_api_jax_compat._utils import quaxify @quaxify diff --git a/src/array_api_jax_compat/linalg.py b/src/array_api_jax_compat/_array_api/linalg.py similarity index 97% rename from src/array_api_jax_compat/linalg.py rename to src/array_api_jax_compat/_array_api/linalg.py index 86fc8b1..fa9d30a 100644 --- a/src/array_api_jax_compat/linalg.py +++ b/src/array_api_jax_compat/_array_api/linalg.py @@ -34,8 +34,8 @@ from jax.experimental import array_api from quax import Value -from ._types import DType -from ._utils import quaxify +from array_api_jax_compat._types import DType +from array_api_jax_compat._utils import quaxify @quaxify From 1cc023aaf2de83b076a994184f9bbf1630028f0e Mon Sep 17 00:00:00 2001 From: nstarman Date: Mon, 1 Jan 2024 22:30:46 -0800 Subject: [PATCH 2/3] WIP grad Signed-off-by: nstarman --- src/array_api_jax_compat/__init__.py | 4 ++- src/array_api_jax_compat/_grad.py | 41 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 src/array_api_jax_compat/_grad.py diff --git a/src/array_api_jax_compat/__init__.py b/src/array_api_jax_compat/__init__.py index 217b5a3..071a3e1 100644 --- a/src/array_api_jax_compat/__init__.py +++ b/src/array_api_jax_compat/__init__.py @@ -13,12 +13,14 @@ from jaxtyping import install_import_hook with install_import_hook("array_api_jax_compat", None): - from . import _array_api + from . import _array_api, _grad from ._array_api import * + from ._grad import * from ._version import version as __version__ __all__ = ["__version__"] __all__ += _array_api.__all__ +__all__ += _grad.__all__ def __getattr__(name: str) -> Any: # TODO: fuller annotation diff --git a/src/array_api_jax_compat/_grad.py b/src/array_api_jax_compat/_grad.py new file mode 100644 index 0000000..4cd9a3a --- /dev/null +++ b/src/array_api_jax_compat/_grad.py @@ -0,0 +1,41 @@ +"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved. + +array-api-jax-compat: Array-API JAX compatibility +""" + +from __future__ import annotations + +__all__ = ["grad", "value_and_grad"] + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from jax._src.api import AxisName + + +def grad( + fun: Callable[..., Any], + *, + argnums: int | Sequence[int] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: Sequence[AxisName] = (), +) -> Callable[..., Any]: + """`grad`.""" + raise NotImplementedError("TODO") # noqa: EM101 + + +def value_and_grad( + fun: Callable[..., Any], + *, + argnums: int | Sequence[int] = 0, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: Sequence[AxisName] = (), +) -> Callable[..., tuple[Any, Any]]: + """`value_and_grad`.""" + raise NotImplementedError("TODO") # noqa: EM101 From 7de395e37db412123f6c1759fc0d91073abfb4af Mon Sep 17 00:00:00 2001 From: nstarman Date: Sat, 13 Jan 2024 14:00:54 -0500 Subject: [PATCH 3/3] grad via jac --- src/array_api_jax_compat/_grad.py | 109 ++++++++++++++++++++++++++---- 1 file changed, 94 insertions(+), 15 deletions(-) diff --git a/src/array_api_jax_compat/_grad.py b/src/array_api_jax_compat/_grad.py index 4cd9a3a..2c4027e 100644 --- a/src/array_api_jax_compat/_grad.py +++ b/src/array_api_jax_compat/_grad.py @@ -5,9 +5,14 @@ from __future__ import annotations -__all__ = ["grad", "value_and_grad"] +__all__ = ["grad"] -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol, TypeVar + +import jax +import jax.numpy as jnp +from quax import quaxify +from typing_extensions import Self if TYPE_CHECKING: from collections.abc import Callable, Sequence @@ -15,6 +20,15 @@ from jax._src.api import AxisName +class SupportsGetItem(Protocol): + def __getitem__(self, key: Any) -> Self: + ... + + +T = TypeVar("T") +IT = TypeVar("IT", bound=SupportsGetItem) + + def grad( fun: Callable[..., Any], *, @@ -23,19 +37,84 @@ def grad( holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), + vmap_kw: dict[str, Any] | None = None, + vmap_batch: tuple[int, ...] | None = None, ) -> Callable[..., Any]: - """`grad`.""" - raise NotImplementedError("TODO") # noqa: EM101 + """Quaxified :func:`jax.grad`. + Creates a function that evaluates the gradient of ``fun``. -def value_and_grad( - fun: Callable[..., Any], - *, - argnums: int | Sequence[int] = 0, - has_aux: bool = False, - holomorphic: bool = False, - allow_int: bool = False, - reduce_axes: Sequence[AxisName] = (), -) -> Callable[..., tuple[Any, Any]]: - """`value_and_grad`.""" - raise NotImplementedError("TODO") # noqa: EM101 + Parameters + ---------- + fun : callable + Function to be differentiated. Its arguments at positions specified by + `argnums` should be arrays, scalars, or standard Python containers. + Argument arrays in the positions specified by `argnums` must be of + inexact (i.e., floating-point or complex) type. It should return a + scalar (which includes arrays with shape `()` but not arrays with shape + `(1,)` etc.) + argnums : int or sequence of ints, optional + Specifies which positional argument(s) to differentiate with respect to + (default 0). + has_aux : bool, optional + Indicates whether `fun` returns a pair where the first element is + considered the output of the mathematical function to be differentiated + and the second element is auxiliary data. Default False. + holomorphic : bool, optional + Indicates whether `fun` is promised to be holomorphic. If True, inputs + and outputs must be complex. Default False. + allow_int : bool, optional + Whether to allow differentiating with respect to integer valued inputs. + The gradient of an integer input will have a trivial vector-space dtype + (float0). Default False. + reduce_axes : tuple of axis names, optional + If an axis is listed here, and `fun` implicitly broadcasts a value over + that axis, the backward pass will perform a `psum` of the corresponding + gradient. Otherwise, the gradient will be per-example over named axes. + For example, if `'batch'` is a named batch axis, `grad(f, + reduce_axes=('batch',))` will create a function that computes the total + gradient while `grad(f)` will create one that computes the per-example + gradient. + + Returns + ------- + callable + A function with the same arguments as `fun`, that evaluates the gradient + of `fun`. If `argnums` is an integer then the gradient has the same + shape and type as the positional argument indicated by that integer. If + `argnums` is a tuple of integers, the gradient is a tuple of values with + the same shapes and types as the corresponding arguments. If `has_aux` + is True then a pair of (gradient, auxiliary_data) is returned. + + Examples + -------- + >>> import jax + >>> + >>> grad_tanh = jax.grad(jax.numpy.tanh) + >>> print(grad_tanh(0.2)) + 0.961043 + """ + # TODO: get this working using the actual `grad` function. + # There are some interesting issues to resolve. See + # https://github.com/patrick-kidger/quax/issues/5. + # In the meantime, we workaround this by using `jacfwd` instead. + if allow_int: + msg = "allow_int is not yet supported" + raise NotImplementedError(msg) + if reduce_axes: + msg = "reduce_axes is not yet supported" + raise NotImplementedError(msg) + + grad_substitute = quaxify( + jax.jacfwd(fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic), + ) + + def grad_func(*args: Any) -> Any: + for i, arg in enumerate(args): + assert ( # noqa: S101 + len(jnp.shape(arg)) < 2 + ), f"arg {i} has shape {arg.shape}" + + return grad_substitute(*args) + + return grad_func