Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: grad functions #11

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 6 additions & 42 deletions src/array_api_jax_compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,53 +10,17 @@

from typing import Any

from jax.experimental.array_api import __array_api_version__
from jaxtyping import install_import_hook

with install_import_hook("array_api_jax_compat", None):
from . import (
_constants,
_creation_functions,
_data_type_functions,
_elementwise_functions,
_indexing_functions,
_linear_algebra_functions,
_manipulation_functions,
_searching_functions,
_set_functions,
_sorting_functions,
_statistical_functions,
_utility_functions,
fft,
linalg,
)
from ._constants import *
from ._creation_functions import *
from ._data_type_functions import *
from ._elementwise_functions import *
from ._indexing_functions import *
from ._linear_algebra_functions import *
from ._manipulation_functions import *
from ._searching_functions import *
from ._set_functions import *
from ._sorting_functions import *
from ._statistical_functions import *
from ._utility_functions import *
from . import _array_api, _grad
from ._array_api import *
from ._grad import *
from ._version import version as __version__

__all__ = ["__version__", "__array_api_version__", "fft", "linalg"]
__all__ += _constants.__all__
__all__ += _creation_functions.__all__
__all__ += _data_type_functions.__all__
__all__ += _elementwise_functions.__all__
__all__ += _indexing_functions.__all__
__all__ += _linear_algebra_functions.__all__
__all__ += _manipulation_functions.__all__
__all__ += _searching_functions.__all__
__all__ += _set_functions.__all__
__all__ += _sorting_functions.__all__
__all__ += _statistical_functions.__all__
__all__ += _utility_functions.__all__
__all__ = ["__version__"]
__all__ += _array_api.__all__
__all__ += _grad.__all__


def __getattr__(name: str) -> Any: # TODO: fuller annotation
Expand Down
54 changes: 54 additions & 0 deletions src/array_api_jax_compat/_array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved.

array-api-jax-compat: Array-API JAX compatibility
"""

# pylint: disable=redefined-builtin


from __future__ import annotations

from jax.experimental.array_api import __array_api_version__

from . import (
_constants,
_creation_functions,
_data_type_functions,
_elementwise_functions,
_indexing_functions,
_linear_algebra_functions,
_manipulation_functions,
_searching_functions,
_set_functions,
_sorting_functions,
_statistical_functions,
_utility_functions,
fft,
linalg,
)
from ._constants import *
from ._creation_functions import *
from ._data_type_functions import *
from ._elementwise_functions import *
from ._indexing_functions import *
from ._linear_algebra_functions import *
from ._manipulation_functions import *
from ._searching_functions import *
from ._set_functions import *
from ._sorting_functions import *
from ._statistical_functions import *
from ._utility_functions import *

__all__ = ["__array_api_version__", "fft", "linalg"]
__all__ += _constants.__all__
__all__ += _creation_functions.__all__
__all__ += _data_type_functions.__all__
__all__ += _elementwise_functions.__all__
__all__ += _indexing_functions.__all__
__all__ += _linear_algebra_functions.__all__
__all__ += _manipulation_functions.__all__
__all__ += _searching_functions.__all__
__all__ += _set_functions.__all__
__all__ += _sorting_functions.__all__
__all__ += _statistical_functions.__all__
__all__ += _utility_functions.__all__
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from jax.experimental import array_api
from quax import Value

from ._dispatch import dispatcher
from ._types import DType
from ._utils import quaxify
from array_api_jax_compat._dispatch import dispatcher
from array_api_jax_compat._types import DType
from array_api_jax_compat._utils import quaxify

T = TypeVar("T")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from jax.experimental.array_api._data_type_functions import FInfo, IInfo
from quax import Value

from ._types import DType
from ._utils import quaxify
from array_api_jax_compat._types import DType
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from jax.experimental import array_api
from quax import Value

from ._utils import quaxify
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from jax.experimental import array_api
from quax import Value

from ._utils import quaxify
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax.experimental import array_api
from quax import Value

from ._utils import quaxify
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from jax.experimental import array_api
from quax import Value

from ._utils import quaxify
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax.experimental import array_api
from quax import Value

from ._utils import quaxify
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax.experimental import array_api
from quax import Value

from ._utils import quaxify
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax.experimental import array_api
from quax import Value

from ._utils import quaxify
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from jax.experimental import array_api
from quax import Value

from ._types import DType
from ._utils import quaxify
from array_api_jax_compat._types import DType
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax.experimental import array_api
from quax import Value

from ._utils import quaxify
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from jax.experimental.array_api import fft as _jax_fft
from quax import Value

from ._utils import quaxify
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from jax.experimental import array_api
from quax import Value

from ._types import DType
from ._utils import quaxify
from array_api_jax_compat._types import DType
from array_api_jax_compat._utils import quaxify


@quaxify
Expand Down
120 changes: 120 additions & 0 deletions src/array_api_jax_compat/_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved.

array-api-jax-compat: Array-API JAX compatibility
"""

from __future__ import annotations

__all__ = ["grad"]

from typing import TYPE_CHECKING, Any, Protocol, TypeVar

import jax
import jax.numpy as jnp
from quax import quaxify
from typing_extensions import Self

if TYPE_CHECKING:
from collections.abc import Callable, Sequence

from jax._src.api import AxisName


class SupportsGetItem(Protocol):
def __getitem__(self, key: Any) -> Self:
...


T = TypeVar("T")
IT = TypeVar("IT", bound=SupportsGetItem)


def grad(
fun: Callable[..., Any],
*,
argnums: int | Sequence[int] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = (),
vmap_kw: dict[str, Any] | None = None,
vmap_batch: tuple[int, ...] | None = None,
) -> Callable[..., Any]:
"""Quaxified :func:`jax.grad`.

Creates a function that evaluates the gradient of ``fun``.

Parameters
----------
fun : callable
Function to be differentiated. Its arguments at positions specified by
`argnums` should be arrays, scalars, or standard Python containers.
Argument arrays in the positions specified by `argnums` must be of
inexact (i.e., floating-point or complex) type. It should return a
scalar (which includes arrays with shape `()` but not arrays with shape
`(1,)` etc.)
argnums : int or sequence of ints, optional
Specifies which positional argument(s) to differentiate with respect to
(default 0).
has_aux : bool, optional
Indicates whether `fun` returns a pair where the first element is
considered the output of the mathematical function to be differentiated
and the second element is auxiliary data. Default False.
holomorphic : bool, optional
Indicates whether `fun` is promised to be holomorphic. If True, inputs
and outputs must be complex. Default False.
allow_int : bool, optional
Whether to allow differentiating with respect to integer valued inputs.
The gradient of an integer input will have a trivial vector-space dtype
(float0). Default False.
reduce_axes : tuple of axis names, optional
If an axis is listed here, and `fun` implicitly broadcasts a value over
that axis, the backward pass will perform a `psum` of the corresponding
gradient. Otherwise, the gradient will be per-example over named axes.
For example, if `'batch'` is a named batch axis, `grad(f,
reduce_axes=('batch',))` will create a function that computes the total
gradient while `grad(f)` will create one that computes the per-example
gradient.

Returns
-------
callable
A function with the same arguments as `fun`, that evaluates the gradient
of `fun`. If `argnums` is an integer then the gradient has the same
shape and type as the positional argument indicated by that integer. If
`argnums` is a tuple of integers, the gradient is a tuple of values with
the same shapes and types as the corresponding arguments. If `has_aux`
is True then a pair of (gradient, auxiliary_data) is returned.

Examples
--------
>>> import jax
>>>
>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> print(grad_tanh(0.2))
0.961043
"""
# TODO: get this working using the actual `grad` function.
# There are some interesting issues to resolve. See
# https://github.com/patrick-kidger/quax/issues/5.
# In the meantime, we workaround this by using `jacfwd` instead.
if allow_int:
msg = "allow_int is not yet supported"
raise NotImplementedError(msg)
if reduce_axes:
msg = "reduce_axes is not yet supported"
raise NotImplementedError(msg)

grad_substitute = quaxify(
jax.jacfwd(fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic),
)

def grad_func(*args: Any) -> Any:
for i, arg in enumerate(args):
assert ( # noqa: S101
len(jnp.shape(arg)) < 2
), f"arg {i} has shape {arg.shape}"

return grad_substitute(*args)

return grad_func
Loading